from transformers import AutoFeatureExtractor, ResNetForImageClassification | |
import torch | |
# from datasets import load_dataset | |
# dataset = load_dataset("huggingface/cats-image") | |
# image = dataset["test"]["image"][0] | |
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") | |
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") | |
import gradio as gr | |
def segment(image): | |
inputs = feature_extractor(image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = torch.nn.Softmax(dim=1)(logits) | |
# labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs[0])] | |
labels = {model.config.id2label[idx] : float(prob) for idx, prob in enumerate(probs[0])} | |
print(labels) | |
# model predicts one of the 1000 ImageNet classes | |
# predicted_label = logits.argmax(-1).item() | |
return labels # model.config.id2label[predicted_label] | |
gr.Interface(fn=segment, inputs="image", outputs="label").launch() | |
#gr.Interface(fn=segment, inputs="image", outputs="text").launch() | |
# with torch.no_grad(): | |
# prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0) | |
# return {model.config.id2label[i]: float(prediction[i]) for i in range(3)} | |
#gr.Interface(fn=segment, inputs="image", outputs="label").launch() | |