danielhshi8224 commited on
Commit
f572b0e
·
1 Parent(s): ee88f70

Multi image classification

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -126,12 +126,10 @@ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
126
  model.eval()
127
 
128
  # (Optional) use model's own labels if present
129
- ID2LABEL = (
130
- [model.config.id2label[str(i)] for i in range(model.config.num_labels)]
131
- if getattr(model.config, "id2label", None)
132
- else ['Eel','Scallop','Crab','Flatfish','Roundfish','Skate','Whelk']
133
- )
134
-
135
  def classify_image(image):
136
  if not isinstance(image, Image.Image):
137
  image = Image.fromarray(image).convert("RGB")
@@ -217,7 +215,7 @@ batch = gr.Interface(
217
  fn=classify_images_batch,
218
  inputs=gr.Files(label="Upload up to 10 images"),
219
  outputs=[
220
- gr.Gallery(label="Results (Top-1 in caption)").style(grid=3, height=500),
221
  gr.Dataframe(
222
  headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
223
  label="Predictions Table",
 
126
  model.eval()
127
 
128
  # (Optional) use model's own labels if present
129
+ ID2LABEL = [
130
+ model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
131
+ for i in range(model.config.num_labels)
132
+ ]
 
 
133
  def classify_image(image):
134
  if not isinstance(image, Image.Image):
135
  image = Image.fromarray(image).convert("RGB")
 
215
  fn=classify_images_batch,
216
  inputs=gr.Files(label="Upload up to 10 images"),
217
  outputs=[
218
+ gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3),
219
  gr.Dataframe(
220
  headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
221
  label="Predictions Table",