dentadelta123 commited on
Commit
550b53d
1 Parent(s): 65cf06b

Add application file

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import torch
3
+ import gradio as gr
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ try:
7
+ from torchvision.transforms import InterpolationMode
8
+ BICUBIC = InterpolationMode.BICUBIC
9
+ except ImportError:
10
+ BICUBIC = Image.BICUBIC
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ model, preprocess = clip.load('ViT-L/14@336px')
15
+
16
+
17
+ def zeroshot_detection(Press_Clear_Dont_Stack_Image):
18
+ inp = Press_Clear_Dont_Stack_Image
19
+
20
+ captions = "photo of a guardrail, no guardrail in the photo" #CHANGE THIS IF YOU WANT TO CHANGE THE PREDICTION: separate by commas
21
+
22
+ captions = captions.split(',')
23
+ caption = clip.tokenize(captions).to(device)
24
+ image = preprocess(inp).unsqueeze(0).to(device)
25
+ with torch.no_grad():
26
+ image_features = model.encode_image(image)
27
+ text_features = model.encode_text(caption)
28
+ image_features /= image_features.norm(dim=-1, keepdim=True)
29
+ text_features /= text_features.norm(dim=-1, keepdim=True)
30
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
31
+ values, indices = similarity[0].topk(len(captions))
32
+ return {captions[indices[i].item()]: float(values[i].item()) for i in range(len(values))}
33
+
34
+ gr.Interface(fn=zeroshot_detection,
35
+ inputs=[gr.Image(type="pil")],
36
+ outputs=gr.Label(num_top_classes=1)).launch()