JeemsTerri commited on
Commit
60869f5
1 Parent(s): f0299d1

update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -1
app.py CHANGED
@@ -1,3 +1,26 @@
1
  import gradio as gr
 
 
 
2
 
3
- gr.load("models/jtas/fish_classification").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
+ processor_pipe = AutoImageProcessor.from_pretrained("jtas/fish_classification")
7
+ model_pipe = AutoModelForImageClassification.from_pretrained("jtas/fish_classification")
8
+
9
+ def classify_pipe(image):
10
+ inputs = processor_pipe(images=image, return_tensors="pt")
11
+
12
+ outputs = model_pipe(**inputs)
13
+ logits = outputs.logits
14
+
15
+ # Get the predicted label
16
+ predicted_class_idx = logits.argmax(-1).item()
17
+ labels = model_pipe.config.id2label
18
+ pipe_score = np.argmax(logits, dim=1).max().item()
19
+ predicted_label = labels[predicted_class_idx]
20
+
21
+ # Return the predicted label and score
22
+ return {"label": predicted_label, "score": pipe_score}
23
+
24
+ # Create a Gradio interface
25
+ iface = gr.Interface(fn=classify_pipe, inputs=gr.Image(), outputs="json", title="Fish Classification")
26
+ iface.launch()