Menghuan1918 commited on
Commit
03aafe3
1 Parent(s): 1268f7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import torch
2
  from PIL import Image
3
  import open_clip
4
- from gradio import Interface, inputs, outputs
5
 
6
  def start():
7
- model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
8
- tokenizer = open_clip.get_tokenizer('ViT-B-32')
 
 
9
  return model, preprocess, tokenizer
10
 
 
11
  def process(model, preprocess, tokenizer, image_path, text):
12
  if isinstance(image_path, str):
13
  image = Image.open(image_path)
@@ -24,17 +27,24 @@ def process(model, preprocess, tokenizer, image_path, text):
24
  similarity = (image_features @ text_features.T) * 100
25
  return similarity
26
 
 
27
  def predict(image, text):
28
- model, preprocess, tokenizer = start()
29
  similarity = process(model, preprocess, tokenizer, image, text)
30
  return similarity.item()
31
 
32
- inputs = [
33
- inputs.Image(type="pil", label="Image"),
34
- inputs.Textbox(label="Text")
35
- ]
36
-
37
- outputs = outputs.Textbox(label="Similarity")
 
 
 
38
 
39
  if __name__ == "__main__":
40
- Interface(fn=predict, inputs=inputs, outputs=outputs).launch()
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
  import open_clip
4
+ import gradio as gr
5
 
6
  def start():
7
+ model, _, preprocess = open_clip.create_model_and_transforms(
8
+ "ViT-B-32", pretrained="laion2b_s34b_b79k"
9
+ )
10
+ tokenizer = open_clip.get_tokenizer("ViT-B-32")
11
  return model, preprocess, tokenizer
12
 
13
+
14
  def process(model, preprocess, tokenizer, image_path, text):
15
  if isinstance(image_path, str):
16
  image = Image.open(image_path)
 
27
  similarity = (image_features @ text_features.T) * 100
28
  return similarity
29
 
30
+
31
  def predict(image, text):
 
32
  similarity = process(model, preprocess, tokenizer, image, text)
33
  return similarity.item()
34
 
35
+ gradio_app = gr.Interface(
36
+ predict,
37
+ inputs=[
38
+ gr.Image(label="Select the picture", type="pil"),
39
+ gr.Textbox(label="Enter the text"),
40
+ ],
41
+ outputs=gr.Textbox(label="Similarity"),
42
+ title="You draw&AI guess",
43
+ )
44
 
45
  if __name__ == "__main__":
46
+ model, preprocess, tokenizer = start()
47
+ # If you want to run it locally, you can use the following code :(
48
+ gradio_app.launch()
49
+ # If you want to share it online, you can use the following code :)
50
+ # gradio_app.launch(share=True)