SauravMaheshkar commited on
Commit
5287bf1
1 Parent(s): 358b6e7

feat: output multiple images

Browse files
Files changed (2) hide show
  1. app.py +24 -16
  2. model.py +3 -2
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import numpy as np
3
  import gradio as gr
@@ -5,46 +7,52 @@ from faiss import read_index
5
  from PIL import Image, ImageOps
6
  from datasets import load_dataset
7
  import torchvision.transforms as T
8
- from torchvision.models import resnet50
9
 
10
  from model import DINO
11
 
12
- transforms = T.Compose(
13
- [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
14
- )
15
-
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
 
18
  dataset = load_dataset("ethz/food101")
19
-
20
  model = DINO(batch_size_per_device=32, num_classes=1000).to(device)
21
  model.load_state_dict(torch.load("./bin/model.ckpt", map_location=device)["state_dict"])
22
 
23
 
24
- def augment(img, transforms=transforms) -> torch.Tensor:
25
  img = Image.fromarray(img)
26
  if img.mode == "L":
27
  # Convert grayscale image to RGB by duplicating the single channel three times
28
  img = ImageOps.colorize(img, black="black", white="white")
 
 
 
 
 
29
  return transforms(img).unsqueeze(0)
30
 
31
 
32
- def search_index(input_image, k = 1):
33
  with torch.no_grad():
34
- embedding = model(augment(input_image))
35
  index = read_index("./bin/dino.index")
36
- _, I = index.search(np.array(embedding[0].reshape(1, -1)), k)
37
- indices = I[0]
38
- answer = ""
39
- for i, index in enumerate(indices[:1]):
40
  retrieved_img = dataset["train"][int(index)]["image"]
41
- return retrieved_img
 
42
 
43
 
44
  app = gr.Interface(
45
  search_index,
46
- inputs=gr.Image(),
47
- outputs="image",
 
 
 
 
 
48
  )
49
 
50
  if __name__ == "__main__":
 
1
+ #!/usr/bin/env python
2
+
3
  import torch
4
  import numpy as np
5
  import gradio as gr
 
7
  from PIL import Image, ImageOps
8
  from datasets import load_dataset
9
  import torchvision.transforms as T
 
10
 
11
  from model import DINO
12
 
 
 
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ ## Define Model and Dataset
16
  dataset = load_dataset("ethz/food101")
 
17
  model = DINO(batch_size_per_device=32, num_classes=1000).to(device)
18
  model.load_state_dict(torch.load("./bin/model.ckpt", map_location=device)["state_dict"])
19
 
20
 
21
+ def augment(img: np.ndarray) -> torch.Tensor:
22
  img = Image.fromarray(img)
23
  if img.mode == "L":
24
  # Convert grayscale image to RGB by duplicating the single channel three times
25
  img = ImageOps.colorize(img, black="black", white="white")
26
+
27
+ transforms = T.Compose(
28
+ [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
29
+ )
30
+
31
  return transforms(img).unsqueeze(0)
32
 
33
 
34
+ def search_index(input_image, k: int = 1):
35
  with torch.no_grad():
36
+ embedding = model(augment(input_image).to(device))
37
  index = read_index("./bin/dino.index")
38
+ _, results = index.search(np.array(embedding[0].reshape(1, -1)), k)
39
+ indices = results[0]
40
+ images = []
41
+ for i, index in enumerate(indices[:k]):
42
  retrieved_img = dataset["train"][int(index)]["image"]
43
+ images.append(retrieved_img)
44
+ return images
45
 
46
 
47
  app = gr.Interface(
48
  search_index,
49
+ inputs=[
50
+ gr.Image(),
51
+ gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K"),
52
+ ],
53
+ outputs=[
54
+ gr.Gallery(label="Retrieved Images"),
55
+ ],
56
  )
57
 
58
  if __name__ == "__main__":
model.py CHANGED
@@ -1,7 +1,9 @@
1
  import copy
2
 
 
3
  from pytorch_lightning import LightningModule
4
  from torch import Tensor
 
5
  from torch.nn import Identity
6
  from torchvision.models import resnet50
7
 
@@ -13,11 +15,10 @@ from lightly.models.utils import (
13
  get_weight_decay_parameters,
14
  update_momentum,
15
  )
16
- from lightly.transforms import DINOTransform
17
  from lightly.utils.benchmarking import OnlineLinearClassifier
18
  from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
19
 
20
- from typing import Union, Tuple, List
21
 
22
 
23
  class DINO(LightningModule):
 
1
  import copy
2
 
3
+ import torch
4
  from pytorch_lightning import LightningModule
5
  from torch import Tensor
6
+ from torch.optim import SGD
7
  from torch.nn import Identity
8
  from torchvision.models import resnet50
9
 
 
15
  get_weight_decay_parameters,
16
  update_momentum,
17
  )
 
18
  from lightly.utils.benchmarking import OnlineLinearClassifier
19
  from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
20
 
21
+ from typing import Union, Tuple, List
22
 
23
 
24
  class DINO(LightningModule):