pythonprincess commited on
Commit
fa3e666
ยท
verified ยท
1 Parent(s): e785234

Upload model_loader.py

Browse files
Files changed (1) hide show
  1. app/model_loader.py +889 -0
app/model_loader.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/model_loader.py
2
+ """
3
+ ๐Ÿง  PENNY Model Loader - Azure-Ready Multi-Model Orchestration
4
+
5
+ This is Penny's brain loader. She manages multiple specialized models:
6
+ - Gemma 7B for conversational reasoning
7
+ - NLLB-200 for 27-language translation
8
+ - Sentiment analysis for resident wellbeing
9
+ - Bias detection for equitable service
10
+ - LayoutLM for civic document processing
11
+
12
+ MISSION: Load AI models efficiently in memory-constrained environments while
13
+ maintaining Penny's warm, civic-focused personality across all interactions.
14
+
15
+ FEATURES:
16
+ - Lazy loading (models only load when needed)
17
+ - 8-bit quantization for memory efficiency
18
+ - GPU/CPU auto-detection
19
+ - Model caching and reuse
20
+ - Graceful fallbacks for Azure ML deployment
21
+ - Memory monitoring and cleanup
22
+ """
23
+
24
+ import json
25
+ import os
26
+ import torch
27
+ from typing import Dict, Any, Callable, Optional, Union, List
28
+ from pathlib import Path
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from enum import Enum
32
+ from datetime import datetime
33
+
34
+ # --- LOGGING SETUP (Must be before functions that use it) ---
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ============================================================
38
+ # HUGGING FACE AUTHENTICATION
39
+ # ============================================================
40
+
41
+ def setup_huggingface_auth() -> bool:
42
+ """
43
+ ๐Ÿ” Authenticates with Hugging Face Hub using HF_TOKEN.
44
+
45
+ Returns:
46
+ True if authentication successful or not needed, False if failed
47
+ """
48
+ HF_TOKEN = os.getenv("HF_TOKEN")
49
+
50
+ if not HF_TOKEN:
51
+ logger.warning("โš ๏ธ HF_TOKEN not found in environment")
52
+ logger.warning(" Some models may not be accessible")
53
+ logger.warning(" Set HF_TOKEN in your environment or Hugging Face Spaces secrets")
54
+ return False
55
+
56
+ try:
57
+ from huggingface_hub import login
58
+ login(token=HF_TOKEN, add_to_git_credential=False)
59
+ logger.info("โœ… Authenticated with Hugging Face Hub")
60
+ return True
61
+ except ImportError:
62
+ logger.warning("โš ๏ธ huggingface_hub not installed, skipping authentication")
63
+ return False
64
+ except Exception as e:
65
+ logger.error(f"โŒ Failed to authenticate with Hugging Face: {e}")
66
+ return False
67
+
68
+ # Attempt authentication at module load
69
+ setup_huggingface_auth()
70
+
71
+ # --- PATH CONFIGURATION (Environment-Aware) ---
72
+ # Support both local development and Azure ML deployment
73
+ if os.getenv("AZUREML_MODEL_DIR"):
74
+ # Azure ML deployment - models are in AZUREML_MODEL_DIR
75
+ MODEL_ROOT = Path(os.getenv("AZUREML_MODEL_DIR"))
76
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
77
+ logger.info("โ˜๏ธ Running in Azure ML environment")
78
+ else:
79
+ # Local development - models are in project structure
80
+ PROJECT_ROOT = Path(__file__).parent.parent
81
+ MODEL_ROOT = PROJECT_ROOT / "models"
82
+ CONFIG_PATH = MODEL_ROOT / "model_config.json"
83
+ logger.info("๐Ÿ’ป Running in local development environment")
84
+
85
+ logger.info(f"๐Ÿ“‚ Model config path: {CONFIG_PATH}")
86
+
87
+ # ============================================================
88
+ # PENNY'S CIVIC IDENTITY & PERSONALITY
89
+ # ============================================================
90
+
91
+ PENNY_SYSTEM_PROMPT = (
92
+ "You are Penny, a smart, civic-focused AI assistant serving local communities. "
93
+ "You help residents navigate city services, government programs, and community resources. "
94
+ "You're warm, professional, accurate, and always stay within your civic mission.\n\n"
95
+
96
+ "Your expertise includes:\n"
97
+ "- Connecting people with local services (food banks, shelters, libraries)\n"
98
+ "- Translating information into 27 languages\n"
99
+ "- Explaining public programs and eligibility\n"
100
+ "- Guiding residents through civic processes\n"
101
+ "- Providing emergency resources when needed\n\n"
102
+
103
+ "YOUR PERSONALITY:\n"
104
+ "- Warm and approachable, like a helpful community center staff member\n"
105
+ "- Clear and practical, avoiding jargon\n"
106
+ "- Culturally sensitive and inclusive\n"
107
+ "- Patient with repetition or clarification\n"
108
+ "- Funny when appropriate, but never at anyone's expense\n\n"
109
+
110
+ "CRITICAL RULES:\n"
111
+ "- When residents greet you by name (e.g., 'Hi Penny'), respond warmly and personally\n"
112
+ "- You are ALWAYS Penny - never ChatGPT, Assistant, Claude, or any other name\n"
113
+ "- If you don't know something, say so clearly and help find the right resource\n"
114
+ "- NEVER make up information about services, eligibility, or contacts\n"
115
+ "- Stay within your civic mission - you don't provide legal, medical, or financial advice\n"
116
+ "- For emergencies, immediately connect to appropriate services (911, crisis lines)\n\n"
117
+ )
118
+
119
+ # --- GLOBAL STATE ---
120
+ _MODEL_CACHE: Dict[str, Any] = {} # Memory-efficient model reuse
121
+ _LOAD_TIMES: Dict[str, float] = {} # Track model loading performance
122
+
123
+
124
+ # ============================================================
125
+ # DEVICE MANAGEMENT
126
+ # ============================================================
127
+
128
+ class DeviceType(str, Enum):
129
+ """Supported compute devices."""
130
+ CUDA = "cuda"
131
+ CPU = "cpu"
132
+ MPS = "mps" # Apple Silicon
133
+
134
+
135
+ def get_optimal_device() -> str:
136
+ """
137
+ ๐ŸŽฎ Determines the best device for model inference.
138
+
139
+ Priority:
140
+ 1. CUDA GPU (NVIDIA)
141
+ 2. MPS (Apple Silicon)
142
+ 3. CPU (fallback)
143
+
144
+ Returns:
145
+ Device string ("cuda", "mps", or "cpu")
146
+ """
147
+ if torch.cuda.is_available():
148
+ device = DeviceType.CUDA.value
149
+ gpu_name = torch.cuda.get_device_name(0)
150
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
151
+ logger.info(f"๐ŸŽฎ GPU detected: {gpu_name} ({gpu_memory:.1f}GB)")
152
+ return device
153
+
154
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
155
+ device = DeviceType.MPS.value
156
+ logger.info("๐ŸŽ Apple Silicon (MPS) detected")
157
+ return device
158
+
159
+ else:
160
+ device = DeviceType.CPU.value
161
+ logger.info("๐Ÿ’ป Using CPU for inference")
162
+ logger.warning("โš ๏ธ GPU not available - inference will be slower")
163
+ return device
164
+
165
+
166
+ def get_memory_stats() -> Dict[str, float]:
167
+ """
168
+ ๐Ÿ“Š Returns current GPU/CPU memory statistics.
169
+
170
+ Returns:
171
+ Dict with memory stats in GB
172
+ """
173
+ stats = {}
174
+
175
+ if torch.cuda.is_available():
176
+ stats["gpu_allocated_gb"] = torch.cuda.memory_allocated() / 1e9
177
+ stats["gpu_reserved_gb"] = torch.cuda.memory_reserved() / 1e9
178
+ stats["gpu_total_gb"] = torch.cuda.get_device_properties(0).total_memory / 1e9
179
+
180
+ # CPU memory (requires psutil)
181
+ try:
182
+ import psutil
183
+ mem = psutil.virtual_memory()
184
+ stats["cpu_used_gb"] = mem.used / 1e9
185
+ stats["cpu_total_gb"] = mem.total / 1e9
186
+ stats["cpu_percent"] = mem.percent
187
+ except ImportError:
188
+ pass
189
+
190
+ return stats
191
+
192
+
193
+ # ============================================================
194
+ # MODEL CLIENT (Individual Model Handler)
195
+ # ============================================================
196
+
197
+ @dataclass
198
+ class ModelMetadata:
199
+ """
200
+ ๐Ÿ“‹ Metadata about a loaded model.
201
+ Tracks performance and resource usage.
202
+ """
203
+ name: str
204
+ task: str
205
+ model_name: str
206
+ device: str
207
+ loaded_at: Optional[datetime] = None
208
+ load_time_seconds: Optional[float] = None
209
+ memory_usage_gb: Optional[float] = None
210
+ inference_count: int = 0
211
+ total_inference_time_ms: float = 0.0
212
+
213
+ @property
214
+ def avg_inference_time_ms(self) -> float:
215
+ """Calculate average inference time."""
216
+ if self.inference_count == 0:
217
+ return 0.0
218
+ return self.total_inference_time_ms / self.inference_count
219
+
220
+
221
+ class ModelClient:
222
+ """
223
+ ๐Ÿค– Manages a single HuggingFace model with optimized loading and inference.
224
+
225
+ Features:
226
+ - Lazy loading (load on first use)
227
+ - Memory optimization (8-bit quantization)
228
+ - Performance tracking
229
+ - Graceful error handling
230
+ - Automatic device placement
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ name: str,
236
+ model_name: str,
237
+ task: str,
238
+ device: str = None,
239
+ config: Optional[Dict[str, Any]] = None
240
+ ):
241
+ """
242
+ Initialize model client (doesn't load the model yet).
243
+
244
+ Args:
245
+ name: Model identifier (e.g., "penny-core-agent")
246
+ model_name: HuggingFace model ID
247
+ task: Task type (text-generation, translation, etc.)
248
+ device: Target device (auto-detected if None)
249
+ config: Additional model configuration
250
+ """
251
+ self.name = name
252
+ self.model_name = model_name
253
+ self.task = task
254
+ self.device = device or get_optimal_device()
255
+ self.config = config or {}
256
+ self.pipeline = None
257
+ self._load_attempted = False
258
+ self.metadata = ModelMetadata(
259
+ name=name,
260
+ task=task,
261
+ model_name=model_name,
262
+ device=self.device
263
+ )
264
+
265
+ logger.info(f"๐Ÿ“ฆ Initialized ModelClient: {name}")
266
+ logger.debug(f" Model: {model_name}")
267
+ logger.debug(f" Task: {task}")
268
+ logger.debug(f" Device: {self.device}")
269
+
270
+ def load_pipeline(self) -> bool:
271
+ """
272
+ ๐Ÿ”„ Loads the HuggingFace pipeline with Azure-optimized settings.
273
+
274
+ Features:
275
+ - 8-bit quantization for large models (saves ~50% memory)
276
+ - Automatic device placement
277
+ - Memory monitoring
278
+ - Cache checking
279
+
280
+ Returns:
281
+ True if successful, False otherwise
282
+ """
283
+ if self.pipeline is not None:
284
+ logger.debug(f"โœ… {self.name} already loaded")
285
+ return True
286
+
287
+ if self._load_attempted:
288
+ logger.warning(f"โš ๏ธ Previous load attempt failed for {self.name}")
289
+ return False
290
+
291
+ global _MODEL_CACHE, _LOAD_TIMES
292
+
293
+ # Check cache first
294
+ if self.name in _MODEL_CACHE:
295
+ logger.info(f"โ™ป๏ธ Using cached pipeline for {self.name}")
296
+ self.pipeline = _MODEL_CACHE[self.name]
297
+ return True
298
+
299
+ logger.info(f"๐Ÿ”„ Loading {self.name} from HuggingFace...")
300
+ self._load_attempted = True
301
+
302
+ start_time = datetime.now()
303
+
304
+ try:
305
+ # Import pipeline from transformers (lazy import to avoid dependency issues)
306
+ from transformers import pipeline
307
+
308
+ # === TEXT GENERATION (Gemma 7B, GPT-2, etc.) ===
309
+ if self.task == "text-generation":
310
+ logger.info(" Using 8-bit quantization for memory efficiency...")
311
+
312
+ # Check if model supports 8-bit loading
313
+ use_8bit = self.device == DeviceType.CUDA.value
314
+
315
+ if use_8bit:
316
+ self.pipeline = pipeline(
317
+ "text-generation",
318
+ model=self.model_name,
319
+ tokenizer=self.model_name,
320
+ device_map="auto",
321
+ load_in_8bit=True, # Reduces ~14GB to ~7GB
322
+ trust_remote_code=True,
323
+ torch_dtype=torch.float16
324
+ )
325
+ else:
326
+ # CPU fallback
327
+ self.pipeline = pipeline(
328
+ "text-generation",
329
+ model=self.model_name,
330
+ tokenizer=self.model_name,
331
+ device=-1, # CPU
332
+ trust_remote_code=True,
333
+ torch_dtype=torch.float32
334
+ )
335
+
336
+ # === TRANSLATION (NLLB-200, M2M-100, etc.) ===
337
+ elif self.task == "translation":
338
+ self.pipeline = pipeline(
339
+ "translation",
340
+ model=self.model_name,
341
+ device=0 if self.device == DeviceType.CUDA.value else -1,
342
+ src_lang=self.config.get("default_src_lang", "eng_Latn"),
343
+ tgt_lang=self.config.get("default_tgt_lang", "spa_Latn")
344
+ )
345
+
346
+ # === SENTIMENT ANALYSIS ===
347
+ elif self.task == "sentiment-analysis":
348
+ self.pipeline = pipeline(
349
+ "sentiment-analysis",
350
+ model=self.model_name,
351
+ device=0 if self.device == DeviceType.CUDA.value else -1,
352
+ truncation=True,
353
+ max_length=512
354
+ )
355
+
356
+ # === BIAS DETECTION (Zero-Shot Classification) ===
357
+ elif self.task == "bias-detection":
358
+ self.pipeline = pipeline(
359
+ "zero-shot-classification",
360
+ model=self.model_name,
361
+ device=0 if self.device == DeviceType.CUDA.value else -1
362
+ )
363
+
364
+ # === TEXT CLASSIFICATION (Generic) ===
365
+ elif self.task == "text-classification":
366
+ self.pipeline = pipeline(
367
+ "text-classification",
368
+ model=self.model_name,
369
+ device=0 if self.device == DeviceType.CUDA.value else -1,
370
+ truncation=True
371
+ )
372
+
373
+ # === PDF/DOCUMENT EXTRACTION (LayoutLMv3) ===
374
+ elif self.task == "pdf-extraction":
375
+ logger.warning("โš ๏ธ PDF extraction requires additional OCR setup")
376
+ logger.info(" Consider using Azure Form Recognizer as alternative")
377
+ # Placeholder - requires pytesseract/OCR infrastructure
378
+ self.pipeline = None
379
+ return False
380
+
381
+ else:
382
+ raise ValueError(f"Unknown task type: {self.task}")
383
+
384
+ # === SUCCESS HANDLING ===
385
+ if self.pipeline is not None:
386
+ # Calculate load time
387
+ load_time = (datetime.now() - start_time).total_seconds()
388
+ self.metadata.loaded_at = datetime.now()
389
+ self.metadata.load_time_seconds = load_time
390
+
391
+ # Cache the pipeline
392
+ _MODEL_CACHE[self.name] = self.pipeline
393
+ _LOAD_TIMES[self.name] = load_time
394
+
395
+ # Log memory usage
396
+ mem_stats = get_memory_stats()
397
+ self.metadata.memory_usage_gb = mem_stats.get("gpu_allocated_gb", 0)
398
+
399
+ logger.info(f"โœ… {self.name} loaded successfully!")
400
+ logger.info(f" Load time: {load_time:.2f}s")
401
+
402
+ if "gpu_allocated_gb" in mem_stats:
403
+ logger.info(
404
+ f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB / "
405
+ f"{mem_stats['gpu_total_gb']:.2f}GB"
406
+ )
407
+
408
+ return True
409
+
410
+ except Exception as e:
411
+ logger.error(f"โŒ Failed to load {self.name}: {e}", exc_info=True)
412
+ self.pipeline = None
413
+ return False
414
+
415
+ def predict(
416
+ self,
417
+ input_data: Union[str, Dict[str, Any]],
418
+ **kwargs
419
+ ) -> Dict[str, Any]:
420
+ """
421
+ ๐ŸŽฏ Runs inference with the loaded model pipeline.
422
+
423
+ Features:
424
+ - Automatic pipeline loading
425
+ - Error handling with fallback responses
426
+ - Performance tracking
427
+ - Penny's personality injection (for text-generation)
428
+
429
+ Args:
430
+ input_data: Text or structured input for the model
431
+ **kwargs: Task-specific parameters
432
+
433
+ Returns:
434
+ Model output dict with results or error information
435
+ """
436
+ # Track inference start time
437
+ start_time = datetime.now()
438
+
439
+ # Ensure pipeline is loaded
440
+ if self.pipeline is None:
441
+ success = self.load_pipeline()
442
+ if not success:
443
+ return {
444
+ "error": f"{self.name} pipeline unavailable",
445
+ "detail": "Model failed to load. Check logs for details.",
446
+ "model": self.name
447
+ }
448
+
449
+ try:
450
+ # === TEXT GENERATION ===
451
+ if self.task == "text-generation":
452
+ # Inject Penny's civic identity
453
+ if not kwargs.get("skip_system_prompt", False):
454
+ full_prompt = PENNY_SYSTEM_PROMPT + input_data
455
+ else:
456
+ full_prompt = input_data
457
+
458
+ # Extract generation parameters with safe defaults
459
+ max_new_tokens = kwargs.get("max_new_tokens", 256)
460
+ temperature = kwargs.get("temperature", 0.7)
461
+ top_p = kwargs.get("top_p", 0.9)
462
+ do_sample = kwargs.get("do_sample", temperature > 0.0)
463
+
464
+ result = self.pipeline(
465
+ full_prompt,
466
+ max_new_tokens=max_new_tokens,
467
+ temperature=temperature,
468
+ top_p=top_p,
469
+ do_sample=do_sample,
470
+ return_full_text=False,
471
+ pad_token_id=self.pipeline.tokenizer.eos_token_id,
472
+ truncation=True
473
+ )
474
+
475
+ output = {
476
+ "generated_text": result[0]["generated_text"],
477
+ "model": self.name,
478
+ "success": True
479
+ }
480
+
481
+ # === TRANSLATION ===
482
+ elif self.task == "translation":
483
+ src_lang = kwargs.get("source_lang", "eng_Latn")
484
+ tgt_lang = kwargs.get("target_lang", "spa_Latn")
485
+
486
+ result = self.pipeline(
487
+ input_data,
488
+ src_lang=src_lang,
489
+ tgt_lang=tgt_lang,
490
+ max_length=512
491
+ )
492
+
493
+ output = {
494
+ "translation": result[0]["translation_text"],
495
+ "source_lang": src_lang,
496
+ "target_lang": tgt_lang,
497
+ "model": self.name,
498
+ "success": True
499
+ }
500
+
501
+ # === SENTIMENT ANALYSIS ===
502
+ elif self.task == "sentiment-analysis":
503
+ result = self.pipeline(input_data)
504
+
505
+ output = {
506
+ "sentiment": result[0]["label"],
507
+ "confidence": result[0]["score"],
508
+ "model": self.name,
509
+ "success": True
510
+ }
511
+
512
+ # === BIAS DETECTION ===
513
+ elif self.task == "bias-detection":
514
+ candidate_labels = kwargs.get("candidate_labels", [
515
+ "neutral and objective",
516
+ "contains political bias",
517
+ "uses emotional language",
518
+ "culturally insensitive"
519
+ ])
520
+
521
+ result = self.pipeline(
522
+ input_data,
523
+ candidate_labels=candidate_labels,
524
+ multi_label=True
525
+ )
526
+
527
+ output = {
528
+ "labels": result["labels"],
529
+ "scores": result["scores"],
530
+ "model": self.name,
531
+ "success": True
532
+ }
533
+
534
+ # === TEXT CLASSIFICATION ===
535
+ elif self.task == "text-classification":
536
+ result = self.pipeline(input_data)
537
+
538
+ output = {
539
+ "label": result[0]["label"],
540
+ "confidence": result[0]["score"],
541
+ "model": self.name,
542
+ "success": True
543
+ }
544
+
545
+ else:
546
+ output = {
547
+ "error": f"Task '{self.task}' not implemented",
548
+ "model": self.name,
549
+ "success": False
550
+ }
551
+
552
+ # Track performance
553
+ inference_time = (datetime.now() - start_time).total_seconds() * 1000
554
+ self.metadata.inference_count += 1
555
+ self.metadata.total_inference_time_ms += inference_time
556
+ output["inference_time_ms"] = round(inference_time, 2)
557
+
558
+ return output
559
+
560
+ except Exception as e:
561
+ logger.error(f"โŒ Inference error in {self.name}: {e}", exc_info=True)
562
+ return {
563
+ "error": "Inference failed",
564
+ "detail": str(e),
565
+ "model": self.name,
566
+ "success": False
567
+ }
568
+
569
+ def unload(self) -> None:
570
+ """
571
+ ๐Ÿ—‘๏ธ Unloads the model to free memory.
572
+ Critical for Azure environments with limited resources.
573
+ """
574
+ if self.pipeline is not None:
575
+ logger.info(f"๐Ÿ—‘๏ธ Unloading {self.name}...")
576
+
577
+ # Delete pipeline
578
+ del self.pipeline
579
+ self.pipeline = None
580
+
581
+ # Remove from cache
582
+ if self.name in _MODEL_CACHE:
583
+ del _MODEL_CACHE[self.name]
584
+
585
+ # Force GPU memory release
586
+ if torch.cuda.is_available():
587
+ torch.cuda.empty_cache()
588
+
589
+ logger.info(f"โœ… {self.name} unloaded successfully")
590
+
591
+ # Log memory stats after unload
592
+ mem_stats = get_memory_stats()
593
+ if "gpu_allocated_gb" in mem_stats:
594
+ logger.info(f" GPU Memory: {mem_stats['gpu_allocated_gb']:.2f}GB remaining")
595
+
596
+ def get_metadata(self) -> Dict[str, Any]:
597
+ """
598
+ ๐Ÿ“Š Returns model metadata and performance stats.
599
+ """
600
+ return {
601
+ "name": self.metadata.name,
602
+ "task": self.metadata.task,
603
+ "model_name": self.metadata.model_name,
604
+ "device": self.metadata.device,
605
+ "loaded": self.pipeline is not None,
606
+ "loaded_at": self.metadata.loaded_at.isoformat() if self.metadata.loaded_at else None,
607
+ "load_time_seconds": self.metadata.load_time_seconds,
608
+ "memory_usage_gb": self.metadata.memory_usage_gb,
609
+ "inference_count": self.metadata.inference_count,
610
+ "avg_inference_time_ms": round(self.metadata.avg_inference_time_ms, 2)
611
+ }
612
+
613
+
614
+ # ============================================================
615
+ # MODEL LOADER (Singleton Manager)
616
+ # ============================================================
617
+
618
+ class ModelLoader:
619
+ """
620
+ ๐ŸŽ›๏ธ Singleton manager for all Penny's specialized models.
621
+
622
+ Features:
623
+ - Centralized model configuration
624
+ - Lazy loading (models only load when needed)
625
+ - Memory management
626
+ - Health monitoring
627
+ - Unified access interface
628
+ """
629
+
630
+ _instance: Optional['ModelLoader'] = None
631
+
632
+ def __new__(cls, *args, **kwargs):
633
+ """Singleton pattern - only one ModelLoader instance."""
634
+ if cls._instance is None:
635
+ cls._instance = super(ModelLoader, cls).__new__(cls)
636
+ return cls._instance
637
+
638
+ def __init__(self, config_path: Optional[str] = None):
639
+ """
640
+ Initialize ModelLoader (only runs once due to singleton).
641
+
642
+ Args:
643
+ config_path: Path to model_config.json (optional)
644
+ """
645
+ if not hasattr(self, '_models_loaded'):
646
+ self.models: Dict[str, ModelClient] = {}
647
+ self._models_loaded = True
648
+ self._initialization_time = datetime.now()
649
+
650
+ # Use provided path or default
651
+ config_file = Path(config_path) if config_path else CONFIG_PATH
652
+
653
+ try:
654
+ logger.info(f"๐Ÿ“– Loading model configuration from {config_file}")
655
+
656
+ if not config_file.exists():
657
+ logger.warning(f"โš ๏ธ Configuration file not found: {config_file}")
658
+ logger.info(" Create model_config.json with your model definitions")
659
+ return
660
+
661
+ with open(config_file, "r") as f:
662
+ config = json.load(f)
663
+
664
+ # Initialize ModelClients (doesn't load models yet)
665
+ for model_id, model_info in config.items():
666
+ self.models[model_id] = ModelClient(
667
+ name=model_id,
668
+ model_name=model_info["model_name"],
669
+ task=model_info["task"],
670
+ config=model_info.get("config", {})
671
+ )
672
+
673
+ logger.info(f"โœ… ModelLoader initialized with {len(self.models)} models:")
674
+ for model_id in self.models.keys():
675
+ logger.info(f" - {model_id}")
676
+
677
+ except json.JSONDecodeError as e:
678
+ logger.error(f"โŒ Invalid JSON in model_config.json: {e}")
679
+ except Exception as e:
680
+ logger.error(f"โŒ Failed to initialize ModelLoader: {e}", exc_info=True)
681
+
682
+ def get(self, model_id: str) -> Optional[ModelClient]:
683
+ """
684
+ ๐ŸŽฏ Retrieves a configured ModelClient by ID.
685
+
686
+ Args:
687
+ model_id: Model identifier from config
688
+
689
+ Returns:
690
+ ModelClient instance or None if not found
691
+ """
692
+ return self.models.get(model_id)
693
+
694
+ def list_models(self) -> List[str]:
695
+ """๐Ÿ“‹ Returns list of all available model IDs."""
696
+ return list(self.models.keys())
697
+
698
+ def get_loaded_models(self) -> List[str]:
699
+ """๐Ÿ“‹ Returns list of currently loaded model IDs."""
700
+ return [
701
+ model_id
702
+ for model_id, client in self.models.items()
703
+ if client.pipeline is not None
704
+ ]
705
+
706
+ def unload_all(self) -> None:
707
+ """
708
+ ๐Ÿ—‘๏ธ Unloads all models to free memory.
709
+ Useful for Azure environments when switching workloads.
710
+ """
711
+ logger.info("๐Ÿ—‘๏ธ Unloading all models...")
712
+ for model_client in self.models.values():
713
+ model_client.unload()
714
+ logger.info("โœ… All models unloaded")
715
+
716
+ def get_status(self) -> Dict[str, Any]:
717
+ """
718
+ ๐Ÿ“Š Returns comprehensive status of all models.
719
+ Useful for health checks and monitoring.
720
+ """
721
+ status = {
722
+ "initialization_time": self._initialization_time.isoformat(),
723
+ "total_models": len(self.models),
724
+ "loaded_models": len(self.get_loaded_models()),
725
+ "device": get_optimal_device(),
726
+ "memory": get_memory_stats(),
727
+ "models": {}
728
+ }
729
+
730
+ for model_id, client in self.models.items():
731
+ status["models"][model_id] = client.get_metadata()
732
+
733
+ return status
734
+
735
+
736
+ # ============================================================
737
+ # PUBLIC INTERFACE (Used by all *_utils.py modules)
738
+ # ============================================================
739
+
740
+ def load_model_pipeline(agent_name: str) -> Callable[..., Dict[str, Any]]:
741
+ """
742
+ ๐Ÿš€ Loads a model client and returns its inference function.
743
+
744
+ This is the main function used by other modules (translation_utils.py,
745
+ sentiment_utils.py, etc.) to access Penny's models.
746
+
747
+ Args:
748
+ agent_name: Model ID from model_config.json
749
+
750
+ Returns:
751
+ Callable inference function
752
+
753
+ Raises:
754
+ ValueError: If agent_name not found in configuration
755
+
756
+ Example:
757
+ >>> translator = load_model_pipeline("penny-translate-agent")
758
+ >>> result = translator("Hello world", target_lang="spa_Latn")
759
+ """
760
+ loader = ModelLoader()
761
+ client = loader.get(agent_name)
762
+
763
+ if client is None:
764
+ available = loader.list_models()
765
+ raise ValueError(
766
+ f"Agent ID '{agent_name}' not found in model configuration. "
767
+ f"Available models: {available}"
768
+ )
769
+
770
+ # Load the pipeline (lazy loading)
771
+ client.load_pipeline()
772
+
773
+ # Return a callable wrapper
774
+ def inference_wrapper(input_data, **kwargs):
775
+ return client.predict(input_data, **kwargs)
776
+
777
+ return inference_wrapper
778
+
779
+
780
+ # === CONVENIENCE FUNCTIONS ===
781
+
782
+ def get_model_status() -> Dict[str, Any]:
783
+ """
784
+ ๐Ÿ“Š Returns status of all configured models.
785
+ Useful for health checks and monitoring endpoints.
786
+ """
787
+ loader = ModelLoader()
788
+ return loader.get_status()
789
+
790
+
791
+ def preload_models(model_ids: Optional[List[str]] = None) -> None:
792
+ """
793
+ ๐Ÿš€ Preloads specified models during startup.
794
+
795
+ Args:
796
+ model_ids: List of model IDs to preload (None = all models)
797
+ """
798
+ loader = ModelLoader()
799
+
800
+ if model_ids is None:
801
+ model_ids = loader.list_models()
802
+
803
+ logger.info(f"๐Ÿš€ Preloading {len(model_ids)} models...")
804
+
805
+ for model_id in model_ids:
806
+ client = loader.get(model_id)
807
+ if client:
808
+ logger.info(f" Loading {model_id}...")
809
+ client.load_pipeline()
810
+
811
+ logger.info("โœ… Model preloading complete")
812
+
813
+
814
+ def initialize_model_system() -> bool:
815
+ """
816
+ ๐Ÿ Initializes the model system.
817
+ Should be called during app startup.
818
+
819
+ Returns:
820
+ True if initialization successful
821
+ """
822
+ logger.info("๐Ÿง  Initializing Penny's model system...")
823
+
824
+ try:
825
+ # Initialize singleton
826
+ loader = ModelLoader()
827
+
828
+ # Log device info
829
+ device = get_optimal_device()
830
+ mem_stats = get_memory_stats()
831
+
832
+ logger.info(f"โœ… Model system initialized")
833
+ logger.info(f"๐ŸŽฎ Compute device: {device}")
834
+
835
+ if "gpu_total_gb" in mem_stats:
836
+ logger.info(
837
+ f"๐Ÿ’พ GPU Memory: {mem_stats['gpu_total_gb']:.1f}GB total"
838
+ )
839
+
840
+ logger.info(f"๐Ÿ“ฆ {len(loader.models)} models configured")
841
+
842
+ # Optional: Preload critical models
843
+ # Uncomment to preload models at startup
844
+ # preload_models(["penny-core-agent"])
845
+
846
+ return True
847
+
848
+ except Exception as e:
849
+ logger.error(f"โŒ Failed to initialize model system: {e}", exc_info=True)
850
+ return False
851
+
852
+
853
+ # ============================================================
854
+ # CLI TESTING & DEBUGGING
855
+ # ============================================================
856
+
857
+ if __name__ == "__main__":
858
+ """
859
+ ๐Ÿงช Test script for model loading and inference.
860
+ Run with: python -m app.model_loader
861
+ """
862
+ print("=" * 60)
863
+ print("๐Ÿงช Testing Penny's Model System")
864
+ print("=" * 60)
865
+
866
+ # Initialize
867
+ loader = ModelLoader()
868
+ print(f"\n๐Ÿ“‹ Available models: {loader.list_models()}")
869
+
870
+ # Get status
871
+ status = get_model_status()
872
+ print(f"\n๐Ÿ“Š System status:")
873
+ print(json.dumps(status, indent=2, default=str))
874
+
875
+ # Test model loading (if models configured)
876
+ if loader.models:
877
+ test_model_id = list(loader.models.keys())[0]
878
+ print(f"\n๐Ÿงช Testing model: {test_model_id}")
879
+
880
+ client = loader.get(test_model_id)
881
+ if client:
882
+ print(f" Loading pipeline...")
883
+ success = client.load_pipeline()
884
+
885
+ if success:
886
+ print(f" โœ… Model loaded successfully!")
887
+ print(f" Metadata: {json.dumps(client.get_metadata(), indent=2, default=str)}")
888
+ else:
889
+ print(f" โŒ Model loading failed")