RezinWiz commited on
Commit
4721ef4
·
verified ·
1 Parent(s): dc04619

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +729 -744
api.py CHANGED
@@ -1,744 +1,729 @@
1
- """
2
- EASI Severity Prediction REST API
3
- ==================================
4
-
5
- FastAPI-based REST API for predicting EASI scores from dermatological images.
6
- Designed for integration with Flutter mobile applications.
7
-
8
- Endpoints:
9
- - POST /predict - Upload image and get EASI predictions
10
- - GET /health - Health check endpoint
11
- - GET /conditions - Get list of available conditions
12
-
13
- Installation:
14
- pip install fastapi uvicorn python-multipart pillow tensorflow numpy pandas huggingface-hub requests
15
-
16
- Run:
17
- uvicorn api:app --host 0.0.0.0 --port 8000 --reload
18
- """
19
-
20
- import os
21
- import warnings
22
- import logging
23
- from typing import List, Dict, Any, Optional
24
- from io import BytesIO
25
- from pathlib import Path
26
-
27
- # Suppress warnings
28
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
29
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
30
- os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = ''
31
- warnings.filterwarnings('ignore')
32
- logging.getLogger('absl').setLevel(logging.ERROR)
33
-
34
- import tensorflow as tf
35
- tf.get_logger().setLevel('ERROR')
36
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
37
-
38
- from fastapi import FastAPI, File, UploadFile, HTTPException, status
39
- from fastapi.middleware.cors import CORSMiddleware
40
- from fastapi.responses import JSONResponse
41
- from pydantic import BaseModel, Field
42
- import numpy as np
43
- from PIL import Image
44
- import pickle
45
- import pandas as pd
46
- import requests
47
- from huggingface_hub import hf_hub_download, login
48
-
49
- # Initialize FastAPI app
50
- app = FastAPI(
51
- title="EASI Severity Prediction API",
52
- description="REST API for predicting EASI scores from skin images",
53
- version="1.0.0"
54
- )
55
-
56
- # CORS middleware for Flutter web/mobile
57
- app.add_middleware(
58
- CORSMiddleware,
59
- allow_origins=["*"], # In production, specify your Flutter app domain
60
- allow_credentials=True,
61
- allow_methods=["*"],
62
- allow_headers=["*"],
63
- )
64
-
65
- # Configuration
66
- HF_REPO_ID = "google/derm-foundation"
67
- DERM_FOUNDATION_PATH = "./derm_foundation/"
68
- R2_BASE_URL = os.environ.get("R2_BASE_URL", "https://r2-worker.eczemanage.workers.dev")
69
-
70
- # Get Hugging Face token from environment variable
71
- HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
72
-
73
- # Response Models
74
- class ConditionPrediction(BaseModel):
75
- condition: str
76
- probability: float = Field(..., ge=0, le=1)
77
- confidence: float = Field(..., ge=0)
78
- weight: float = Field(..., ge=0)
79
- easi_category: Optional[str] = None
80
- easi_contribution: int = Field(..., ge=0, le=3)
81
-
82
- class EASIComponent(BaseModel):
83
- name: str
84
- score: int = Field(..., ge=0, le=3)
85
- contributing_conditions: List[Dict[str, Any]]
86
-
87
- class PredictionResponse(BaseModel):
88
- success: bool
89
- total_easi_score: int = Field(..., ge=0, le=12)
90
- severity_interpretation: str
91
- easi_components: Dict[str, EASIComponent]
92
- predicted_conditions: List[ConditionPrediction]
93
- summary_statistics: Dict[str, float]
94
- image_info: Dict[str, Any]
95
-
96
- class HealthResponse(BaseModel):
97
- status: str
98
- models_loaded: Dict[str, bool]
99
- available_conditions: int
100
- hf_token_configured: bool
101
- model_source: str
102
-
103
- class ErrorResponse(BaseModel):
104
- success: bool = False
105
- error: str
106
- detail: Optional[str] = None
107
-
108
-
109
- # Model wrapper class
110
- class DermFoundationNeuralNetwork:
111
- def __init__(self):
112
- self.model = None
113
- self.mlb = None
114
- self.embedding_scaler = None
115
- self.confidence_scaler = None
116
- self.weighted_scaler = None
117
-
118
- def load_model(self, filepath):
119
- try:
120
- with open(filepath, 'rb') as f:
121
- model_data = pickle.load(f)
122
-
123
- self.mlb = model_data['mlb']
124
- self.embedding_scaler = model_data['embedding_scaler']
125
- self.confidence_scaler = model_data['confidence_scaler']
126
- self.weighted_scaler = model_data['weighted_scaler']
127
-
128
- keras_model_path = model_data['keras_model_path']
129
- if os.path.exists(keras_model_path):
130
- self.model = tf.keras.models.load_model(keras_model_path)
131
- return True
132
- else:
133
- return False
134
- except Exception as e:
135
- print(f"Error loading model: {e}")
136
- return False
137
-
138
- def predict(self, embedding):
139
- if self.model is None:
140
- return None
141
-
142
- if len(embedding.shape) == 1:
143
- embedding = embedding.reshape(1, -1)
144
-
145
- embedding_scaled = self.embedding_scaler.transform(embedding)
146
- predictions = self.model.predict(embedding_scaled, verbose=0)
147
-
148
- condition_probs = predictions['conditions'][0]
149
- individual_confidences = predictions['individual_confidences'][0]
150
- individual_weights = predictions['individual_weights'][0]
151
-
152
- condition_threshold = 0.3
153
- predicted_condition_indices = np.where(condition_probs > condition_threshold)[0]
154
-
155
- predicted_conditions = []
156
- predicted_confidences = []
157
- predicted_weights_dict = {}
158
-
159
- for idx in predicted_condition_indices:
160
- condition_name = self.mlb.classes_[idx]
161
- condition_prob = float(condition_probs[idx])
162
-
163
- if individual_confidences[idx] > 0:
164
- confidence_orig = self.confidence_scaler.inverse_transform([[individual_confidences[idx]]])[0, 0]
165
- else:
166
- confidence_orig = 0.0
167
-
168
- if individual_weights[idx] > 0:
169
- weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[idx]]])[0, 0]
170
- else:
171
- weight_orig = 0.0
172
-
173
- predicted_conditions.append(condition_name)
174
- predicted_confidences.append(max(0, confidence_orig))
175
- predicted_weights_dict[condition_name] = max(0, weight_orig)
176
-
177
- all_condition_probs = {}
178
- all_confidences = {}
179
- all_weights = {}
180
-
181
- for i, class_name in enumerate(self.mlb.classes_):
182
- all_condition_probs[class_name] = float(condition_probs[i])
183
-
184
- if individual_confidences[i] > 0:
185
- conf_orig = self.confidence_scaler.inverse_transform([[individual_confidences[i]]])[0, 0]
186
- all_confidences[class_name] = max(0, conf_orig)
187
- else:
188
- all_confidences[class_name] = 0.0
189
-
190
- if individual_weights[i] > 0:
191
- weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[i]]])[0, 0]
192
- all_weights[class_name] = max(0, weight_orig)
193
- else:
194
- all_weights[class_name] = 0.0
195
-
196
- return {
197
- 'dermatologist_skin_condition_on_label_name': predicted_conditions,
198
- 'dermatologist_skin_condition_confidence': predicted_confidences,
199
- 'weighted_skin_condition_label': predicted_weights_dict,
200
- 'all_condition_probabilities': all_condition_probs,
201
- 'all_individual_confidences': all_confidences,
202
- 'all_individual_weights': all_weights,
203
- 'condition_threshold': condition_threshold
204
- }
205
-
206
-
207
- # Helper function to download from Cloudflare R2 with chunked streaming
208
- def download_derm_foundation_from_r2(output_dir):
209
- """Download Derm Foundation model from Cloudflare R2 using memory-efficient streaming"""
210
- try:
211
- print(f"Downloading Derm Foundation model from R2 ({R2_BASE_URL})...")
212
- os.makedirs(output_dir, exist_ok=True)
213
-
214
- # Files to download
215
- files_to_download = [
216
- "saved_model.pb",
217
- "variables/variables.index",
218
- "variables/variables.data-00000-of-00001"
219
- ]
220
-
221
- for file_path in files_to_download:
222
- print(f"Downloading {file_path}...")
223
- url = f"{R2_BASE_URL}/{file_path}"
224
- local_path = os.path.join(output_dir, file_path)
225
-
226
- # Create subdirectories if needed
227
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
228
-
229
- # Download file with streaming (ULTRA MEMORY EFFICIENT)
230
- # Use tiny chunk size and aggressive garbage collection
231
- import gc
232
-
233
- with requests.get(url, stream=True, timeout=900) as response:
234
- response.raise_for_status()
235
-
236
- total_size = int(response.headers.get('content-length', 0))
237
- downloaded = 0
238
- chunk_count = 0
239
-
240
- # Write directly to disk in tiny chunks (256KB to minimize memory)
241
- with open(local_path, 'wb') as f:
242
- for chunk in response.iter_content(chunk_size=256*1024): # 256KB chunks
243
- if chunk:
244
- f.write(chunk)
245
- f.flush() # Force write to disk
246
- downloaded += len(chunk)
247
- chunk_count += 1
248
-
249
- # Aggressive garbage collection every 10 chunks (~2.5MB)
250
- if chunk_count % 10 == 0:
251
- gc.collect()
252
-
253
- # Less frequent progress updates to reduce print overhead
254
- if total_size > 0 and chunk_count % 20 == 0:
255
- progress = (downloaded / total_size) * 100
256
- mb_downloaded = downloaded / (1024*1024)
257
- mb_total = total_size / (1024*1024)
258
- print(f" Progress: {progress:.1f}% ({mb_downloaded:.1f}/{mb_total:.1f} MB)")
259
-
260
- print() # New line after progress
261
- gc.collect() # Final cleanup
262
-
263
- print(f"✓ Downloaded: {file_path}")
264
-
265
- print(f"✓ Derm Foundation model downloaded successfully from R2")
266
- return True
267
- except Exception as e:
268
- print(f"✗ Error downloading from R2: {e}")
269
- import traceback
270
- traceback.print_exc()
271
- return False
272
-
273
-
274
- # Helper function to download from Hugging Face (Fallback) with memory-efficient streaming
275
- def download_derm_foundation_from_hf(output_dir):
276
- """Download Derm Foundation model from Hugging Face using memory-efficient streaming"""
277
- try:
278
- # Login to Hugging Face if token is available
279
- if HF_TOKEN:
280
- print("Authenticating with Hugging Face...")
281
- login(token=HF_TOKEN)
282
- else:
283
- print("WARNING: No HF token found. Attempting download without authentication...")
284
-
285
- print(f"Downloading Derm Foundation model from Hugging Face...")
286
- os.makedirs(output_dir, exist_ok=True)
287
-
288
- # Files to download
289
- files_to_download = [
290
- "saved_model.pb",
291
- "variables/variables.data-00000-of-00001",
292
- "variables/variables.index"
293
- ]
294
-
295
- for file_path in files_to_download:
296
- print(f"Downloading {file_path}...")
297
- local_path = os.path.join(output_dir, file_path)
298
-
299
- # Create subdirectories if needed
300
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
301
-
302
- # Download file with token if available
303
- # hf_hub_download handles streaming internally
304
- downloaded_path = hf_hub_download(
305
- repo_id=HF_REPO_ID,
306
- filename=file_path,
307
- token=HF_TOKEN,
308
- cache_dir=None,
309
- local_dir=output_dir,
310
- local_dir_use_symlinks=False,
311
- resume_download=True # Resume if interrupted
312
- )
313
- print(f"✓ Downloaded: {file_path}")
314
-
315
- print(f"✓ Derm Foundation model downloaded successfully from HuggingFace")
316
- return True
317
- except Exception as e:
318
- print(f"✗ Error downloading from Hugging Face: {e}")
319
- print(f"Make sure HUGGINGFACE_TOKEN is set in Render environment variables")
320
- import traceback
321
- traceback.print_exc()
322
- return False
323
-
324
-
325
- # EASI calculation functions
326
- def calculate_easi_scores(predictions):
327
- easi_categories = {
328
- 'erythema': {
329
- 'name': 'Erythema (Redness)',
330
- 'conditions': [
331
- 'Post-Inflammatory hyperpigmentation', 'Erythema ab igne', 'Erythema annulare centrifugum',
332
- 'Erythema elevatum diutinum', 'Erythema gyratum repens', 'Erythema multiforme',
333
- 'Erythema nodosum', 'Flagellate erythema', 'Annular erythema', 'Drug Rash',
334
- 'Allergic Contact Dermatitis', 'Irritant Contact Dermatitis', 'Contact dermatitis',
335
- 'Acute dermatitis', 'Chronic dermatitis', 'Acute and chronic dermatitis',
336
- 'Sunburn', 'Photodermatitis', 'Phytophotodermatitis', 'Rosacea',
337
- 'Seborrheic Dermatitis', 'Stasis Dermatitis', 'Perioral Dermatitis',
338
- 'Burn erythema of abdominal wall', 'Burn erythema of back of hand',
339
- 'Burn erythema of lower leg', 'Cellulitis', 'Infection of skin',
340
- 'Viral Exanthem', 'Infected eczema', 'Crusted eczematous dermatitis',
341
- 'Inflammatory dermatosis', 'Vasculitis of the skin', 'Leukocytoclastic Vasculitis',
342
- 'Cutaneous lupus', 'CD - Contact dermatitis', 'Acute dermatitis, NOS',
343
- 'Herpes Simplex', 'Hypersensitivity', 'Impetigo', 'Pigmented purpuric eruption',
344
- 'Pityriasis rosea', 'Tinea', 'Tinea Versicolor'
345
- ]
346
- },
347
- 'induration': {
348
- 'name': 'Induration/Papulation (Swelling/Bumps)',
349
- 'conditions': [
350
- 'Prurigo nodularis', 'Urticaria', 'Granuloma annulare', 'Morphea',
351
- 'Scleroderma', 'Lichen Simplex Chronicus', 'Lichen planus', 'lichenoid eruption',
352
- 'Lichen nitidus', 'Lichen spinulosus', 'Lichen striatus', 'Keratosis pilaris',
353
- 'Molluscum Contagiosum', 'Verruca vulgaris', 'Folliculitis', 'Acne',
354
- 'Hidradenitis', 'Nodular vasculitis', 'Sweet syndrome', 'Necrobiosis lipoidica',
355
- 'Basal Cell Carcinoma', 'SCC', 'SCCIS', 'SK', 'ISK',
356
- 'Cutaneous T Cell Lymphoma', 'Skin cancer', 'Adnexal neoplasm',
357
- 'Insect Bite', 'Milia', 'Miliaria', 'Xanthoma', 'Psoriasis',
358
- 'Lichen planus/lichenoid eruption'
359
- ]
360
- },
361
- 'excoriation': {
362
- 'name': 'Excoriation (Scratching Damage)',
363
- 'conditions': [
364
- 'Inflicted skin lesions', 'Scabies', 'Abrasion', 'Abrasion of wrist',
365
- 'Superficial wound of body region', 'Scrape', 'Animal bite - wound',
366
- 'Pruritic dermatitis', 'Prurigo', 'Atopic dermatitis', 'Scab'
367
- ]
368
- },
369
- 'lichenification': {
370
- 'name': 'Lichenification (Skin Thickening)',
371
- 'conditions': [
372
- 'Lichenified eczematous dermatitis', 'Acanthosis nigricans',
373
- 'Hyperkeratosis of skin', 'HK - Hyperkeratosis', 'Keratoderma',
374
- 'Ichthyosis', 'Ichthyosiform dermatosis', 'Chronic eczema',
375
- 'Psoriasis', 'Xerosis'
376
- ]
377
- }
378
- }
379
-
380
- def probability_to_score(prob):
381
- if prob < 0.171:
382
- return 0
383
- elif prob < 0.238:
384
- return 1
385
- elif prob < 0.421:
386
- return 2
387
- elif prob < 0.614:
388
- return 3
389
- else:
390
- return 3
391
-
392
- easi_results = {}
393
- all_condition_probs = predictions['all_condition_probabilities']
394
-
395
- for component, category_info in easi_categories.items():
396
- category_conditions = []
397
-
398
- for condition_name, probability in all_condition_probs.items():
399
- if condition_name.lower() == 'eczema':
400
- continue
401
-
402
- if condition_name in category_info['conditions']:
403
- category_conditions.append({
404
- 'condition': condition_name,
405
- 'probability': probability,
406
- 'individual_score': probability_to_score(probability)
407
- })
408
-
409
- category_conditions = [c for c in category_conditions if c['individual_score'] > 0]
410
- category_conditions.sort(key=lambda x: x['probability'], reverse=True)
411
-
412
- component_score = sum(c['individual_score'] for c in category_conditions)
413
- component_score = min(component_score, 3)
414
-
415
- easi_results[component] = {
416
- 'name': category_info['name'],
417
- 'score': component_score,
418
- 'contributing_conditions': category_conditions
419
- }
420
-
421
- total_easi = sum(result['score'] for result in easi_results.values())
422
-
423
- return easi_results, total_easi
424
-
425
-
426
- def get_severity_interpretation(total_easi):
427
- if total_easi == 0:
428
- return "No significant EASI features detected"
429
- elif total_easi <= 3:
430
- return "Mild EASI severity"
431
- elif total_easi <= 6:
432
- return "Moderate EASI severity"
433
- elif total_easi <= 9:
434
- return "Severe EASI severity"
435
- else:
436
- return "Very Severe EASI severity"
437
-
438
-
439
- # Image processing functions
440
- def smart_crop_to_square(image):
441
- width, height = image.size
442
- if width == height:
443
- return image
444
-
445
- size = min(width, height)
446
- left = (width - size) // 2
447
- top = (height - size) // 2
448
- right = left + size
449
- bottom = top + size
450
-
451
- return image.crop((left, top, right, bottom))
452
-
453
-
454
- def generate_derm_foundation_embedding(model, image):
455
- try:
456
- if image.mode != 'RGB':
457
- image = image.convert('RGB')
458
-
459
- buf = BytesIO()
460
- image.save(buf, format='JPEG')
461
- image_bytes = buf.getvalue()
462
-
463
- input_tensor = tf.train.Example(features=tf.train.Features(
464
- feature={'image/encoded': tf.train.Feature(
465
- bytes_list=tf.train.BytesList(value=[image_bytes]))
466
- })).SerializeToString()
467
-
468
- infer = model.signatures["serving_default"]
469
- output = infer(inputs=tf.constant([input_tensor]))
470
-
471
- if 'embedding' in output:
472
- embedding_vector = output['embedding'].numpy().flatten()
473
- else:
474
- key = list(output.keys())[0]
475
- embedding_vector = output[key].numpy().flatten()
476
-
477
- return embedding_vector
478
- except Exception as e:
479
- raise HTTPException(status_code=500, detail=f"Error generating embedding: {str(e)}")
480
-
481
-
482
- # Global model instances
483
- derm_model = None
484
- easi_model = None
485
- model_source = "not_loaded"
486
-
487
-
488
- @app.on_event("startup")
489
- async def load_models():
490
- """Load models on startup"""
491
- global derm_model, easi_model, model_source
492
-
493
- # Force garbage collection before starting
494
- import gc
495
- gc.collect()
496
-
497
- # Check if model exists (should be pre-downloaded in Docker or already cached)
498
- if not os.path.exists(DERM_FOUNDATION_PATH) or not os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")):
499
- print("=" * 60)
500
- print("Derm Foundation model not found locally.")
501
- print("=" * 60)
502
-
503
- # Try R2 first (fast)
504
- print("\n[1/2] Attempting download from Cloudflare R2...")
505
- success = download_derm_foundation_from_r2(DERM_FOUNDATION_PATH)
506
-
507
- if success:
508
- model_source = "cloudflare_r2"
509
- else:
510
- # Fallback to HuggingFace
511
- print("\n[2/2] R2 failed, trying HuggingFace as fallback...")
512
-
513
- if not HF_TOKEN:
514
- print("=" * 60)
515
- print("WARNING: HUGGINGFACE_TOKEN environment variable not set!")
516
- print("Set it in Render Dashboard > Environment > Environment Variables")
517
- print("Variable name: HUGGINGFACE_TOKEN")
518
- print("Variable value: <your-hf-token>")
519
- print("=" * 60)
520
-
521
- success = download_derm_foundation_from_hf(DERM_FOUNDATION_PATH)
522
- if success:
523
- model_source = "huggingface"
524
- else:
525
- print("=" * 60)
526
- print("ERROR: Failed to download model from both R2 and HuggingFace!")
527
- print("=" * 60)
528
- model_source = "failed"
529
- else:
530
- print("✓ Derm Foundation model found locally (pre-downloaded or cached)")
531
- model_source = "local_cache"
532
-
533
- # Load Derm Foundation model
534
- if os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")):
535
- try:
536
- print(f"Loading Derm-Foundation model from: {DERM_FOUNDATION_PATH}")
537
- # Force garbage collection before loading large model
538
- gc.collect()
539
-
540
- derm_model = tf.saved_model.load(DERM_FOUNDATION_PATH)
541
- print(f" Derm-Foundation model loaded successfully (source: {model_source})")
542
-
543
- # Cleanup after loading
544
- gc.collect()
545
- except Exception as e:
546
- print(f"✗ Failed to load Derm Foundation model: {str(e)}")
547
-
548
- # Load EASI model (keep this local in your repo)
549
- model_path = './trained_model/easi_severity_model_derm_foundation_individual.pkl'
550
- if os.path.exists(model_path):
551
- easi_model = DermFoundationNeuralNetwork()
552
- success = easi_model.load_model(model_path)
553
- if success:
554
- print(f"✓ EASI model loaded from: {model_path}")
555
- else:
556
- print(f" Failed to load EASI model")
557
- easi_model = None
558
- else:
559
- print(f"✗ EASI model not found at: {model_path}")
560
-
561
- if derm_model is None or easi_model is None:
562
- print("=" * 60)
563
- print("WARNING: Some models failed to load!")
564
- print(f"Derm Foundation: {'✓' if derm_model else '✗'}")
565
- print(f"EASI Model: {'✓' if easi_model else '✗'}")
566
- print("=" * 60)
567
- else:
568
- print("=" * 60)
569
- print(" All models loaded successfully!")
570
- print(f"Model source: {model_source}")
571
- print("=" * 60)
572
-
573
-
574
- # API Endpoints
575
-
576
- @app.get("/")
577
- async def root():
578
- """Root endpoint"""
579
- return {
580
- "message": "EASI Severity Prediction API",
581
- "version": "1.0.0",
582
- "model_source": model_source,
583
- "docs": "/docs",
584
- "health": "/health",
585
- "predict": "/predict",
586
- "conditions": "/conditions"
587
- }
588
-
589
-
590
- @app.get("/health", response_model=HealthResponse)
591
- async def health_check():
592
- """Health check endpoint"""
593
- return {
594
- "status": "ok" if (derm_model is not None and easi_model is not None) else "degraded",
595
- "models_loaded": {
596
- "derm_foundation": derm_model is not None,
597
- "easi_model": easi_model is not None
598
- },
599
- "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0,
600
- "hf_token_configured": HF_TOKEN is not None,
601
- "model_source": model_source
602
- }
603
-
604
-
605
- @app.get("/conditions", response_model=Dict[str, List[str]])
606
- async def get_conditions():
607
- """Get list of available conditions"""
608
- if easi_model is None:
609
- raise HTTPException(status_code=503, detail="EASI model not loaded")
610
-
611
- return {
612
- "conditions": easi_model.mlb.classes_.tolist()
613
- }
614
-
615
-
616
- @app.post("/predict", response_model=PredictionResponse)
617
- async def predict_easi(
618
- file: UploadFile = File(..., description="Skin image file (JPG, JPEG, PNG)")
619
- ):
620
- """
621
- Predict EASI scores from uploaded skin image.
622
-
623
- - **file**: Image file (JPG, JPEG, PNG)
624
- - Returns: EASI scores, component breakdown, and condition predictions
625
- """
626
-
627
- # Validate models loaded
628
- if derm_model is None or easi_model is None:
629
- raise HTTPException(
630
- status_code=503,
631
- detail="Models not loaded. Check server logs."
632
- )
633
-
634
- # Validate file type
635
- if not file.content_type.startswith('image/'):
636
- raise HTTPException(
637
- status_code=400,
638
- detail="File must be an image (JPG, JPEG, PNG)"
639
- )
640
-
641
- try:
642
- # Read and process image
643
- image_bytes = await file.read()
644
- original_image = Image.open(BytesIO(image_bytes)).convert('RGB')
645
- original_size = original_image.size
646
-
647
- # Process to 448x448
648
- cropped_img = smart_crop_to_square(original_image)
649
- processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS)
650
-
651
- # Generate embedding
652
- embedding = generate_derm_foundation_embedding(derm_model, processed_img)
653
-
654
- # Make prediction
655
- predictions = easi_model.predict(embedding)
656
-
657
- if predictions is None:
658
- raise HTTPException(status_code=500, detail="Prediction failed")
659
-
660
- # Calculate EASI scores
661
- easi_results, total_easi = calculate_easi_scores(predictions)
662
- severity = get_severity_interpretation(total_easi)
663
-
664
- # Format predicted conditions
665
- predicted_conditions = []
666
- for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']):
667
- prob = predictions['all_condition_probabilities'][condition]
668
- conf = predictions['dermatologist_skin_condition_confidence'][i]
669
- weight = predictions['weighted_skin_condition_label'][condition]
670
-
671
- # Find EASI category
672
- easi_category = None
673
- easi_contribution = 0
674
- for cat_key, cat_info in easi_results.items():
675
- for contrib in cat_info['contributing_conditions']:
676
- if contrib['condition'] == condition:
677
- easi_category = cat_info['name']
678
- easi_contribution = contrib['individual_score']
679
- break
680
-
681
- predicted_conditions.append(ConditionPrediction(
682
- condition=condition,
683
- probability=float(prob),
684
- confidence=float(conf),
685
- weight=float(weight),
686
- easi_category=easi_category,
687
- easi_contribution=easi_contribution
688
- ))
689
-
690
- # Summary statistics
691
- summary_stats = {
692
- "total_conditions": len(predicted_conditions),
693
- "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0,
694
- "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0,
695
- "total_weight": float(sum(predictions['weighted_skin_condition_label'].values()))
696
- }
697
-
698
- # Format EASI components
699
- easi_components_formatted = {
700
- component: EASIComponent(
701
- name=result['name'],
702
- score=result['score'],
703
- contributing_conditions=result['contributing_conditions']
704
- )
705
- for component, result in easi_results.items()
706
- }
707
-
708
- return PredictionResponse(
709
- success=True,
710
- total_easi_score=total_easi,
711
- severity_interpretation=severity,
712
- easi_components=easi_components_formatted,
713
- predicted_conditions=predicted_conditions,
714
- summary_statistics=summary_stats,
715
- image_info={
716
- "original_size": f"{original_size[0]}x{original_size[1]}",
717
- "processed_size": "448x448",
718
- "filename": file.filename
719
- }
720
- )
721
-
722
- except HTTPException:
723
- raise
724
- except Exception as e:
725
- raise HTTPException(
726
- status_code=500,
727
- detail=f"Error processing image: {str(e)}"
728
- )
729
-
730
-
731
- @app.exception_handler(HTTPException)
732
- async def http_exception_handler(request, exc):
733
- return JSONResponse(
734
- status_code=exc.status_code,
735
- content=ErrorResponse(
736
- error=exc.detail,
737
- detail=str(exc)
738
- ).dict()
739
- )
740
-
741
-
742
- if __name__ == "__main__":
743
- import uvicorn
744
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ """
2
+ EASI Severity Prediction REST API
3
+ ==================================
4
+
5
+ FastAPI-based REST API for predicting EASI scores from dermatological images.
6
+ Designed for integration with Flutter mobile applications.
7
+
8
+ Endpoints:
9
+ - POST /predict - Upload image and get EASI predictions
10
+ - GET /health - Health check endpoint
11
+ - GET /conditions - Get list of available conditions
12
+
13
+ Installation:
14
+ pip install fastapi uvicorn python-multipart pillow tensorflow numpy pandas huggingface-hub requests
15
+
16
+ Run:
17
+ uvicorn api:app --host 0.0.0.0 --port 8000 --reload
18
+ """
19
+
20
+ import os
21
+ import warnings
22
+ import logging
23
+ from typing import List, Dict, Any, Optional
24
+ from io import BytesIO
25
+ from pathlib import Path
26
+
27
+ # Suppress warnings
28
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
29
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
30
+ os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = ''
31
+ warnings.filterwarnings('ignore')
32
+ logging.getLogger('absl').setLevel(logging.ERROR)
33
+
34
+ import tensorflow as tf
35
+ tf.get_logger().setLevel('ERROR')
36
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
37
+
38
+ from fastapi import FastAPI, File, UploadFile, HTTPException, status
39
+ from fastapi.middleware.cors import CORSMiddleware
40
+ from fastapi.responses import JSONResponse
41
+ from pydantic import BaseModel, Field
42
+ import numpy as np
43
+ from PIL import Image
44
+ import pickle
45
+ import pandas as pd
46
+ import requests
47
+ from huggingface_hub import hf_hub_download, login
48
+
49
+ # Initialize FastAPI app
50
+ app = FastAPI(
51
+ title="EASI Severity Prediction API",
52
+ description="REST API for predicting EASI scores from skin images",
53
+ version="1.0.0"
54
+ )
55
+
56
+ # CORS middleware for Flutter web/mobile
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=["*"], # In production, specify your Flutter app domain
60
+ allow_credentials=True,
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ # Configuration
66
+ HF_REPO_ID = "google/derm-foundation"
67
+ DERM_FOUNDATION_PATH = "./derm_foundation/"
68
+ R2_BASE_URL = os.environ.get("R2_BASE_URL", "https://r2-worker.eczemanage.workers.dev")
69
+
70
+ # Get Hugging Face token from environment variable
71
+ HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
72
+
73
+ # Response Models
74
+ class ConditionPrediction(BaseModel):
75
+ condition: str
76
+ probability: float = Field(..., ge=0, le=1)
77
+ confidence: float = Field(..., ge=0)
78
+ weight: float = Field(..., ge=0)
79
+ easi_category: Optional[str] = None
80
+ easi_contribution: int = Field(..., ge=0, le=3)
81
+
82
+ class EASIComponent(BaseModel):
83
+ name: str
84
+ score: int = Field(..., ge=0, le=3)
85
+ contributing_conditions: List[Dict[str, Any]]
86
+
87
+ class PredictionResponse(BaseModel):
88
+ success: bool
89
+ total_easi_score: int = Field(..., ge=0, le=12)
90
+ severity_interpretation: str
91
+ easi_components: Dict[str, EASIComponent]
92
+ predicted_conditions: List[ConditionPrediction]
93
+ summary_statistics: Dict[str, float]
94
+ image_info: Dict[str, Any]
95
+
96
+ class HealthResponse(BaseModel):
97
+ status: str
98
+ models_loaded: Dict[str, bool]
99
+ available_conditions: int
100
+ hf_token_configured: bool
101
+ model_source: str
102
+
103
+ class ErrorResponse(BaseModel):
104
+ success: bool = False
105
+ error: str
106
+ detail: Optional[str] = None
107
+
108
+
109
+ # Model wrapper class
110
+ class DermFoundationNeuralNetwork:
111
+ def __init__(self):
112
+ self.model = None
113
+ self.mlb = None
114
+ self.embedding_scaler = None
115
+ self.confidence_scaler = None
116
+ self.weighted_scaler = None
117
+
118
+ def load_model(self, filepath):
119
+ try:
120
+ with open(filepath, 'rb') as f:
121
+ model_data = pickle.load(f)
122
+
123
+ self.mlb = model_data['mlb']
124
+ self.embedding_scaler = model_data['embedding_scaler']
125
+ self.confidence_scaler = model_data['confidence_scaler']
126
+ self.weighted_scaler = model_data['weighted_scaler']
127
+
128
+ keras_model_path = model_data['keras_model_path']
129
+ if os.path.exists(keras_model_path):
130
+ self.model = tf.keras.models.load_model(keras_model_path)
131
+ return True
132
+ else:
133
+ return False
134
+ except Exception as e:
135
+ print(f"Error loading model: {e}")
136
+ return False
137
+
138
+ def predict(self, embedding):
139
+ if self.model is None:
140
+ return None
141
+
142
+ if len(embedding.shape) == 1:
143
+ embedding = embedding.reshape(1, -1)
144
+
145
+ embedding_scaled = self.embedding_scaler.transform(embedding)
146
+ predictions = self.model.predict(embedding_scaled, verbose=0)
147
+
148
+ condition_probs = predictions['conditions'][0]
149
+ individual_confidences = predictions['individual_confidences'][0]
150
+ individual_weights = predictions['individual_weights'][0]
151
+
152
+ condition_threshold = 0.3
153
+ predicted_condition_indices = np.where(condition_probs > condition_threshold)[0]
154
+
155
+ predicted_conditions = []
156
+ predicted_confidences = []
157
+ predicted_weights_dict = {}
158
+
159
+ for idx in predicted_condition_indices:
160
+ condition_name = self.mlb.classes_[idx]
161
+ condition_prob = float(condition_probs[idx])
162
+
163
+ if individual_confidences[idx] > 0:
164
+ confidence_orig = self.confidence_scaler.inverse_transform([[individual_confidences[idx]]])[0, 0]
165
+ else:
166
+ confidence_orig = 0.0
167
+
168
+ if individual_weights[idx] > 0:
169
+ weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[idx]]])[0, 0]
170
+ else:
171
+ weight_orig = 0.0
172
+
173
+ predicted_conditions.append(condition_name)
174
+ predicted_confidences.append(max(0, confidence_orig))
175
+ predicted_weights_dict[condition_name] = max(0, weight_orig)
176
+
177
+ all_condition_probs = {}
178
+ all_confidences = {}
179
+ all_weights = {}
180
+
181
+ for i, class_name in enumerate(self.mlb.classes_):
182
+ all_condition_probs[class_name] = float(condition_probs[i])
183
+
184
+ if individual_confidences[i] > 0:
185
+ conf_orig = self.confidence_scaler.inverse_transform([[individual_confidences[i]]])[0, 0]
186
+ all_confidences[class_name] = max(0, conf_orig)
187
+ else:
188
+ all_confidences[class_name] = 0.0
189
+
190
+ if individual_weights[i] > 0:
191
+ weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[i]]])[0, 0]
192
+ all_weights[class_name] = max(0, weight_orig)
193
+ else:
194
+ all_weights[class_name] = 0.0
195
+
196
+ return {
197
+ 'dermatologist_skin_condition_on_label_name': predicted_conditions,
198
+ 'dermatologist_skin_condition_confidence': predicted_confidences,
199
+ 'weighted_skin_condition_label': predicted_weights_dict,
200
+ 'all_condition_probabilities': all_condition_probs,
201
+ 'all_individual_confidences': all_confidences,
202
+ 'all_individual_weights': all_weights,
203
+ 'condition_threshold': condition_threshold
204
+ }
205
+
206
+
207
+ # Helper function to download from Cloudflare R2 with chunked streaming
208
+ def download_derm_foundation_from_r2(output_dir):
209
+ """Download Derm Foundation model from Cloudflare R2 using memory-efficient streaming"""
210
+ try:
211
+ print(f"Downloading Derm Foundation model from R2 ({R2_BASE_URL})...")
212
+ os.makedirs(output_dir, exist_ok=True)
213
+
214
+ # Files to download
215
+ files_to_download = [
216
+ "saved_model.pb",
217
+ "variables/variables.index",
218
+ "variables/variables.data-00000-of-00001"
219
+ ]
220
+
221
+ for file_path in files_to_download:
222
+ print(f"Downloading {file_path}...")
223
+ url = f"{R2_BASE_URL}/{file_path}"
224
+ local_path = os.path.join(output_dir, file_path)
225
+
226
+ # Create subdirectories if needed
227
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
228
+
229
+ # Download file with streaming (ULTRA MEMORY EFFICIENT)
230
+ # Use tiny chunk size and aggressive garbage collection
231
+ import gc
232
+
233
+ with requests.get(url, stream=True, timeout=900) as response:
234
+ response.raise_for_status()
235
+
236
+ total_size = int(response.headers.get('content-length', 0))
237
+ downloaded = 0
238
+ chunk_count = 0
239
+
240
+ # Write directly to disk in tiny chunks (256KB to minimize memory)
241
+ with open(local_path, 'wb') as f:
242
+ for chunk in response.iter_content(chunk_size=256*1024): # 256KB chunks
243
+ if chunk:
244
+ f.write(chunk)
245
+ f.flush() # Force write to disk
246
+ downloaded += len(chunk)
247
+ chunk_count += 1
248
+
249
+ # Aggressive garbage collection every 10 chunks (~2.5MB)
250
+ if chunk_count % 10 == 0:
251
+ gc.collect()
252
+
253
+ # Less frequent progress updates to reduce print overhead
254
+ if total_size > 0 and chunk_count % 20 == 0:
255
+ progress = (downloaded / total_size) * 100
256
+ mb_downloaded = downloaded / (1024*1024)
257
+ mb_total = total_size / (1024*1024)
258
+ print(f" Progress: {progress:.1f}% ({mb_downloaded:.1f}/{mb_total:.1f} MB)")
259
+
260
+ print() # New line after progress
261
+ gc.collect() # Final cleanup
262
+
263
+ print(f"✓ Downloaded: {file_path}")
264
+
265
+ print(f"✓ Derm Foundation model downloaded successfully from R2")
266
+ return True
267
+ except Exception as e:
268
+ print(f"✗ Error downloading from R2: {e}")
269
+ import traceback
270
+ traceback.print_exc()
271
+ return False
272
+
273
+
274
+ # Helper function to download from Hugging Face (Fallback) with memory-efficient streaming
275
+ def download_derm_foundation_from_hf(output_dir):
276
+ """Download Derm Foundation model from Hugging Face using memory-efficient streaming"""
277
+ try:
278
+ # Login to Hugging Face if token is available
279
+ if HF_TOKEN:
280
+ print("Authenticating with Hugging Face...")
281
+ login(token=HF_TOKEN)
282
+ else:
283
+ print("WARNING: No HF token found. Attempting download without authentication...")
284
+
285
+ print(f"Downloading Derm Foundation model from Hugging Face...")
286
+ os.makedirs(output_dir, exist_ok=True)
287
+
288
+ # Files to download
289
+ files_to_download = [
290
+ "saved_model.pb",
291
+ "variables/variables.data-00000-of-00001",
292
+ "variables/variables.index"
293
+ ]
294
+
295
+ for file_path in files_to_download:
296
+ print(f"Downloading {file_path}...")
297
+ local_path = os.path.join(output_dir, file_path)
298
+
299
+ # Create subdirectories if needed
300
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
301
+
302
+ # Download file with token if available
303
+ # hf_hub_download handles streaming internally
304
+ downloaded_path = hf_hub_download(
305
+ repo_id=HF_REPO_ID,
306
+ filename=file_path,
307
+ token=HF_TOKEN,
308
+ cache_dir=None,
309
+ local_dir=output_dir,
310
+ local_dir_use_symlinks=False,
311
+ resume_download=True # Resume if interrupted
312
+ )
313
+ print(f"✓ Downloaded: {file_path}")
314
+
315
+ print(f"✓ Derm Foundation model downloaded successfully from HuggingFace")
316
+ return True
317
+ except Exception as e:
318
+ print(f"✗ Error downloading from Hugging Face: {e}")
319
+ print(f"Make sure HUGGINGFACE_TOKEN is set in Render environment variables")
320
+ import traceback
321
+ traceback.print_exc()
322
+ return False
323
+
324
+
325
+ # EASI calculation functions
326
+ def calculate_easi_scores(predictions):
327
+ easi_categories = {
328
+ 'erythema': {
329
+ 'name': 'Erythema (Redness)',
330
+ 'conditions': [
331
+ 'Post-Inflammatory hyperpigmentation', 'Erythema ab igne', 'Erythema annulare centrifugum',
332
+ 'Erythema elevatum diutinum', 'Erythema gyratum repens', 'Erythema multiforme',
333
+ 'Erythema nodosum', 'Flagellate erythema', 'Annular erythema', 'Drug Rash',
334
+ 'Allergic Contact Dermatitis', 'Irritant Contact Dermatitis', 'Contact dermatitis',
335
+ 'Acute dermatitis', 'Chronic dermatitis', 'Acute and chronic dermatitis',
336
+ 'Sunburn', 'Photodermatitis', 'Phytophotodermatitis', 'Rosacea',
337
+ 'Seborrheic Dermatitis', 'Stasis Dermatitis', 'Perioral Dermatitis',
338
+ 'Burn erythema of abdominal wall', 'Burn erythema of back of hand',
339
+ 'Burn erythema of lower leg', 'Cellulitis', 'Infection of skin',
340
+ 'Viral Exanthem', 'Infected eczema', 'Crusted eczematous dermatitis',
341
+ 'Inflammatory dermatosis', 'Vasculitis of the skin', 'Leukocytoclastic Vasculitis',
342
+ 'Cutaneous lupus', 'CD - Contact dermatitis', 'Acute dermatitis, NOS',
343
+ 'Herpes Simplex', 'Hypersensitivity', 'Impetigo', 'Pigmented purpuric eruption',
344
+ 'Pityriasis rosea', 'Tinea', 'Tinea Versicolor'
345
+ ]
346
+ },
347
+ 'induration': {
348
+ 'name': 'Induration/Papulation (Swelling/Bumps)',
349
+ 'conditions': [
350
+ 'Prurigo nodularis', 'Urticaria', 'Granuloma annulare', 'Morphea',
351
+ 'Scleroderma', 'Lichen Simplex Chronicus', 'Lichen planus', 'lichenoid eruption',
352
+ 'Lichen nitidus', 'Lichen spinulosus', 'Lichen striatus', 'Keratosis pilaris',
353
+ 'Molluscum Contagiosum', 'Verruca vulgaris', 'Folliculitis', 'Acne',
354
+ 'Hidradenitis', 'Nodular vasculitis', 'Sweet syndrome', 'Necrobiosis lipoidica',
355
+ 'Basal Cell Carcinoma', 'SCC', 'SCCIS', 'SK', 'ISK',
356
+ 'Cutaneous T Cell Lymphoma', 'Skin cancer', 'Adnexal neoplasm',
357
+ 'Insect Bite', 'Milia', 'Miliaria', 'Xanthoma', 'Psoriasis',
358
+ 'Lichen planus/lichenoid eruption'
359
+ ]
360
+ },
361
+ 'excoriation': {
362
+ 'name': 'Excoriation (Scratching Damage)',
363
+ 'conditions': [
364
+ 'Inflicted skin lesions', 'Scabies', 'Abrasion', 'Abrasion of wrist',
365
+ 'Superficial wound of body region', 'Scrape', 'Animal bite - wound',
366
+ 'Pruritic dermatitis', 'Prurigo', 'Atopic dermatitis', 'Scab'
367
+ ]
368
+ },
369
+ 'lichenification': {
370
+ 'name': 'Lichenification (Skin Thickening)',
371
+ 'conditions': [
372
+ 'Lichenified eczematous dermatitis', 'Acanthosis nigricans',
373
+ 'Hyperkeratosis of skin', 'HK - Hyperkeratosis', 'Keratoderma',
374
+ 'Ichthyosis', 'Ichthyosiform dermatosis', 'Chronic eczema',
375
+ 'Psoriasis', 'Xerosis'
376
+ ]
377
+ }
378
+ }
379
+
380
+ def probability_to_score(prob):
381
+ if prob < 0.171:
382
+ return 0
383
+ elif prob < 0.238:
384
+ return 1
385
+ elif prob < 0.421:
386
+ return 2
387
+ elif prob < 0.614:
388
+ return 3
389
+ else:
390
+ return 3
391
+
392
+ easi_results = {}
393
+ all_condition_probs = predictions['all_condition_probabilities']
394
+
395
+ for component, category_info in easi_categories.items():
396
+ category_conditions = []
397
+
398
+ for condition_name, probability in all_condition_probs.items():
399
+ if condition_name.lower() == 'eczema':
400
+ continue
401
+
402
+ if condition_name in category_info['conditions']:
403
+ category_conditions.append({
404
+ 'condition': condition_name,
405
+ 'probability': probability,
406
+ 'individual_score': probability_to_score(probability)
407
+ })
408
+
409
+ category_conditions = [c for c in category_conditions if c['individual_score'] > 0]
410
+ category_conditions.sort(key=lambda x: x['probability'], reverse=True)
411
+
412
+ component_score = sum(c['individual_score'] for c in category_conditions)
413
+ component_score = min(component_score, 3)
414
+
415
+ easi_results[component] = {
416
+ 'name': category_info['name'],
417
+ 'score': component_score,
418
+ 'contributing_conditions': category_conditions
419
+ }
420
+
421
+ total_easi = sum(result['score'] for result in easi_results.values())
422
+
423
+ return easi_results, total_easi
424
+
425
+
426
+ def get_severity_interpretation(total_easi):
427
+ if total_easi == 0:
428
+ return "No significant EASI features detected"
429
+ elif total_easi <= 3:
430
+ return "Mild EASI severity"
431
+ elif total_easi <= 6:
432
+ return "Moderate EASI severity"
433
+ elif total_easi <= 9:
434
+ return "Severe EASI severity"
435
+ else:
436
+ return "Very Severe EASI severity"
437
+
438
+
439
+ # Image processing functions
440
+ def smart_crop_to_square(image):
441
+ width, height = image.size
442
+ if width == height:
443
+ return image
444
+
445
+ size = min(width, height)
446
+ left = (width - size) // 2
447
+ top = (height - size) // 2
448
+ right = left + size
449
+ bottom = top + size
450
+
451
+ return image.crop((left, top, right, bottom))
452
+
453
+
454
+ def generate_derm_foundation_embedding(model, image):
455
+ try:
456
+ if image.mode != 'RGB':
457
+ image = image.convert('RGB')
458
+
459
+ buf = BytesIO()
460
+ image.save(buf, format='JPEG')
461
+ image_bytes = buf.getvalue()
462
+
463
+ input_tensor = tf.train.Example(features=tf.train.Features(
464
+ feature={'image/encoded': tf.train.Feature(
465
+ bytes_list=tf.train.BytesList(value=[image_bytes]))
466
+ })).SerializeToString()
467
+
468
+ infer = model.signatures["serving_default"]
469
+ output = infer(inputs=tf.constant([input_tensor]))
470
+
471
+ if 'embedding' in output:
472
+ embedding_vector = output['embedding'].numpy().flatten()
473
+ else:
474
+ key = list(output.keys())[0]
475
+ embedding_vector = output[key].numpy().flatten()
476
+
477
+ return embedding_vector
478
+ except Exception as e:
479
+ raise HTTPException(status_code=500, detail=f"Error generating embedding: {str(e)}")
480
+
481
+
482
+ # Global model instances
483
+ derm_model = None
484
+ easi_model = None
485
+ model_source = "not_loaded"
486
+
487
+
488
+ @app.on_event("startup")
489
+ async def load_models():
490
+ """Load models on startup"""
491
+ global derm_model, easi_model, model_source
492
+
493
+ # Force garbage collection before starting
494
+ import gc
495
+ gc.collect()
496
+
497
+ # Check if model exists locally
498
+ if not os.path.exists(DERM_FOUNDATION_PATH) or not os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")):
499
+ print("=" * 60)
500
+ print("Derm Foundation model not found locally.")
501
+ print("Downloading from Hugging Face Hub...")
502
+ print("=" * 60)
503
+
504
+ # Download directly from HuggingFace Hub
505
+ success = download_derm_foundation_from_hf(DERM_FOUNDATION_PATH)
506
+
507
+ if success:
508
+ model_source = "huggingface"
509
+ else:
510
+ print("=" * 60)
511
+ print("ERROR: Failed to download model from HuggingFace!")
512
+ print("=" * 60)
513
+ model_source = "failed"
514
+ else:
515
+ print(" Derm Foundation model found locally (cached)")
516
+ model_source = "local_cache"
517
+
518
+ # Load Derm Foundation model
519
+ if os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")):
520
+ try:
521
+ print(f"Loading Derm-Foundation model from: {DERM_FOUNDATION_PATH}")
522
+ # Force garbage collection before loading large model
523
+ gc.collect()
524
+
525
+ derm_model = tf.saved_model.load(DERM_FOUNDATION_PATH)
526
+ print(f" Derm-Foundation model loaded successfully (source: {model_source})")
527
+
528
+ # Cleanup after loading
529
+ gc.collect()
530
+ except Exception as e:
531
+ print(f"✗ Failed to load Derm Foundation model: {str(e)}")
532
+
533
+ # Load EASI model (keep this local in your repo)
534
+ model_path = './trained_model/easi_severity_model_derm_foundation_individual.pkl'
535
+ if os.path.exists(model_path):
536
+ easi_model = DermFoundationNeuralNetwork()
537
+ success = easi_model.load_model(model_path)
538
+ if success:
539
+ print(f"✓ EASI model loaded from: {model_path}")
540
+ else:
541
+ print(f" Failed to load EASI model")
542
+ easi_model = None
543
+ else:
544
+ print(f"✗ EASI model not found at: {model_path}")
545
+
546
+ if derm_model is None or easi_model is None:
547
+ print("=" * 60)
548
+ print("WARNING: Some models failed to load!")
549
+ print(f"Derm Foundation: {'' if derm_model else '✗'}")
550
+ print(f"EASI Model: {'✓' if easi_model else '✗'}")
551
+ print("=" * 60)
552
+ else:
553
+ print("=" * 60)
554
+ print("✓ All models loaded successfully!")
555
+ print(f"Model source: {model_source}")
556
+ print("=" * 60)
557
+
558
+
559
+ # API Endpoints
560
+
561
+ @app.get("/")
562
+ async def root():
563
+ """Root endpoint"""
564
+ return {
565
+ "message": "EASI Severity Prediction API",
566
+ "version": "1.0.0",
567
+ "model_source": model_source,
568
+ "docs": "/docs",
569
+ "health": "/health",
570
+ "predict": "/predict",
571
+ "conditions": "/conditions"
572
+ }
573
+
574
+
575
+ @app.get("/health", response_model=HealthResponse)
576
+ async def health_check():
577
+ """Health check endpoint"""
578
+ return {
579
+ "status": "ok" if (derm_model is not None and easi_model is not None) else "degraded",
580
+ "models_loaded": {
581
+ "derm_foundation": derm_model is not None,
582
+ "easi_model": easi_model is not None
583
+ },
584
+ "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0,
585
+ "hf_token_configured": HF_TOKEN is not None,
586
+ "model_source": model_source
587
+ }
588
+
589
+
590
+ @app.get("/conditions", response_model=Dict[str, List[str]])
591
+ async def get_conditions():
592
+ """Get list of available conditions"""
593
+ if easi_model is None:
594
+ raise HTTPException(status_code=503, detail="EASI model not loaded")
595
+
596
+ return {
597
+ "conditions": easi_model.mlb.classes_.tolist()
598
+ }
599
+
600
+
601
+ @app.post("/predict", response_model=PredictionResponse)
602
+ async def predict_easi(
603
+ file: UploadFile = File(..., description="Skin image file (JPG, JPEG, PNG)")
604
+ ):
605
+ """
606
+ Predict EASI scores from uploaded skin image.
607
+
608
+ - **file**: Image file (JPG, JPEG, PNG)
609
+ - Returns: EASI scores, component breakdown, and condition predictions
610
+ """
611
+
612
+ # Validate models loaded
613
+ if derm_model is None or easi_model is None:
614
+ raise HTTPException(
615
+ status_code=503,
616
+ detail="Models not loaded. Check server logs."
617
+ )
618
+
619
+ # Validate file type
620
+ if not file.content_type.startswith('image/'):
621
+ raise HTTPException(
622
+ status_code=400,
623
+ detail="File must be an image (JPG, JPEG, PNG)"
624
+ )
625
+
626
+ try:
627
+ # Read and process image
628
+ image_bytes = await file.read()
629
+ original_image = Image.open(BytesIO(image_bytes)).convert('RGB')
630
+ original_size = original_image.size
631
+
632
+ # Process to 448x448
633
+ cropped_img = smart_crop_to_square(original_image)
634
+ processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS)
635
+
636
+ # Generate embedding
637
+ embedding = generate_derm_foundation_embedding(derm_model, processed_img)
638
+
639
+ # Make prediction
640
+ predictions = easi_model.predict(embedding)
641
+
642
+ if predictions is None:
643
+ raise HTTPException(status_code=500, detail="Prediction failed")
644
+
645
+ # Calculate EASI scores
646
+ easi_results, total_easi = calculate_easi_scores(predictions)
647
+ severity = get_severity_interpretation(total_easi)
648
+
649
+ # Format predicted conditions
650
+ predicted_conditions = []
651
+ for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']):
652
+ prob = predictions['all_condition_probabilities'][condition]
653
+ conf = predictions['dermatologist_skin_condition_confidence'][i]
654
+ weight = predictions['weighted_skin_condition_label'][condition]
655
+
656
+ # Find EASI category
657
+ easi_category = None
658
+ easi_contribution = 0
659
+ for cat_key, cat_info in easi_results.items():
660
+ for contrib in cat_info['contributing_conditions']:
661
+ if contrib['condition'] == condition:
662
+ easi_category = cat_info['name']
663
+ easi_contribution = contrib['individual_score']
664
+ break
665
+
666
+ predicted_conditions.append(ConditionPrediction(
667
+ condition=condition,
668
+ probability=float(prob),
669
+ confidence=float(conf),
670
+ weight=float(weight),
671
+ easi_category=easi_category,
672
+ easi_contribution=easi_contribution
673
+ ))
674
+
675
+ # Summary statistics
676
+ summary_stats = {
677
+ "total_conditions": len(predicted_conditions),
678
+ "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0,
679
+ "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0,
680
+ "total_weight": float(sum(predictions['weighted_skin_condition_label'].values()))
681
+ }
682
+
683
+ # Format EASI components
684
+ easi_components_formatted = {
685
+ component: EASIComponent(
686
+ name=result['name'],
687
+ score=result['score'],
688
+ contributing_conditions=result['contributing_conditions']
689
+ )
690
+ for component, result in easi_results.items()
691
+ }
692
+
693
+ return PredictionResponse(
694
+ success=True,
695
+ total_easi_score=total_easi,
696
+ severity_interpretation=severity,
697
+ easi_components=easi_components_formatted,
698
+ predicted_conditions=predicted_conditions,
699
+ summary_statistics=summary_stats,
700
+ image_info={
701
+ "original_size": f"{original_size[0]}x{original_size[1]}",
702
+ "processed_size": "448x448",
703
+ "filename": file.filename
704
+ }
705
+ )
706
+
707
+ except HTTPException:
708
+ raise
709
+ except Exception as e:
710
+ raise HTTPException(
711
+ status_code=500,
712
+ detail=f"Error processing image: {str(e)}"
713
+ )
714
+
715
+
716
+ @app.exception_handler(HTTPException)
717
+ async def http_exception_handler(request, exc):
718
+ return JSONResponse(
719
+ status_code=exc.status_code,
720
+ content=ErrorResponse(
721
+ error=exc.detail,
722
+ detail=str(exc)
723
+ ).dict()
724
+ )
725
+
726
+
727
+ if __name__ == "__main__":
728
+ import uvicorn
729
+ uvicorn.run(app, host="0.0.0.0", port=8000)