Caniya commited on
Commit
d603ca0
1 Parent(s): 0a351b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -1,21 +1,38 @@
 
 
 
 
1
  import torch
2
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  import gradio as gr
4
 
5
- model_id = f'Caniya/vit-base-patch16-224-finetuned-flower'
 
6
  labels = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
7
 
 
8
  def classify_image(image):
9
- model = AutoModelForImageClassification.from_pretrained(model_id)
10
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
11
- inp = feature_extractor(image, return_tensors='pt')
12
- outp = model(**inp)
13
- pred = torch.nn.functional.softmax(outp.logits, dim=-1)
14
- preds = pred[0].cpu().detach().numpy()
15
- confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
16
- return confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- interface = gr.Interface(fn=classify_image,
19
- inputs='image',
20
- examples=['Dandelion_2.jpg', 'dandelion_1.jpg'],
21
- outputs='label').launch()
 
1
+ # Install necessary packages
2
+ !pip install torch transformers gradio
3
+
4
+ # Import libraries
5
  import torch
6
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
7
  import gradio as gr
8
 
9
+ # Define model ID and labels
10
+ model_id = 'Caniya/vit-base-patch16-224-finetuned-flower'
11
  labels = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
12
 
13
+ # Define classify_image function
14
  def classify_image(image):
15
+ model = AutoModelForImageClassification.from_pretrained(model_id)
16
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
17
+ inp = feature_extractor(image, return_tensors='pt')
18
+ outp = model(**inp)
19
+ pred = torch.nn.functional.softmax(outp.logits, dim=-1)
20
+ preds = pred[0].cpu().detach().numpy()
21
+ confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
22
+ return confidence
23
+
24
+ # Create interface
25
+ interface = gr.Interface(
26
+ fn=classify_image,
27
+ inputs='image',
28
+ outputs='label',
29
+ title='Flower Image Classifier',
30
+ description='Classify images of flowers into different categories.',
31
+ examples=[
32
+ ['flower-1.jpeg'],
33
+ ['flower-2.jpeg']
34
+ ]
35
+ )
36
 
37
+ # Launch the interface
38
+ interface.launch()