omerXfaruq commited on
Commit
28de0db
1 Parent(s): a1a1e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -1,22 +1,27 @@
1
  import gradio as gr
2
  import os
3
- from torchvision.transforms import Resize
4
  from upstash_vector import Index
5
  from datasets import load_dataset
6
  from transformers import AutoFeatureExtractor, AutoModel
7
 
8
  index = Index.from_env()
9
- print(os.environ["UPSTASH_VECTOR_REST_URL"])
10
- print(os.environ["UPSTASH_VECTOR_REST_TOKEN"])
11
 
12
- resize_transform = Resize((250, 250))
 
 
 
 
 
 
 
 
13
 
14
  model_ckpt = "google/vit-base-patch16-224-in21k"
15
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
16
  model = AutoModel.from_pretrained(model_ckpt)
17
  hidden_dim = model.config.hidden_size
18
-
19
- dataset = load_dataset("HengJi/human_faces")
20
 
21
  with gr.Blocks() as demo:
22
  gr.Markdown(
@@ -31,18 +36,18 @@ with gr.Blocks() as demo:
31
  with gr.Row():
32
  with gr.Column(scale=1):
33
  input_image = gr.Image(type="pil")
34
- with gr.Column(scale=2):
35
- output_image = gr.Gallery(height=600)
36
 
37
 
38
  @input_image.upload(inputs=input_image, outputs=output_image)
39
  def find_similar_faces(image):
40
- resized_image = resize_transform(image)
41
- inputs = extractor(images=resized_image, return_tensors="pt")
42
  outputs = model(**inputs)
43
  embed = outputs.last_hidden_state[0][0]
44
- result = index.query(vector=embed.tolist(), top_k=4)
45
- return [dataset["train"][int(vector.id[3:])]["image"] for vector in result]
46
 
47
  with gr.Tab("Advanced"):
48
  with gr.Row():
@@ -56,12 +61,12 @@ with gr.Blocks() as demo:
56
 
57
  @adv_input_image.upload(inputs=[adv_input_image, adv_image_count], outputs=[adv_output_image])
58
  def find_similar_faces(image, count):
59
- resized_image = resize_transform(image)
60
- inputs = extractor(images=resized_image, return_tensors="pt")
61
  outputs = model(**inputs)
62
  embed = outputs.last_hidden_state[0][0]
63
- result = index.query(vector=embed.tolist(), top_k=min(count, 9))
64
- return [dataset["train"][int(vector.id[3:])]["image"] for vector in result]
65
 
66
  if __name__ == "__main__":
67
  demo.launch(debug=True)
 
1
  import gradio as gr
2
  import os
3
+ import torchvision.transforms as T
4
  from upstash_vector import Index
5
  from datasets import load_dataset
6
  from transformers import AutoFeatureExtractor, AutoModel
7
 
8
  index = Index.from_env()
 
 
9
 
10
+ # Data transformation chain.
11
+ transformation_chain = T.Compose(
12
+ [
13
+ T.Resize(extractor.size["height"]),
14
+ T.CenterCrop(extractor.size["height"]),
15
+ T.ToTensor(),
16
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
17
+ ]
18
+ )
19
 
20
  model_ckpt = "google/vit-base-patch16-224-in21k"
21
  extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
22
  model = AutoModel.from_pretrained(model_ckpt)
23
  hidden_dim = model.config.hidden_size
24
+ dataset = load_dataset("BounharAbdelaziz/Face-Aging-Dataset")
 
25
 
26
  with gr.Blocks() as demo:
27
  gr.Markdown(
 
36
  with gr.Row():
37
  with gr.Column(scale=1):
38
  input_image = gr.Image(type="pil")
39
+ with gr.Column(scale=3):
40
+ output_image = gr.Gallery(height=800)
41
 
42
 
43
  @input_image.upload(inputs=input_image, outputs=output_image)
44
  def find_similar_faces(image):
45
+ t_image = transformation_chain(image)
46
+ inputs = extractor(images=t_image, return_tensors="pt")
47
  outputs = model(**inputs)
48
  embed = outputs.last_hidden_state[0][0]
49
+ result = index.query(vector=embed, top_k=4)
50
+ return [dataset["train"][int(vector.id)]["image"] for vector in result]
51
 
52
  with gr.Tab("Advanced"):
53
  with gr.Row():
 
61
 
62
  @adv_input_image.upload(inputs=[adv_input_image, adv_image_count], outputs=[adv_output_image])
63
  def find_similar_faces(image, count):
64
+ t_image = transformation_chain(image)
65
+ inputs = extractor(images=t_image, return_tensors="pt")
66
  outputs = model(**inputs)
67
  embed = outputs.last_hidden_state[0][0]
68
+ result = index.query(vector=embed, top_k=max(1, min(19, count)))
69
+ return [dataset["train"][int(vector.id)]["image"] for vector in result]
70
 
71
  if __name__ == "__main__":
72
  demo.launch(debug=True)