smaciu commited on
Commit
96c1982
1 Parent(s): bda97f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -25
app.py CHANGED
@@ -1,32 +1,28 @@
1
- import requests
2
-
3
  import gradio as gr
4
- import torch
5
- from timm import create_model
6
- from timm.data import resolve_data_config
7
- from timm.data.transforms_factory import create_transform
8
-
9
- IMAGENET_1k_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
10
- LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
11
-
12
- model = create_model('resnet50', pretrained=True)
13
 
14
- transform = create_transform(
15
- **resolve_data_config({}, model=model)
16
- )
17
- model.eval()
18
 
19
- def predict_fn(img):
20
- img = img.convert('RGB')
21
- img = transform(img).unsqueeze(0)
22
 
23
- with torch.no_grad():
24
- out = model(img)
25
-
26
- probabilites = torch.nn.functional.softmax(out[0], dim=0)
27
 
28
- values, indices = torch.topk(probabilites, k=5)
 
 
 
 
29
 
30
- return {LABELS[i]: v.item() for i, v in zip(indices, values)}
 
 
 
31
 
32
- gr.Interface(predict_fn, gr.inputs.Image(type='pil'), outputs='label').launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from fastai.vision.all import *
3
+ import skimage
4
+ from huggingface_hub import from_pretrained_fastai
 
 
 
 
 
 
5
 
 
 
 
 
6
 
7
+ def label_func(f): return f.name[:2]
 
 
8
 
9
+ learn = from_pretrained_fastai("smaciu/bee-wings-classifier")
 
 
 
10
 
11
+ labels = learn.dls.vocab
12
+ def predict(img):
13
+ #img = PILImage.create(img)
14
+ pred,pred_idx,probs = learn.predict(img)
15
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}
16
 
17
+ title="Honey Bee Wing Classifier"
18
+ description = "A bee wings classifier trained with fastai"
19
+ examples = ['ES-0003-2092-0.dw.png','HU-0001-2019-000007-R.dw.png','PL-0002-000187-R.dw.png']
20
+ enable_queue=True
21
 
22
+ gr.Interface(fn=predict,
23
+ inputs=gr.inputs.Image(type='pil'),
24
+ outputs=gr.outputs.Label(num_top_classes=3),
25
+ title=title,
26
+ description=description,
27
+ examples=examples,
28
+ enable_queue=enable_queue).launch()