baixintech_zhangyiming_prod commited on
Commit
2422ed3
1 Parent(s): ef8cb22

complete app

Browse files
Files changed (5) hide show
  1. app.py +46 -3
  2. images/00000048.jpg +0 -0
  3. images/00004403.jpg +0 -0
  4. images/00004405.jpg +0 -0
  5. word2idx.json +105 -0
app.py CHANGED
@@ -1,7 +1,50 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import ViTForImageClassification
3
+ from torchvision import transforms
4
+ import os
5
+ import numpy as np
6
+ import json
7
 
 
 
8
 
9
+ class WordVocabulary:
10
+ def __init__(self, records=None, word2idx_path=None):
11
+ if word2idx_path is not None:
12
+ self.records = []
13
+ self.word2idx = json.load(open(word2idx_path, "r"))
14
+ self.words = list(self.word2idx.keys())
15
+ return
16
+
17
+
18
+ def build_vocabulary(self):
19
+ words = set()
20
+ for r in self.records:
21
+ words.update([w.strip() for w in r['text'].split(",")])
22
+ self.words = sorted(list(words))
23
+ self.word2idx = {w: idx for (idx, w) in enumerate(self.words)}
24
+
25
+
26
+ vocabulary = WordVocabulary(word2idx_path="word2idx.json")
27
+ model = ViTForImageClassification.from_pretrained("Inf009/food1024_vit_focal_mixup", problem_type="multi_label_classification", num_labels=len(vocabulary))
28
+ test_transforms = transforms.Compose(
29
+ [
30
+ transforms.Resize((256, 256)),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ ]
34
+ )
35
+
36
+ def multi_label_predict(img, threshold=0.5):
37
+ img_transformed = test_transforms(img)
38
+ outputs = model(img_transformed.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
39
+ indices = np.where(outputs > threshold)[0]
40
+ indices = sorted(indices, key=lambda x: outputs[x], reverse=True)
41
+ predict_tags = [vocabulary[idx] for idx in indices]
42
+ return predict_tags
43
+
44
+ demo_image_path = "images"
45
+ images = [f for f in os.listdir(demo_image_path) if f.endswith(".jpg")][:10]
46
+ images = [os.path.join(demo_image_path, file) for file in images]
47
+ examples = [[image, 0.5] for image in images]
48
+ iface = gr.Interface(fn=multi_label_predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(default=0.5)],
49
+ examples=examples, outputs="text")
50
  iface.launch()
images/00000048.jpg ADDED
images/00004403.jpg ADDED
images/00004405.jpg ADDED
word2idx.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "French beans": 0,
3
+ "almond": 1,
4
+ "apple": 2,
5
+ "apricot": 3,
6
+ "asparagus": 4,
7
+ "avocado": 5,
8
+ "bamboo shoots": 6,
9
+ "banana": 7,
10
+ "bean sprouts": 8,
11
+ "biscuit": 9,
12
+ "blueberry": 10,
13
+ "bread": 11,
14
+ "broccoli": 12,
15
+ "cabbage": 13,
16
+ "cake": 14,
17
+ "candy": 15,
18
+ "carrot": 16,
19
+ "cashew": 17,
20
+ "cauliflower": 18,
21
+ "celery stick": 19,
22
+ "cheese butter": 20,
23
+ "cherry": 21,
24
+ "chicken duck": 22,
25
+ "chocolate": 23,
26
+ "cilantro mint": 24,
27
+ "coffee": 25,
28
+ "corn": 26,
29
+ "crab": 27,
30
+ "cucumber": 28,
31
+ "date": 29,
32
+ "dried cranberries": 30,
33
+ "egg": 31,
34
+ "egg tart": 32,
35
+ "eggplant": 33,
36
+ "enoki mushroom": 34,
37
+ "fig": 35,
38
+ "fish": 36,
39
+ "french fries": 37,
40
+ "fried meat": 38,
41
+ "garlic": 39,
42
+ "ginger": 40,
43
+ "grape": 41,
44
+ "green beans": 42,
45
+ "hamburg": 43,
46
+ "hanamaki baozi": 44,
47
+ "ice cream": 45,
48
+ "juice": 46,
49
+ "kelp": 47,
50
+ "king oyster mushroom": 48,
51
+ "kiwi": 49,
52
+ "lamb": 50,
53
+ "lemon": 51,
54
+ "lettuce": 52,
55
+ "mango": 53,
56
+ "melon": 54,
57
+ "milk": 55,
58
+ "milkshake": 56,
59
+ "noodles": 57,
60
+ "okra": 58,
61
+ "olives": 59,
62
+ "onion": 60,
63
+ "orange": 61,
64
+ "other ingredients": 62,
65
+ "oyster mushroom": 63,
66
+ "pasta": 64,
67
+ "peach": 65,
68
+ "peanut": 66,
69
+ "pear": 67,
70
+ "pepper": 68,
71
+ "pie": 69,
72
+ "pineapple": 70,
73
+ "pizza": 71,
74
+ "popcorn": 72,
75
+ "pork": 73,
76
+ "potato": 74,
77
+ "pudding": 75,
78
+ "pumpkin": 76,
79
+ "rape": 77,
80
+ "raspberry": 78,
81
+ "red beans": 79,
82
+ "rice": 80,
83
+ "salad": 81,
84
+ "sauce": 82,
85
+ "sausage": 83,
86
+ "seaweed": 84,
87
+ "shellfish": 85,
88
+ "shiitake": 86,
89
+ "shrimp": 87,
90
+ "snow peas": 88,
91
+ "soup": 89,
92
+ "soy": 90,
93
+ "spring onion": 91,
94
+ "steak": 92,
95
+ "strawberry": 93,
96
+ "tea": 94,
97
+ "tofu": 95,
98
+ "tomato": 96,
99
+ "walnut": 97,
100
+ "watermelon": 98,
101
+ "white button mushroom": 99,
102
+ "white radish": 100,
103
+ "wine": 101,
104
+ "wonton dumplings": 102
105
+ }