JeemsTerri commited on
Commit
1cf3740
1 Parent(s): 0014013

fix app.py for inference logic

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -3,24 +3,17 @@ from PIL import Image
3
  import numpy as np
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
- processor_pipe = AutoImageProcessor.from_pretrained("jtas/fish_classification")
7
- model_pipe = AutoModelForImageClassification.from_pretrained("jtas/fish_classification")
8
 
9
- def classify_pipe(image):
10
- inputs = processor_pipe(images=image, return_tensors="pt")
11
 
12
- outputs = model_pipe(**inputs)
13
- logits = outputs.logits
14
 
15
- # Get the predicted label
16
- predicted_class_idx = logits.argmax(-1).item()
17
- labels = model_pipe.config.id2label
18
- pipe_score = np.argmax(logits, dim=1).max().item()
19
- predicted_label = labels[predicted_class_idx]
20
 
21
- # Return the predicted label and score
22
- return {"label": predicted_label, "score": pipe_score}
23
 
24
- # Create a Gradio interface
25
- iface = gr.Interface(fn=classify_pipe, inputs=gr.Image(), outputs="json", title="Fish Classification")
26
  iface.launch()
 
3
  import numpy as np
4
  from transformers import AutoImageProcessor, AutoModelForImageClassification
5
 
6
+ processor = AutoImageProcessor.from_pretrained("jtas/fish_classification")
7
+ classifier = AutoModelForImageClassification.from_pretrained("jtas/fish_classification")
8
 
9
+ def fish_classification(image):
10
+ inputs = processor(images=image, return_tensors="pt")
11
 
12
+ outputs = classifier(**inputs)
 
13
 
14
+ return outputs
 
 
 
 
15
 
16
+ label_fish = gr.components.Json()
 
17
 
18
+ iface = gr.Interface(fn=fish_classification, inputs=gr.Image(), outputs=label_fish, title="Fish Classification")
 
19
  iface.launch()