eddydpan commited on
Commit
235b2f5
1 Parent(s): 2e7a500

Create run.py

Browse files
Files changed (1) hide show
  1. run.py +49 -0
run.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import open_clip
5
+
6
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
7
+ tokenizer = open_clip.get_tokenizer('ViT-B-32')
8
+ material_list = []
9
+
10
+ with open("results.txt", "r") as f:
11
+ results = f.readlines()
12
+ for line in results:
13
+ material = line.split(" [")[0]
14
+ material_list.append(material.strip()) # Trim any leading/trailing whitespace
15
+ f.close()
16
+
17
+ text = tokenizer(material_list)
18
+
19
+ def process_image(image_input):
20
+
21
+ results = {}
22
+ float_values = []
23
+ image = preprocess(image_input).unsqueeze(0)
24
+
25
+ with torch.no_grad(), torch.cuda.amp.autocast():
26
+ image_features = model.encode_image(image)
27
+ text_features = model.encode_text(text)
28
+ image_features /= image_features.norm(dim=-1, keepdim=True)
29
+ text_features /= text_features.norm(dim=-1, keepdim=True)
30
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
31
+
32
+ print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
33
+
34
+ counter = 0
35
+ for row in text_probs:
36
+ for column in row:
37
+ float_values.append(float(column))
38
+ results[float(column)] = material_list[counter]
39
+ counter += 1
40
+
41
+ sorted_float_values = sorted(float_values, reverse=True)
42
+ print(sorted_float_values)
43
+ return [["Material : " + str(results[sorted_float_values[0]]), "Confidence : " + str(sorted_float_values[0])], ["Material : " + str(results[sorted_float_values[1]]), "Confidence : " + str(sorted_float_values[1])]]
44
+
45
+ inputs = gr.inputs.Image(type="pil")
46
+ outputs = [gr.outputs.Textbox(label="Top Result"), gr.outputs.Textbox(label="Second Result")]
47
+
48
+ interface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs)
49
+ interface.launch(share=True)