egmaminta commited on
Commit
5d5ba51
·
1 Parent(s): c743d51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -1
app.py CHANGED
@@ -1,4 +1,93 @@
1
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
2
  import gradio
 
 
 
3
 
4
- gradio.Interface.load(name='huggingface/vincentclaes/mit-indoor-scenes').launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
2
  import gradio
3
+ import torch
4
+ from einops import rearrange
5
+ import numpy
6
 
7
+ extractor = AutoFeatureExtractor.from_pretrained("vincentclaes/mit-indoor-scenes")
8
+ model = AutoModelForImageClassification.from_pretrained("vincentclaes/mit-indoor-scenes")
9
+
10
+ labels = {
11
+ "0": "airport_inside",
12
+ "1": "artstudio",
13
+ "2": "auditorium",
14
+ "3": "bakery",
15
+ "4": "bar",
16
+ "5": "bathroom",
17
+ "6": "bedroom",
18
+ "7": "bookstore",
19
+ "8": "bowling",
20
+ "9": "buffet",
21
+ "10": "casino",
22
+ "11": "children_room",
23
+ "12": "church_inside",
24
+ "13": "classroom",
25
+ "14": "cloister",
26
+ "15": "closet",
27
+ "16": "clothingstore",
28
+ "17": "computerroom",
29
+ "18": "concert_hall",
30
+ "19": "corridor",
31
+ "20": "deli",
32
+ "21": "dentaloffice",
33
+ "22": "dining_room",
34
+ "23": "elevator",
35
+ "24": "fastfood_restaurant",
36
+ "25": "florist",
37
+ "26": "gameroom",
38
+ "27": "garage",
39
+ "28": "greenhouse",
40
+ "29": "grocerystore",
41
+ "30": "gym",
42
+ "31": "hairsalon",
43
+ "32": "hospitalroom",
44
+ "33": "inside_bus",
45
+ "34": "inside_subway",
46
+ "35": "jewelleryshop",
47
+ "36": "kindergarden",
48
+ "37": "kitchen",
49
+ "38": "laboratorywet",
50
+ "39": "laundromat",
51
+ "40": "library",
52
+ "41": "livingroom",
53
+ "42": "lobby",
54
+ "43": "locker_room",
55
+ "44": "mall",
56
+ "45": "meeting_room",
57
+ "46": "movietheater",
58
+ "47": "museum",
59
+ "48": "nursery",
60
+ "49": "office",
61
+ "50": "operating_room",
62
+ "51": "pantry",
63
+ "52": "poolinside",
64
+ "53": "prisoncell",
65
+ "54": "restaurant",
66
+ "55": "restaurant_kitchen",
67
+ "56": "shoeshop",
68
+ "57": "stairscase",
69
+ "58": "studiomusic",
70
+ "59": "subway",
71
+ "60": "toystore",
72
+ "61": "trainstation",
73
+ "62": "tv_studio",
74
+ "63": "videostore",
75
+ "64": "waitingroom",
76
+ "65": "warehouse",
77
+ "66": "winecellar"
78
+ }
79
+
80
+ def classify(image):
81
+ model.eval()
82
+ with torch.no_grad():
83
+ inputs = extractor(images=image, return_tensors='pt')
84
+ outputs = model(**inputs).logits
85
+ outputs = rearrange(outputs, '1 j->j')
86
+ outputs = outputs.cpu().numpy()
87
+ outputs = (numpy.exp(outputs)) / (numpy.sum(numpy.exp(outputs)))
88
+ return {labels[str(i)]: float(outputs[i]) for i in range(len(labels))}
89
+
90
+ gradio.Interface(fn=classify,
91
+ inputs=gradio.inputs.Image(shape=(224,224), image_mode='RGB', source='upload', tool='editor', type='pil', label=None, optional=False),
92
+ outputs=gradio.outputs.Label(num_top_classes=5, type='auto'),
93
+ allow_flagging='never').launch(inbrowser=True, debug=True)