JeemsTerri commited on
Commit
4f171d0
1 Parent(s): 9659e27

update gradio interface and inference logic

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -1,19 +1,46 @@
1
  import gradio as gr
2
  from PIL import Image
3
  import numpy as np
4
- from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
- processor = AutoImageProcessor.from_pretrained("jtas/fish_classification")
7
- classifier = AutoModelForImageClassification.from_pretrained("jtas/fish_classification")
 
 
8
 
9
  def fish_classification(image):
10
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- outputs = classifier(**inputs)
 
 
13
 
14
- return outputs
15
 
16
- label_fish = gr.components.Json()
 
 
 
 
 
 
 
 
17
 
18
- iface = gr.Interface(fn=fish_classification, inputs=gr.Image(), outputs=label_fish, title="Fish Classification")
19
- iface.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
  import numpy as np
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline
5
 
6
+ classifier = pipeline(model="jtas/fish_classification")
7
+ model = AutoModelForImageClassification.from_pretrained("jtas/fish_classification")
8
+
9
+ fish_classes = model.config.id2label
10
 
11
  def fish_classification(image):
12
+ pil_img = Image.fromarray(np.uint8(image))
13
+
14
+ fish_prediction = classifier(pil_img)
15
+ class_probs = {str(pred["label"]): pred["score"] for pred in fish_prediction}
16
+
17
+ return class_probs
18
+
19
+ sample_images = [
20
+ ["img/Rastrelliger Faughni.jpg", "Rastrelliger Faughni"],
21
+ ["img/Chanos Chanos.jpg", "Chanos Chanos"],
22
+ ["img/Eleutheronema Tetradactylum.jpeg", "Eleutheronema Tetradactylum"],
23
+ ["img/Johnius Trachycephalus.jpg", "Johnius Trachycephalus"],
24
+ ["img/Nibea Albiflora.jpeg", "Nibea Albiflora"],
25
+ ["img/Oreochromis Mossambicus.jpg", "Oreochromis Mossambicus"],
26
+ ["img/Oreochromis Niloticus.png", "Oreochromis Niloticus"],
27
+ ["img/Upeneus Moluccensis.jpg", "Upeneus Moluccensis"],
28
+ ]
29
 
30
+ supported_classes_html = ""
31
+ for class_id, class_name in fish_classes.items():
32
+ supported_classes_html += f"<p><b>{class_id}</b>: {class_name}</p>"
33
 
34
+ label = gr.components.Label()
35
 
36
+ iface = gr.Interface(
37
+ fn=fish_classification,
38
+ inputs=gr.Image(label="Upload an image"),
39
+ examples=sample_images,
40
+ outputs=label,
41
+ title="Fish Classification",
42
+ description="This web app classifies fish in an image. Upload an image of a fish to see the predicted class probabilities.",
43
+ article="<div style='margin-top: 20px;'><h2>Supported Fish Classes</h2>" + supported_classes_html + "</div>",
44
+ )
45
 
46
+ iface.launch()