Spaces:
Runtime error
Runtime error
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Python imports
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import json
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
|
7 |
+
# Data processing and visualization
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
from tqdm.notebook import tqdm
|
11 |
+
|
12 |
+
# Deep Learning & ML
|
13 |
+
import torch
|
14 |
+
from transformers import (
|
15 |
+
AutoProcessor,
|
16 |
+
AutoModelForVision2Seq,
|
17 |
+
AutoTokenizer,
|
18 |
+
AutoModelForCausalLM,
|
19 |
+
TextStreamer,
|
20 |
+
Idefics3ForConditionalGeneration,
|
21 |
+
BitsAndBytesConfig
|
22 |
+
|
23 |
+
)
|
24 |
+
|
25 |
+
from unsloth import FastVisionModel
|
26 |
+
|
27 |
+
# Dataset handling
|
28 |
+
from datasets import load_from_disk
|
29 |
+
|
30 |
+
# API & Authentication
|
31 |
+
from huggingface_hub import login
|
32 |
+
|
33 |
+
# UI & Environment
|
34 |
+
import gradio as gr
|
35 |
+
from dotenv import load_dotenv
|
36 |
+
|
37 |
+
# Available models
|
38 |
+
MODELS = {
|
39 |
+
"Blood Cell Classifier with Llama-3.2": "laurru01/Llama-3.2-11B-Vision-Instruct-ft-PeripherallBloodCells",
|
40 |
+
"Blood Cell Classifier with Qwen2-VL": "laurru01/Qwen2-VL-2B-Instruct-ft-bloodcells-big",
|
41 |
+
"Blood Cell Classifier with SmolVLM": "laurru01/SmolVLM-Instruct-ft-PeripherallBloodCells",
|
42 |
+
}
|
43 |
+
|
44 |
+
# Global dictionary to store loaded models
|
45 |
+
loaded_models = {}
|
46 |
+
|
47 |
+
def initialize_models():
|
48 |
+
"""Preload all models during startup"""
|
49 |
+
print("Initializing models...")
|
50 |
+
for model_name, model_path in MODELS.items():
|
51 |
+
print(f"Loading {model_name}...")
|
52 |
+
try:
|
53 |
+
if "SmolVLM" in model_name:
|
54 |
+
# Carga específica para SmolVLM
|
55 |
+
base_model = Idefics3ForConditionalGeneration.from_pretrained(
|
56 |
+
"HuggingFaceTB/SmolVLM-Instruct",
|
57 |
+
device_map="auto",
|
58 |
+
torch_dtype=torch.bfloat16,
|
59 |
+
load_in_4bit=True,
|
60 |
+
max_memory={0: "12GB"}
|
61 |
+
)
|
62 |
+
base_model.load_adapter(model_path)
|
63 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
|
64 |
+
|
65 |
+
loaded_models[model_name] = {
|
66 |
+
"model": base_model,
|
67 |
+
"processor": processor,
|
68 |
+
"type": "smolvlm"
|
69 |
+
}
|
70 |
+
else:
|
71 |
+
# Carga original para Llama y Qwen (sin cambios)
|
72 |
+
model, tokenizer = FastVisionModel.from_pretrained(
|
73 |
+
model_name=model_path,
|
74 |
+
load_in_4bit=True,
|
75 |
+
use_gradient_checkpointing="unsloth"
|
76 |
+
)
|
77 |
+
FastVisionModel.for_inference(model)
|
78 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
79 |
+
|
80 |
+
loaded_models[model_name] = {
|
81 |
+
"model": model,
|
82 |
+
"tokenizer": tokenizer,
|
83 |
+
"processor": processor,
|
84 |
+
"type": "standard"
|
85 |
+
}
|
86 |
+
print(f"Successfully loaded {model_name}")
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error loading {model_name}: {str(e)}")
|
90 |
+
|
91 |
+
print("Model initialization complete")
|
92 |
+
|
93 |
+
def extract_cell_type(text):
|
94 |
+
"""Extract cell type from generated description"""
|
95 |
+
cell_types = ['neutrophil', 'lymphocyte', 'monocyte', 'eosinophil', 'basophil']
|
96 |
+
text_lower = text.lower()
|
97 |
+
for cell_type in cell_types:
|
98 |
+
if cell_type in text_lower:
|
99 |
+
return cell_type.capitalize()
|
100 |
+
return "Unidentified Cell Type"
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
def generate_description_standard(model, tokenizer, image):
|
104 |
+
"""Generate description using standard models (Llama and Qwen)"""
|
105 |
+
messages = [{
|
106 |
+
"role": "user",
|
107 |
+
"content": [
|
108 |
+
{"type": "image"},
|
109 |
+
{"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
|
110 |
+
]}]
|
111 |
+
|
112 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
113 |
+
inputs = tokenizer(image, input_text, add_special_tokens=False, return_tensors="pt").to("cuda")
|
114 |
+
|
115 |
+
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
|
116 |
+
output = model.generate(
|
117 |
+
**inputs,
|
118 |
+
streamer=text_streamer,
|
119 |
+
max_new_tokens=1024,
|
120 |
+
use_cache=True,
|
121 |
+
temperature=1.5,
|
122 |
+
min_p=0.1
|
123 |
+
)
|
124 |
+
|
125 |
+
raw_output = tokenizer.decode(output[0], skip_special_tokens=True)
|
126 |
+
if "The provided image" in raw_output:
|
127 |
+
start_idx = raw_output.find("assistant")
|
128 |
+
cleaned_output = raw_output[start_idx:]
|
129 |
+
else:
|
130 |
+
cleaned_output = raw_output
|
131 |
+
|
132 |
+
return cleaned_output.strip()
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def generate_description_smolvlm(model, processor, image):
|
136 |
+
"""Generate description using SmolVLM model with memory-efficient settings"""
|
137 |
+
if image.mode != "RGB":
|
138 |
+
image = image.convert("RGB")
|
139 |
+
|
140 |
+
# Redimensionar a un tamaño más pequeño para reducir memoria
|
141 |
+
max_size = 192
|
142 |
+
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
143 |
+
|
144 |
+
sample = [{
|
145 |
+
"role": "user",
|
146 |
+
"content": [
|
147 |
+
{"type": "image", "image": image},
|
148 |
+
{"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
|
149 |
+
]
|
150 |
+
}]
|
151 |
+
|
152 |
+
text_input = processor.apply_chat_template(
|
153 |
+
sample,
|
154 |
+
add_generation_prompt=True
|
155 |
+
)
|
156 |
+
|
157 |
+
try:
|
158 |
+
torch.cuda.empty_cache()
|
159 |
+
|
160 |
+
with torch.cuda.amp.autocast():
|
161 |
+
model_inputs = processor(
|
162 |
+
text=text_input,
|
163 |
+
images=[[image]],
|
164 |
+
return_tensors="pt",
|
165 |
+
).to("cuda")
|
166 |
+
|
167 |
+
generated_ids = model.generate(
|
168 |
+
**model_inputs,
|
169 |
+
max_new_tokens=256,
|
170 |
+
do_sample=True,
|
171 |
+
temperature=0.7,
|
172 |
+
top_p=0.9,
|
173 |
+
repetition_penalty=1.5,
|
174 |
+
no_repeat_ngram_size=3,
|
175 |
+
num_beams=2,
|
176 |
+
length_penalty=1.0,
|
177 |
+
early_stopping=True,
|
178 |
+
use_cache=True,
|
179 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
180 |
+
)
|
181 |
+
|
182 |
+
response_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
|
183 |
+
output_text = processor.decode(
|
184 |
+
response_ids,
|
185 |
+
skip_special_tokens=True,
|
186 |
+
clean_up_tokenization_spaces=True
|
187 |
+
).strip()
|
188 |
+
|
189 |
+
if len(set(output_text.split())) < 5:
|
190 |
+
output_text = "Error: Generated response was too repetitive. Please try again."
|
191 |
+
|
192 |
+
del model_inputs, generated_ids, response_ids
|
193 |
+
torch.cuda.empty_cache()
|
194 |
+
|
195 |
+
return output_text
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
raise e
|
200 |
+
|
201 |
+
def analyze_cell(image, model_name):
|
202 |
+
"""Main function to analyze cell images"""
|
203 |
+
if not isinstance(image, Image.Image):
|
204 |
+
return "Invalid image format. Please upload a valid image.", "", None
|
205 |
+
|
206 |
+
try:
|
207 |
+
if model_name not in loaded_models:
|
208 |
+
return f"Model {model_name} not loaded.", "", None
|
209 |
+
|
210 |
+
model_components = loaded_models[model_name]
|
211 |
+
|
212 |
+
if model_components["type"] == "smolvlm":
|
213 |
+
description = generate_description_smolvlm(
|
214 |
+
model_components["model"],
|
215 |
+
model_components["processor"],
|
216 |
+
image
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
description = generate_description_standard(
|
220 |
+
model_components["model"],
|
221 |
+
model_components["tokenizer"],
|
222 |
+
image
|
223 |
+
)
|
224 |
+
|
225 |
+
cell_type = extract_cell_type(description)
|
226 |
+
return cell_type, description, image
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
return f"Error occurred: {str(e)}", "", None
|
230 |
+
|
231 |
+
# Initialize all models before starting the interface
|
232 |
+
initialize_models()
|
233 |
+
|
234 |
+
# Gradio Interface
|
235 |
+
with gr.Blocks() as iface:
|
236 |
+
gr.HTML("<h1>Blood Cell Analyzer</h1>")
|
237 |
+
gr.HTML("<p>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
with gr.Column():
|
241 |
+
input_image = gr.Image(
|
242 |
+
label="Upload Blood Cell Image",
|
243 |
+
type="pil",
|
244 |
+
sources=["upload"]
|
245 |
+
)
|
246 |
+
model_dropdown = gr.Dropdown(
|
247 |
+
choices=list(MODELS.keys()),
|
248 |
+
value=list(MODELS.keys())[0],
|
249 |
+
label="Select Model Version"
|
250 |
+
)
|
251 |
+
submit_btn = gr.Button("Analyze Cell")
|
252 |
+
|
253 |
+
with gr.Column():
|
254 |
+
cell_type = gr.Textbox(label="Identified Cell Type")
|
255 |
+
description = gr.Textbox(label="Analysis Details", lines=8)
|
256 |
+
output_image = gr.Image(label="Analyzed Image")
|
257 |
+
|
258 |
+
submit_btn.click(
|
259 |
+
fn=analyze_cell,
|
260 |
+
inputs=[input_image, model_dropdown],
|
261 |
+
outputs=[cell_type, description, output_image]
|
262 |
+
)
|
263 |
+
|
264 |
+
# Enhanced CSS with modern color scheme
|
265 |
+
custom_css = """
|
266 |
+
.container {
|
267 |
+
max-width: 1000px;
|
268 |
+
margin: auto;
|
269 |
+
padding: 30px;
|
270 |
+
background: linear-gradient(135deg, #f6f9fc 0%, #ffffff 100%);
|
271 |
+
border-radius: 20px;
|
272 |
+
box-shadow: 0 10px 20px rgba(0,0,0,0.05);
|
273 |
+
}
|
274 |
+
.title {
|
275 |
+
text-align: center;
|
276 |
+
color: #2d3436;
|
277 |
+
font-size: 3em;
|
278 |
+
font-weight: 700;
|
279 |
+
margin-bottom: 20px;
|
280 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
|
281 |
+
}
|
282 |
+
.subtitle {
|
283 |
+
text-align: center;
|
284 |
+
color: #636e72;
|
285 |
+
font-size: 1.2em;
|
286 |
+
margin-bottom: 40px;
|
287 |
+
}
|
288 |
+
.input-image {
|
289 |
+
border: 2px dashed #74b9ff;
|
290 |
+
border-radius: 15px;
|
291 |
+
padding: 20px;
|
292 |
+
transition: all 0.3s ease;
|
293 |
+
}
|
294 |
+
.input-image:hover {
|
295 |
+
border-color: #0984e3;
|
296 |
+
transform: translateY(-2px);
|
297 |
+
}
|
298 |
+
.model-dropdown {
|
299 |
+
background: #f8f9fa;
|
300 |
+
border-radius: 10px;
|
301 |
+
border: 1px solid #dfe6e9;
|
302 |
+
margin: 15px 0;
|
303 |
+
}
|
304 |
+
.submit-button {
|
305 |
+
background: linear-gradient(45deg, #0984e3, #74b9ff);
|
306 |
+
color: white;
|
307 |
+
border: none;
|
308 |
+
padding: 12px 25px;
|
309 |
+
border-radius: 10px;
|
310 |
+
font-weight: 600;
|
311 |
+
transition: all 0.3s ease;
|
312 |
+
}
|
313 |
+
.submit-button:hover {
|
314 |
+
transform: translateY(-2px);
|
315 |
+
box-shadow: 0 5px 15px rgba(9, 132, 227, 0.3);
|
316 |
+
}
|
317 |
+
.result-box {
|
318 |
+
background: white;
|
319 |
+
border-radius: 10px;
|
320 |
+
border: 1px solid #dfe6e9;
|
321 |
+
padding: 15px;
|
322 |
+
margin: 10px 0;
|
323 |
+
}
|
324 |
+
.output-image {
|
325 |
+
border-radius: 15px;
|
326 |
+
overflow: hidden;
|
327 |
+
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
328 |
+
}
|
329 |
+
"""
|
330 |
+
# Interface
|
331 |
+
with gr.Blocks(css=custom_css) as iface:
|
332 |
+
gr.HTML("<h1 class='title'>Blood Cell Classifier</h1>")
|
333 |
+
gr.HTML("<p class='subtitle'>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")
|
334 |
+
with gr.Row():
|
335 |
+
with gr.Column():
|
336 |
+
input_image = gr.Image(
|
337 |
+
label="Upload Blood Cell Image",
|
338 |
+
type="pil",
|
339 |
+
sources=["upload"], # Only allow computer uploads
|
340 |
+
elem_classes="input-image"
|
341 |
+
)
|
342 |
+
model_dropdown = gr.Dropdown(
|
343 |
+
choices=list(MODELS.keys()),
|
344 |
+
value=list(MODELS.keys())[0],
|
345 |
+
label="Select Model Version",
|
346 |
+
elem_classes="model-dropdown"
|
347 |
+
)
|
348 |
+
submit_btn = gr.Button(
|
349 |
+
"Analyze Cell",
|
350 |
+
variant="primary",
|
351 |
+
elem_classes="submit-button"
|
352 |
+
)
|
353 |
+
with gr.Column():
|
354 |
+
cell_type = gr.Textbox(
|
355 |
+
label="Identified Cell Type",
|
356 |
+
elem_classes="result-box"
|
357 |
+
)
|
358 |
+
description = gr.Textbox(
|
359 |
+
label="Analysis Details",
|
360 |
+
lines=8,
|
361 |
+
elem_classes="result-box"
|
362 |
+
)
|
363 |
+
output_image = gr.Image(
|
364 |
+
label="Analyzed Image",
|
365 |
+
elem_classes="output-image"
|
366 |
+
)
|
367 |
+
submit_btn.click(
|
368 |
+
fn=analyze_cell,
|
369 |
+
inputs=[input_image, model_dropdown],
|
370 |
+
outputs=[cell_type, description, output_image]
|
371 |
+
)
|
372 |
+
gr.HTML("""
|
373 |
+
<div style="text-align: center; margin-top: 30px; padding: 20px;">
|
374 |
+
<p style="color: #636e72;">Developed by Laura Ruiz | MSc Bioinformatics and Biostatistics</p>
|
375 |
+
<a href="https://github.com/laurru01" target="_blank"
|
376 |
+
style="color: #0984e3; text-decoration: none; font-weight: 600;">
|
377 |
+
View on GitHub
|
378 |
+
</a>
|
379 |
+
</div>
|
380 |
+
""")
|
381 |
+
|
382 |
+
# Launch the interface
|
383 |
+
iface.launch()
|