File size: 1,368 Bytes
d4bddcf
 
2411be5
a7418a6
2411be5
 
a7418a6
d4bddcf
 
8eed7c4
3d2c6fa
 
2411be5
 
51012d6
 
b7802a1
3517393
dbab638
26e56f5
2411be5
 
dbab638
3517393
b7802a1
dbab638
 
2fd03ec
51012d6
 
2411be5
51012d6
 
3d2c6fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import AutoFeatureExtractor, ResNetForImageClassification
import torch
# from datasets import load_dataset

# dataset = load_dataset("huggingface/cats-image")
# image = dataset["test"]["image"][0]

feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

import gradio as gr
def segment(image):
    inputs = feature_extractor(image, return_tensors="pt")
    
    with torch.no_grad():
       logits = model(**inputs).logits
       probs = torch.nn.Softmax(dim=1)(logits)
       # labels = [(prob, model.config.id2label[idx]) for idx, prob in enumerate(probs[0])]
       labels = {model.config.id2label[idx] : float(prob) for idx, prob in enumerate(probs[0])}
       print(labels)
    
    # model predicts one of the 1000 ImageNet classes
    # predicted_label = logits.argmax(-1).item()
    return labels # model.config.id2label[predicted_label]
    
gr.Interface(fn=segment, inputs="image", outputs="label").launch()
#gr.Interface(fn=segment, inputs="image", outputs="text").launch()
    
#    with torch.no_grad():
#        prediction = torch.nn.functional.softmax(model(**inputs)[0], dim=0)
    
#    return {model.config.id2label[i]: float(prediction[i]) for i in range(3)}  
#gr.Interface(fn=segment, inputs="image", outputs="label").launch()