File size: 813 Bytes
e660ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import gradio as gr
import torch

huggingface_username = 'i-am-holmes'
model_name = 'vit-base-patch16-224-finetuned-flower'

def classify_image(image):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = AutoModelForImageClassification.from_pretrained(f'{huggingface_username}/{model_name}').to(device)
  feature_extractor = AutoFeatureExtractor.from_pretrained(f'{huggingface_username}/{model_name}')  
  inp = feature_extractor(image, return_tensors='pt').to(device)
  outp = model(**inp)
  pred = torch.argmax(outp.logits, dim=1).item()
  return model.config.id2label[pred]

interface = gr.Interface(fn=classify_image, inputs=gr.Image(shape=(224, 224)), outputs="text").launch(debug=True, share=True)