omerXfaruq commited on
Commit
cf029f9
·
1 Parent(s): a94c49d
Files changed (2) hide show
  1. app.py +21 -27
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,27 +2,22 @@ import gradio as gr
2
  import os
3
  from torchvision.transforms import Resize
4
  from upstash_vector import Index
5
-
 
6
 
7
  index = Index.from_env()
8
- print(os.environ("UPSTASH_VECTOR_REST_URL"))
9
- print(os.environ("UPSTASH_VECTOR_REST_TOKEN"))
10
 
11
- resize_transform = Resize((250,250))
12
-
13
-
14
- from transformers import AutoFeatureExtractor, AutoModel
15
 
16
  model_ckpt = "google/vit-base-patch16-224-in21k"
17
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
18
  model = AutoModel.from_pretrained(model_ckpt)
19
  hidden_dim = model.config.hidden_size
20
 
21
- from datasets import load_dataset
22
-
23
  dataset = load_dataset("HengJi/human_faces")
24
 
25
-
26
  with gr.Blocks() as demo:
27
  gr.Markdown(
28
  """
@@ -39,16 +34,16 @@ with gr.Blocks() as demo:
39
  with gr.Column(scale=3):
40
  output_image = gr.Gallery()
41
 
42
-
43
  @input_image.upload(inputs=input_image, outputs=output_image)
44
- def find_similar_faces(image):
45
- resized_image = resize_transform(image)
46
- inputs = extractor(images=image, return_tensors="pt")
47
- outputs = model(**inputs)
48
- embed = outputs.last_hidden_state[0][0]
49
- result = index.query(vector=embed.tolist(), top_k=3)
50
- return[dataset["train"][int(vector.id[3:])]["image"] for vector in result]
51
-
52
  with gr.Tab("Advanced"):
53
  with gr.Row():
54
  with gr.Column(scale=1):
@@ -61,13 +56,12 @@ with gr.Blocks() as demo:
61
 
62
  @adv_input_image.upload(inputs=[adv_input_image, adv_image_count], outputs=[adv_output_image])
63
  def find_similar_faces(image, count):
64
- resized_image = resize_transform(image)
65
- inputs = extractor(images=image, return_tensors="pt")
66
- outputs = model(**inputs)
67
- embed = outputs.last_hidden_state[0][0]
68
- result = index.query(vector=embed.tolist(), top_k=min(count, 9))
69
- return[dataset["train"][int(vector.id[3:])]["image"] for vector in result]
70
-
71
 
72
  if __name__ == "__main__":
73
- demo.launch(debug=True)
 
2
  import os
3
  from torchvision.transforms import Resize
4
  from upstash_vector import Index
5
+ from datasets import load_dataset
6
+ from transformers import AutoFeatureExtractor, AutoModel
7
 
8
  index = Index.from_env()
9
+ print(os.environ["UPSTASH_VECTOR_REST_URL"])
10
+ print(os.environ["UPSTASH_VECTOR_REST_TOKEN"])
11
 
12
+ resize_transform = Resize((250, 250))
 
 
 
13
 
14
  model_ckpt = "google/vit-base-patch16-224-in21k"
15
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
16
  model = AutoModel.from_pretrained(model_ckpt)
17
  hidden_dim = model.config.hidden_size
18
 
 
 
19
  dataset = load_dataset("HengJi/human_faces")
20
 
 
21
  with gr.Blocks() as demo:
22
  gr.Markdown(
23
  """
 
34
  with gr.Column(scale=3):
35
  output_image = gr.Gallery()
36
 
37
+
38
  @input_image.upload(inputs=input_image, outputs=output_image)
39
+ def find_similar_faces(image):
40
+ resized_image = resize_transform(image)
41
+ inputs = extractor(images=resized_image, return_tensors="pt")
42
+ outputs = model(**inputs)
43
+ embed = outputs.last_hidden_state[0][0]
44
+ result = index.query(vector=embed.tolist(), top_k=3)
45
+ return [dataset["train"][int(vector.id[3:])]["image"] for vector in result]
46
+
47
  with gr.Tab("Advanced"):
48
  with gr.Row():
49
  with gr.Column(scale=1):
 
56
 
57
  @adv_input_image.upload(inputs=[adv_input_image, adv_image_count], outputs=[adv_output_image])
58
  def find_similar_faces(image, count):
59
+ resized_image = resize_transform(image)
60
+ inputs = extractor(images=resized_image, return_tensors="pt")
61
+ outputs = model(**inputs)
62
+ embed = outputs.last_hidden_state[0][0]
63
+ result = index.query(vector=embed.tolist(), top_k=min(count, 9))
64
+ return [dataset["train"][int(vector.id[3:])]["image"] for vector in result]
 
65
 
66
  if __name__ == "__main__":
67
+ demo.launch(debug=True)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torchvision
2
  transformers
3
  datasets
4
- upstash-vector
 
 
1
  torchvision
2
  transformers
3
  datasets
4
+ upstash-vector
5
+ gradio