egmaminta's picture
Update app.py
56b7b2c
raw
history blame
2.59 kB
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import gradio
import torch
from einops import rearrange
import numpy
extractor = AutoFeatureExtractor.from_pretrained("vincentclaes/mit-indoor-scenes")
model = AutoModelForImageClassification.from_pretrained("vincentclaes/mit-indoor-scenes")
labels = {
"0": "airport_inside",
"1": "artstudio",
"2": "auditorium",
"3": "bakery",
"4": "bar",
"5": "bathroom",
"6": "bedroom",
"7": "bookstore",
"8": "bowling",
"9": "buffet",
"10": "casino",
"11": "children_room",
"12": "church_inside",
"13": "classroom",
"14": "cloister",
"15": "closet",
"16": "clothingstore",
"17": "computerroom",
"18": "concert_hall",
"19": "corridor",
"20": "deli",
"21": "dentaloffice",
"22": "dining_room",
"23": "elevator",
"24": "fastfood_restaurant",
"25": "florist",
"26": "gameroom",
"27": "garage",
"28": "greenhouse",
"29": "grocerystore",
"30": "gym",
"31": "hairsalon",
"32": "hospitalroom",
"33": "inside_bus",
"34": "inside_subway",
"35": "jewelleryshop",
"36": "kindergarden",
"37": "kitchen",
"38": "laboratorywet",
"39": "laundromat",
"40": "library",
"41": "livingroom",
"42": "lobby",
"43": "locker_room",
"44": "mall",
"45": "meeting_room",
"46": "movietheater",
"47": "museum",
"48": "nursery",
"49": "office",
"50": "operating_room",
"51": "pantry",
"52": "poolinside",
"53": "prisoncell",
"54": "restaurant",
"55": "restaurant_kitchen",
"56": "shoeshop",
"57": "stairscase",
"58": "studiomusic",
"59": "subway",
"60": "toystore",
"61": "trainstation",
"62": "tv_studio",
"63": "videostore",
"64": "waitingroom",
"65": "warehouse",
"66": "winecellar"
}
def classify(image):
model.eval()
with torch.no_grad():
inputs = extractor(images=image, return_tensors='pt')
outputs = model(**inputs).logits
outputs = rearrange(outputs, '1 j->j')
outputs = outputs.cpu().numpy()
outputs = (numpy.exp(outputs)) / (numpy.sum(numpy.exp(outputs)))
return {labels[str(i)]: float(outputs[i]) for i in range(len(labels))}
gradio.Interface(fn=classify,
inputs=gradio.inputs.Image(shape=(224,224), image_mode='RGB', source='upload', tool='editor', type='pil', label=None, optional=False),
outputs=gradio.outputs.Label(num_top_classes=5, type='auto'),
allow_flagging='never').launch()