gbryan commited on
Commit
5815b4c
1 Parent(s): a123851
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -1,6 +1,14 @@
1
- import gradio as gr
2
- from transformer import pipeline
 
3
 
4
- pipe = pipeline(task="image-classification", model="imjeffhi/pokemon_classifier")
 
 
 
5
 
6
- gr.Interface.from_pipeline(pipe, title="Pokemon Classifier", description="A fine-tuned version of ViT-base on a collected set of Pokémon images", allow_flagging="never").launch(inbrowser=True)
 
 
 
 
 
1
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
2
+ from PIL import Image
3
+ import torch
4
 
5
+ # Loading in Model
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model = ViTForImageClassification.from_pretrained( "imjeffhi/pokemon_classifier").to(device)
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('imjeffhi/pokemon_classifier')
9
 
10
+ # Caling the model on a test image
11
+ img = Image.open('test.jpg')
12
+ extracted = feature_extractor(images=img, return_tensors='pt').to(device)
13
+ predicted_id = model(**extracted).logits.argmax(-1).item()
14
+ predicted_pokemon = model.config.id2label[predicted_id]