laurru01 commited on
Commit
c216acb
·
verified ·
1 Parent(s): 083554c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +383 -0
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Python imports
2
+ import os
3
+ import re
4
+ import json
5
+ from typing import List, Dict, Any
6
+
7
+ # Data processing and visualization
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from tqdm.notebook import tqdm
11
+
12
+ # Deep Learning & ML
13
+ import torch
14
+ from transformers import (
15
+ AutoProcessor,
16
+ AutoModelForVision2Seq,
17
+ AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ TextStreamer,
20
+ Idefics3ForConditionalGeneration,
21
+ BitsAndBytesConfig
22
+
23
+ )
24
+
25
+ from unsloth import FastVisionModel
26
+
27
+ # Dataset handling
28
+ from datasets import load_from_disk
29
+
30
+ # API & Authentication
31
+ from huggingface_hub import login
32
+
33
+ # UI & Environment
34
+ import gradio as gr
35
+ from dotenv import load_dotenv
36
+
37
+ # Available models
38
+ MODELS = {
39
+ "Blood Cell Classifier with Llama-3.2": "laurru01/Llama-3.2-11B-Vision-Instruct-ft-PeripherallBloodCells",
40
+ "Blood Cell Classifier with Qwen2-VL": "laurru01/Qwen2-VL-2B-Instruct-ft-bloodcells-big",
41
+ "Blood Cell Classifier with SmolVLM": "laurru01/SmolVLM-Instruct-ft-PeripherallBloodCells",
42
+ }
43
+
44
+ # Global dictionary to store loaded models
45
+ loaded_models = {}
46
+
47
+ def initialize_models():
48
+ """Preload all models during startup"""
49
+ print("Initializing models...")
50
+ for model_name, model_path in MODELS.items():
51
+ print(f"Loading {model_name}...")
52
+ try:
53
+ if "SmolVLM" in model_name:
54
+ # Carga específica para SmolVLM
55
+ base_model = Idefics3ForConditionalGeneration.from_pretrained(
56
+ "HuggingFaceTB/SmolVLM-Instruct",
57
+ device_map="auto",
58
+ torch_dtype=torch.bfloat16,
59
+ load_in_4bit=True,
60
+ max_memory={0: "12GB"}
61
+ )
62
+ base_model.load_adapter(model_path)
63
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
64
+
65
+ loaded_models[model_name] = {
66
+ "model": base_model,
67
+ "processor": processor,
68
+ "type": "smolvlm"
69
+ }
70
+ else:
71
+ # Carga original para Llama y Qwen (sin cambios)
72
+ model, tokenizer = FastVisionModel.from_pretrained(
73
+ model_name=model_path,
74
+ load_in_4bit=True,
75
+ use_gradient_checkpointing="unsloth"
76
+ )
77
+ FastVisionModel.for_inference(model)
78
+ processor = AutoProcessor.from_pretrained(model_path)
79
+
80
+ loaded_models[model_name] = {
81
+ "model": model,
82
+ "tokenizer": tokenizer,
83
+ "processor": processor,
84
+ "type": "standard"
85
+ }
86
+ print(f"Successfully loaded {model_name}")
87
+
88
+ except Exception as e:
89
+ print(f"Error loading {model_name}: {str(e)}")
90
+
91
+ print("Model initialization complete")
92
+
93
+ def extract_cell_type(text):
94
+ """Extract cell type from generated description"""
95
+ cell_types = ['neutrophil', 'lymphocyte', 'monocyte', 'eosinophil', 'basophil']
96
+ text_lower = text.lower()
97
+ for cell_type in cell_types:
98
+ if cell_type in text_lower:
99
+ return cell_type.capitalize()
100
+ return "Unidentified Cell Type"
101
+
102
+ @torch.no_grad()
103
+ def generate_description_standard(model, tokenizer, image):
104
+ """Generate description using standard models (Llama and Qwen)"""
105
+ messages = [{
106
+ "role": "user",
107
+ "content": [
108
+ {"type": "image"},
109
+ {"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
110
+ ]}]
111
+
112
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
113
+ inputs = tokenizer(image, input_text, add_special_tokens=False, return_tensors="pt").to("cuda")
114
+
115
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True)
116
+ output = model.generate(
117
+ **inputs,
118
+ streamer=text_streamer,
119
+ max_new_tokens=1024,
120
+ use_cache=True,
121
+ temperature=1.5,
122
+ min_p=0.1
123
+ )
124
+
125
+ raw_output = tokenizer.decode(output[0], skip_special_tokens=True)
126
+ if "The provided image" in raw_output:
127
+ start_idx = raw_output.find("assistant")
128
+ cleaned_output = raw_output[start_idx:]
129
+ else:
130
+ cleaned_output = raw_output
131
+
132
+ return cleaned_output.strip()
133
+
134
+ @torch.no_grad()
135
+ def generate_description_smolvlm(model, processor, image):
136
+ """Generate description using SmolVLM model with memory-efficient settings"""
137
+ if image.mode != "RGB":
138
+ image = image.convert("RGB")
139
+
140
+ # Redimensionar a un tamaño más pequeño para reducir memoria
141
+ max_size = 192
142
+ image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
143
+
144
+ sample = [{
145
+ "role": "user",
146
+ "content": [
147
+ {"type": "image", "image": image},
148
+ {"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
149
+ ]
150
+ }]
151
+
152
+ text_input = processor.apply_chat_template(
153
+ sample,
154
+ add_generation_prompt=True
155
+ )
156
+
157
+ try:
158
+ torch.cuda.empty_cache()
159
+
160
+ with torch.cuda.amp.autocast():
161
+ model_inputs = processor(
162
+ text=text_input,
163
+ images=[[image]],
164
+ return_tensors="pt",
165
+ ).to("cuda")
166
+
167
+ generated_ids = model.generate(
168
+ **model_inputs,
169
+ max_new_tokens=256,
170
+ do_sample=True,
171
+ temperature=0.7,
172
+ top_p=0.9,
173
+ repetition_penalty=1.5,
174
+ no_repeat_ngram_size=3,
175
+ num_beams=2,
176
+ length_penalty=1.0,
177
+ early_stopping=True,
178
+ use_cache=True,
179
+ pad_token_id=processor.tokenizer.pad_token_id,
180
+ )
181
+
182
+ response_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
183
+ output_text = processor.decode(
184
+ response_ids,
185
+ skip_special_tokens=True,
186
+ clean_up_tokenization_spaces=True
187
+ ).strip()
188
+
189
+ if len(set(output_text.split())) < 5:
190
+ output_text = "Error: Generated response was too repetitive. Please try again."
191
+
192
+ del model_inputs, generated_ids, response_ids
193
+ torch.cuda.empty_cache()
194
+
195
+ return output_text
196
+
197
+ except Exception as e:
198
+ torch.cuda.empty_cache()
199
+ raise e
200
+
201
+ def analyze_cell(image, model_name):
202
+ """Main function to analyze cell images"""
203
+ if not isinstance(image, Image.Image):
204
+ return "Invalid image format. Please upload a valid image.", "", None
205
+
206
+ try:
207
+ if model_name not in loaded_models:
208
+ return f"Model {model_name} not loaded.", "", None
209
+
210
+ model_components = loaded_models[model_name]
211
+
212
+ if model_components["type"] == "smolvlm":
213
+ description = generate_description_smolvlm(
214
+ model_components["model"],
215
+ model_components["processor"],
216
+ image
217
+ )
218
+ else:
219
+ description = generate_description_standard(
220
+ model_components["model"],
221
+ model_components["tokenizer"],
222
+ image
223
+ )
224
+
225
+ cell_type = extract_cell_type(description)
226
+ return cell_type, description, image
227
+
228
+ except Exception as e:
229
+ return f"Error occurred: {str(e)}", "", None
230
+
231
+ # Initialize all models before starting the interface
232
+ initialize_models()
233
+
234
+ # Gradio Interface
235
+ with gr.Blocks() as iface:
236
+ gr.HTML("<h1>Blood Cell Analyzer</h1>")
237
+ gr.HTML("<p>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")
238
+
239
+ with gr.Row():
240
+ with gr.Column():
241
+ input_image = gr.Image(
242
+ label="Upload Blood Cell Image",
243
+ type="pil",
244
+ sources=["upload"]
245
+ )
246
+ model_dropdown = gr.Dropdown(
247
+ choices=list(MODELS.keys()),
248
+ value=list(MODELS.keys())[0],
249
+ label="Select Model Version"
250
+ )
251
+ submit_btn = gr.Button("Analyze Cell")
252
+
253
+ with gr.Column():
254
+ cell_type = gr.Textbox(label="Identified Cell Type")
255
+ description = gr.Textbox(label="Analysis Details", lines=8)
256
+ output_image = gr.Image(label="Analyzed Image")
257
+
258
+ submit_btn.click(
259
+ fn=analyze_cell,
260
+ inputs=[input_image, model_dropdown],
261
+ outputs=[cell_type, description, output_image]
262
+ )
263
+
264
+ # Enhanced CSS with modern color scheme
265
+ custom_css = """
266
+ .container {
267
+ max-width: 1000px;
268
+ margin: auto;
269
+ padding: 30px;
270
+ background: linear-gradient(135deg, #f6f9fc 0%, #ffffff 100%);
271
+ border-radius: 20px;
272
+ box-shadow: 0 10px 20px rgba(0,0,0,0.05);
273
+ }
274
+ .title {
275
+ text-align: center;
276
+ color: #2d3436;
277
+ font-size: 3em;
278
+ font-weight: 700;
279
+ margin-bottom: 20px;
280
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
281
+ }
282
+ .subtitle {
283
+ text-align: center;
284
+ color: #636e72;
285
+ font-size: 1.2em;
286
+ margin-bottom: 40px;
287
+ }
288
+ .input-image {
289
+ border: 2px dashed #74b9ff;
290
+ border-radius: 15px;
291
+ padding: 20px;
292
+ transition: all 0.3s ease;
293
+ }
294
+ .input-image:hover {
295
+ border-color: #0984e3;
296
+ transform: translateY(-2px);
297
+ }
298
+ .model-dropdown {
299
+ background: #f8f9fa;
300
+ border-radius: 10px;
301
+ border: 1px solid #dfe6e9;
302
+ margin: 15px 0;
303
+ }
304
+ .submit-button {
305
+ background: linear-gradient(45deg, #0984e3, #74b9ff);
306
+ color: white;
307
+ border: none;
308
+ padding: 12px 25px;
309
+ border-radius: 10px;
310
+ font-weight: 600;
311
+ transition: all 0.3s ease;
312
+ }
313
+ .submit-button:hover {
314
+ transform: translateY(-2px);
315
+ box-shadow: 0 5px 15px rgba(9, 132, 227, 0.3);
316
+ }
317
+ .result-box {
318
+ background: white;
319
+ border-radius: 10px;
320
+ border: 1px solid #dfe6e9;
321
+ padding: 15px;
322
+ margin: 10px 0;
323
+ }
324
+ .output-image {
325
+ border-radius: 15px;
326
+ overflow: hidden;
327
+ box-shadow: 0 5px 15px rgba(0,0,0,0.1);
328
+ }
329
+ """
330
+ # Interface
331
+ with gr.Blocks(css=custom_css) as iface:
332
+ gr.HTML("<h1 class='title'>Blood Cell Classifier</h1>")
333
+ gr.HTML("<p class='subtitle'>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")
334
+ with gr.Row():
335
+ with gr.Column():
336
+ input_image = gr.Image(
337
+ label="Upload Blood Cell Image",
338
+ type="pil",
339
+ sources=["upload"], # Only allow computer uploads
340
+ elem_classes="input-image"
341
+ )
342
+ model_dropdown = gr.Dropdown(
343
+ choices=list(MODELS.keys()),
344
+ value=list(MODELS.keys())[0],
345
+ label="Select Model Version",
346
+ elem_classes="model-dropdown"
347
+ )
348
+ submit_btn = gr.Button(
349
+ "Analyze Cell",
350
+ variant="primary",
351
+ elem_classes="submit-button"
352
+ )
353
+ with gr.Column():
354
+ cell_type = gr.Textbox(
355
+ label="Identified Cell Type",
356
+ elem_classes="result-box"
357
+ )
358
+ description = gr.Textbox(
359
+ label="Analysis Details",
360
+ lines=8,
361
+ elem_classes="result-box"
362
+ )
363
+ output_image = gr.Image(
364
+ label="Analyzed Image",
365
+ elem_classes="output-image"
366
+ )
367
+ submit_btn.click(
368
+ fn=analyze_cell,
369
+ inputs=[input_image, model_dropdown],
370
+ outputs=[cell_type, description, output_image]
371
+ )
372
+ gr.HTML("""
373
+ <div style="text-align: center; margin-top: 30px; padding: 20px;">
374
+ <p style="color: #636e72;">Developed by Laura Ruiz | MSc Bioinformatics and Biostatistics</p>
375
+ <a href="https://github.com/laurru01" target="_blank"
376
+ style="color: #0984e3; text-decoration: none; font-weight: 600;">
377
+ View on GitHub
378
+ </a>
379
+ </div>
380
+ """)
381
+
382
+ # Launch the interface
383
+ iface.launch()