SauravMaheshkar commited on
Commit
6d7226f
1 Parent(s): 5287bf1

feat: add label to input component

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -19,6 +19,15 @@ model.load_state_dict(torch.load("./bin/model.ckpt", map_location=device)["state
19
 
20
 
21
  def augment(img: np.ndarray) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
22
  img = Image.fromarray(img)
23
  if img.mode == "L":
24
  # Convert grayscale image to RGB by duplicating the single channel three times
@@ -31,14 +40,26 @@ def augment(img: np.ndarray) -> torch.Tensor:
31
  return transforms(img).unsqueeze(0)
32
 
33
 
34
- def search_index(input_image, k: int = 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with torch.no_grad():
36
  embedding = model(augment(input_image).to(device))
37
  index = read_index("./bin/dino.index")
38
  _, results = index.search(np.array(embedding[0].reshape(1, -1)), k)
39
  indices = results[0]
40
- images = []
41
- for i, index in enumerate(indices[:k]):
42
  retrieved_img = dataset["train"][int(index)]["image"]
43
  images.append(retrieved_img)
44
  return images
@@ -47,7 +68,7 @@ def search_index(input_image, k: int = 1):
47
  app = gr.Interface(
48
  search_index,
49
  inputs=[
50
- gr.Image(),
51
  gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K"),
52
  ],
53
  outputs=[
 
19
 
20
 
21
  def augment(img: np.ndarray) -> torch.Tensor:
22
+ """
23
+ Helper Function to augment the image before we generate embeddings
24
+
25
+ Args:
26
+ img (np.ndarray): Input Image
27
+
28
+ Returns:
29
+ torch.Tensor
30
+ """
31
  img = Image.fromarray(img)
32
  if img.mode == "L":
33
  # Convert grayscale image to RGB by duplicating the single channel three times
 
40
  return transforms(img).unsqueeze(0)
41
 
42
 
43
+ def search_index(input_image: np.ndarray, k: int = 1) -> list:
44
+ """
45
+ Retrieve the Top k images from the given input image
46
+
47
+ Args:
48
+ input_image (np.ndarray): Input Image
49
+ k (int): number of images to fetch
50
+
51
+ Returns:
52
+ list: List of top k images retrieved using the embeddings
53
+ generated from the input image
54
+ """
55
+ images = []
56
+
57
  with torch.no_grad():
58
  embedding = model(augment(input_image).to(device))
59
  index = read_index("./bin/dino.index")
60
  _, results = index.search(np.array(embedding[0].reshape(1, -1)), k)
61
  indices = results[0]
62
+ for _, index in enumerate(indices[:k]):
 
63
  retrieved_img = dataset["train"][int(index)]["image"]
64
  images.append(retrieved_img)
65
  return images
 
68
  app = gr.Interface(
69
  search_index,
70
  inputs=[
71
+ gr.Image(label="Input Image"),
72
  gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top K"),
73
  ],
74
  outputs=[