aldrinjenson commited on
Commit
9e73ccf
·
1 Parent(s): 19c8849

Add application file

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import open_clip
5
+
6
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
7
+ tokenizer = open_clip.get_tokenizer('ViT-B-32')
8
+
9
+ labels = ["vehicle accident", "fire", "a cat"]
10
+ text = tokenizer(labels)
11
+
12
+
13
+
14
+ def image_classifier(inp):
15
+ print(type(inp))
16
+ # image = preprocess(Image.open("accident.jpg")).unsqueeze(0)
17
+ image = preprocess(Image.fromarray(inp)).unsqueeze(0)
18
+ with torch.no_grad(), torch.cuda.amp.autocast():
19
+ image_features = model.encode_image(image)
20
+ text_features = model.encode_text(text)
21
+ image_features /= image_features.norm(dim=-1, keepdim=True)
22
+ text_features /= text_features.norm(dim=-1, keepdim=True)
23
+
24
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
25
+ print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
26
+ text_probs = text_probs[0]
27
+ print(text_probs[0])
28
+
29
+ maxProb = 0
30
+ ansIndex = ""
31
+ for idx, probs in enumerate(text_probs):
32
+ if probs > maxProb:
33
+ ansIndex = idx
34
+ maxProb = probs
35
+
36
+ obj = {}
37
+ for i in range(len(labels)):
38
+ currLabel = labels[i]
39
+ currProb = text_probs[i]
40
+ obj[currLabel]=currProb
41
+ print(obj)
42
+ return {labels[ansIndex] : 1}
43
+
44
+
45
+ image_input = gr.inputs.Image(shape=(224, 224))
46
+ output = gr.outputs.Label()
47
+ demo = gr.Interface(fn=image_classifier, inputs=image_input, outputs=output)
48
+ demo.launch(share=False)
49
+