danielhshi8224 commited on
Commit
ee88f70
Β·
1 Parent(s): 9d06c04

update for multi image

Browse files
Files changed (1) hide show
  1. app.py +210 -87
app.py CHANGED
@@ -1,112 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
  import os
6
 
7
- # Get model path (Windows compatible)
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
  MODEL_ID = "dshi01/convnext-tiny-224-7clss"
10
 
11
- # Try different possible filenames
12
- # possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
13
- # model_path = None
14
-
15
- # for name in possible_names:
16
- # test_path = os.path.join(BASE_DIR, name)
17
- # if os.path.exists(test_path):
18
- # model_path = test_path
19
- # print(f"βœ“ Found model: {name}")
20
- # break
21
-
22
- # if model_path is None:
23
- # raise FileNotFoundError(f"Could not find model file. Tried: {possible_names}")
24
-
25
- # Species categories (7 classes)
26
- SPECIES_CATEGORIES = [
27
- 'Eel',
28
- 'Scallop',
29
- 'Crab',
30
- 'Flatfish',
31
- 'Roundfish',
32
- 'Skate',
33
- 'Whelk'
34
- ]
35
-
36
- # Load model
37
  print(f"Loading model from: {MODEL_ID}")
38
- # model = AutoModelForImageClassification.from_pretrained(
39
- # 'facebook/convnext-tiny-224',
40
- # num_labels=7,
41
- # ignore_mismatched_sizes=True
42
- # )
43
- processor=AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
44
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
 
45
 
46
- # Load weights
47
- # checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
48
- # if isinstance(checkpoint, dict):
49
- # if 'model' in checkpoint:
50
- # checkpoint = checkpoint['model']
51
- # elif 'state_dict' in checkpoint:
52
- # checkpoint = checkpoint['state_dict']
53
 
54
- # model.load_state_dict(checkpoint, strict=False)
55
- # model.eval()
 
56
 
57
- # Load processor
58
- # processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
59
- # print("βœ“ Model loaded successfully!")
 
60
 
61
- def classify_image(image):
 
 
 
 
 
62
  """
63
- Classify a benthic species image.
64
-
65
- Args:
66
- image: PIL Image or numpy array
67
-
68
  Returns:
69
- dict: Predictions with species names and confidence scores
 
70
  """
71
- # Convert to PIL if needed
72
- if not isinstance(image, Image.Image):
73
- image = Image.fromarray(image).convert('RGB')
74
-
75
- # Preprocess
76
- inputs = processor(images=image, return_tensors="pt")
77
-
78
- # Predict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  with torch.no_grad():
80
- outputs = model(**inputs)
81
- logits = outputs.logits
82
- probabilities = torch.nn.functional.softmax(logits, dim=1)
83
-
84
- # Create results dictionary for Gradio
85
- results = {}
86
- for idx, prob in enumerate(probabilities[0]):
87
- results[SPECIES_CATEGORIES[idx]] = float(prob)
88
-
89
- return results
 
 
 
 
90
 
91
- # Create Gradio interface
92
- demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
93
  fn=classify_image,
94
  inputs=gr.Image(type="pil", label="Upload Underwater Image"),
95
- outputs=gr.Label(num_top_classes=7, label="Species Classification"),
96
- title="🌊 BenthicAI - Benthic Species Classifier",
97
- description="Upload an image of a benthic organism to classify it into one of 7 species categories. Built with ConvNeXT transformer model.",
98
- examples=[
99
- [os.path.join("examples", "eel.jpg")],
100
- [os.path.join("examples", "scallop.jpg")],
101
- [os.path.join("examples", "crab.jpg")],
102
- ] if os.path.exists("examples") else None,
103
- theme=gr.themes.Soft(),
104
- allow_flagging="never"
 
 
 
 
 
 
 
 
105
  )
106
 
 
 
107
  if __name__ == "__main__":
108
- demo.launch(
109
- server_name="0.0.0.0",
110
- server_port=7860,
111
- share=True # Set to True to get a public URL
112
- )
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ # from PIL import Image
5
+ # import os
6
+
7
+ # # Get model path (Windows compatible)
8
+ # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
+ # MODEL_ID = "dshi01/convnext-tiny-224-7clss"
10
+
11
+ # # Try different possible filenames
12
+ # # possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
13
+ # # model_path = None
14
+
15
+ # # for name in possible_names:
16
+ # # test_path = os.path.join(BASE_DIR, name)
17
+ # # if os.path.exists(test_path):
18
+ # # model_path = test_path
19
+ # # print(f"βœ“ Found model: {name}")
20
+ # # break
21
+
22
+ # # if model_path is None:
23
+ # # raise FileNotFoundError(f"Could not find model file. Tried: {possible_names}")
24
+
25
+ # # Species categories (7 classes)
26
+ # SPECIES_CATEGORIES = [
27
+ # 'Eel',
28
+ # 'Scallop',
29
+ # 'Crab',
30
+ # 'Flatfish',
31
+ # 'Roundfish',
32
+ # 'Skate',
33
+ # 'Whelk'
34
+ # ]
35
+
36
+ # # Load model
37
+ # print(f"Loading model from: {MODEL_ID}")
38
+ # # model = AutoModelForImageClassification.from_pretrained(
39
+ # # 'facebook/convnext-tiny-224',
40
+ # # num_labels=7,
41
+ # # ignore_mismatched_sizes=True
42
+ # # )
43
+ # processor=AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
44
+ # model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
45
+
46
+ # # Load weights
47
+ # # checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
48
+ # # if isinstance(checkpoint, dict):
49
+ # # if 'model' in checkpoint:
50
+ # # checkpoint = checkpoint['model']
51
+ # # elif 'state_dict' in checkpoint:
52
+ # # checkpoint = checkpoint['state_dict']
53
+
54
+ # # model.load_state_dict(checkpoint, strict=False)
55
+ # # model.eval()
56
+
57
+ # # Load processor
58
+ # # processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
59
+ # # print("βœ“ Model loaded successfully!")
60
+
61
+ # def classify_image(image):
62
+ # """
63
+ # Classify a benthic species image.
64
+
65
+ # Args:
66
+ # image: PIL Image or numpy array
67
+
68
+ # Returns:
69
+ # dict: Predictions with species names and confidence scores
70
+ # """
71
+ # # Convert to PIL if needed
72
+ # if not isinstance(image, Image.Image):
73
+ # image = Image.fromarray(image).convert('RGB')
74
+
75
+ # # Preprocess
76
+ # inputs = processor(images=image, return_tensors="pt")
77
+
78
+ # # Predict
79
+ # with torch.no_grad():
80
+ # outputs = model(**inputs)
81
+ # logits = outputs.logits
82
+ # probabilities = torch.nn.functional.softmax(logits, dim=1)
83
+
84
+ # # Create results dictionary for Gradio
85
+ # results = {}
86
+ # for idx, prob in enumerate(probabilities[0]):
87
+ # results[SPECIES_CATEGORIES[idx]] = float(prob)
88
+
89
+ # return results
90
+
91
+ # # Create Gradio interface
92
+ # demo = gr.Interface(
93
+ # fn=classify_image,
94
+ # inputs=gr.Image(type="pil", label="Upload Underwater Image"),
95
+ # outputs=gr.Label(num_top_classes=7, label="Species Classification"),
96
+ # title="🌊 BenthicAI - Benthic Species Classifier",
97
+ # description="Upload an image of a benthic organism to classify it into one of 7 species categories. Built with ConvNeXT transformer model.",
98
+ # examples=[
99
+ # [os.path.join("examples", "eel.jpg")],
100
+ # [os.path.join("examples", "scallop.jpg")],
101
+ # [os.path.join("examples", "crab.jpg")],
102
+ # ] if os.path.exists("examples") else None,
103
+ # theme=gr.themes.Soft(),
104
+ # allow_flagging="never"
105
+ # )
106
+
107
+ # if __name__ == "__main__":
108
+ # demo.launch(
109
+ # server_name="0.0.0.0",
110
+ # server_port=7860,
111
+ # share=True # Set to True to get a public URL
112
+ # )
113
  import gradio as gr
114
  import torch
115
+ import torch.nn.functional as F
116
  from transformers import AutoImageProcessor, AutoModelForImageClassification
117
  from PIL import Image
118
  import os
119
 
 
120
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
121
  MODEL_ID = "dshi01/convnext-tiny-224-7clss"
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  print(f"Loading model from: {MODEL_ID}")
124
+ processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
 
 
 
 
 
125
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
126
+ model.eval()
127
 
128
+ # (Optional) use model's own labels if present
129
+ ID2LABEL = (
130
+ [model.config.id2label[str(i)] for i in range(model.config.num_labels)]
131
+ if getattr(model.config, "id2label", None)
132
+ else ['Eel','Scallop','Crab','Flatfish','Roundfish','Skate','Whelk']
133
+ )
 
134
 
135
+ def classify_image(image):
136
+ if not isinstance(image, Image.Image):
137
+ image = Image.fromarray(image).convert("RGB")
138
 
139
+ inputs = processor(images=image, return_tensors="pt")
140
+ with torch.no_grad():
141
+ logits = model(**inputs).logits
142
+ probs = F.softmax(logits, dim=1)[0].tolist()
143
 
144
+ return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
145
+
146
+ # ---------- NEW: batch classify up to 10 images ----------
147
+ MAX_BATCH = 10
148
+
149
+ def classify_images_batch(files):
150
  """
151
+ files: list of gradio UploadedFile (paths) or None
 
 
 
 
152
  Returns:
153
+ - gallery: list of (image, caption)
154
+ - table: list of rows for Dataframe
155
  """
156
+ if not files:
157
+ return [], []
158
+
159
+ # Keep at most 10
160
+ files = files[:MAX_BATCH]
161
+
162
+ # Load as PIL
163
+ pil_images, names = [], []
164
+ for f in files:
165
+ path = getattr(f, "name", None) or getattr(f, "path", None) or f
166
+ try:
167
+ img = Image.open(path).convert("RGB")
168
+ pil_images.append(img)
169
+ names.append(os.path.basename(path))
170
+ except Exception:
171
+ # Skip unreadable file
172
+ continue
173
+
174
+ if not pil_images:
175
+ return [], []
176
+
177
+ # Batch preprocess + forward
178
+ inputs = processor(images=pil_images, return_tensors="pt")
179
  with torch.no_grad():
180
+ logits = model(**inputs).logits
181
+ probs = F.softmax(logits, dim=1)
182
+
183
+ # Build outputs
184
+ gallery = []
185
+ table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
186
+
187
+ for idx, (img, fname) in enumerate(zip(pil_images, names)):
188
+ p = probs[idx].tolist()
189
+ top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
190
+ top1 = top_idxs[0]
191
+ caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
192
+
193
+ gallery.append((img, f"{fname}\n{caption}"))
194
 
195
+ top3_labels = [ID2LABEL[i] for i in top_idxs]
196
+ top3_scores = [round(p[i], 4) for i in top_idxs]
197
+ table_rows.append([
198
+ fname,
199
+ ID2LABEL[top1],
200
+ round(p[top1], 4),
201
+ ", ".join(top3_labels),
202
+ ", ".join(map(str, top3_scores)),
203
+ ])
204
+
205
+ return gallery, table_rows
206
+
207
+ # ---------- UI ----------
208
+ single = gr.Interface(
209
  fn=classify_image,
210
  inputs=gr.Image(type="pil", label="Upload Underwater Image"),
211
+ outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
212
+ title="🌊 BenthicAI - Single Image",
213
+ description="Classify one image into one of 7 benthic species."
214
+ )
215
+
216
+ batch = gr.Interface(
217
+ fn=classify_images_batch,
218
+ inputs=gr.Files(label="Upload up to 10 images"),
219
+ outputs=[
220
+ gr.Gallery(label="Results (Top-1 in caption)").style(grid=3, height=500),
221
+ gr.Dataframe(
222
+ headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
223
+ label="Predictions Table",
224
+ wrap=True
225
+ )
226
+ ],
227
+ title="🌊 BenthicAI - Batch (up to 10)",
228
+ description="Upload multiple images (max 10). Outputs a gallery with captions and a table of top predictions.",
229
  )
230
 
231
+ demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
232
+
233
  if __name__ == "__main__":
234
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
235
+