Penny_V2 / gemma_utils.py
pythonprincess's picture
Upload 25 files
6347098 verified
raw
history blame
7.66 kB
# models/gemma/gemma_utils.py
"""
Gemma Model Utilities for PENNY Project
Handles text generation using the Gemma-based core language model pipeline.
Provides async generation with structured error handling and logging.
"""
import asyncio
import time
from typing import Dict, Any, Optional
# --- Logging Imports ---
from app.logging_utils import log_interaction, sanitize_for_logging
# --- Model Loader Import ---
try:
from app.model_loader import load_model_pipeline
MODEL_LOADER_AVAILABLE = True
except ImportError:
MODEL_LOADER_AVAILABLE = False
import logging
logging.getLogger(__name__).warning("Could not import load_model_pipeline. Gemma service unavailable.")
# Global variable to store the loaded pipeline for re-use
GEMMA_PIPELINE: Optional[Any] = None
AGENT_NAME = "penny-core-agent"
INITIALIZATION_ATTEMPTED = False
def _initialize_gemma_pipeline() -> bool:
"""
Initializes the Gemma pipeline only once.
Returns:
bool: True if initialization succeeded, False otherwise.
"""
global GEMMA_PIPELINE, INITIALIZATION_ATTEMPTED
if INITIALIZATION_ATTEMPTED:
return GEMMA_PIPELINE is not None
INITIALIZATION_ATTEMPTED = True
if not MODEL_LOADER_AVAILABLE:
log_interaction(
intent="gemma_initialization",
success=False,
error="model_loader unavailable"
)
return False
try:
log_interaction(
intent="gemma_initialization",
success=None,
details=f"Loading {AGENT_NAME}"
)
GEMMA_PIPELINE = load_model_pipeline(AGENT_NAME)
if GEMMA_PIPELINE is None:
log_interaction(
intent="gemma_initialization",
success=False,
error="Pipeline returned None"
)
return False
log_interaction(
intent="gemma_initialization",
success=True,
details=f"Model {AGENT_NAME} loaded successfully"
)
return True
except Exception as e:
log_interaction(
intent="gemma_initialization",
success=False,
error=str(e)
)
return False
# Attempt initialization at module load
_initialize_gemma_pipeline()
def is_gemma_available() -> bool:
"""
Check if Gemma service is available.
Returns:
bool: True if Gemma pipeline is loaded and ready.
"""
return GEMMA_PIPELINE is not None
async def generate_response(
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
tenant_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Runs text generation using the loaded Gemma pipeline.
Args:
prompt: The conversational or instruction prompt.
max_new_tokens: The maximum number of tokens to generate (default: 256).
temperature: Controls randomness in generation (default: 0.7).
tenant_id: Optional tenant identifier for logging.
Returns:
A dictionary containing:
- response (str): The generated text
- available (bool): Whether the service was available
- error (str, optional): Error message if generation failed
- response_time_ms (int, optional): Generation time in milliseconds
"""
start_time = time.time()
global GEMMA_PIPELINE
# Check availability
if not is_gemma_available():
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=False,
error="Gemma pipeline not available",
fallback_used=True
)
return {
"response": "I'm having trouble accessing my language model right now. Please try again in a moment!",
"available": False,
"error": "Pipeline not initialized"
}
# Validate inputs
if not prompt or not isinstance(prompt, str):
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=False,
error="Invalid prompt provided"
)
return {
"response": "I didn't receive a valid prompt. Could you try again?",
"available": True,
"error": "Invalid input"
}
# Configure generation parameters
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": True if temperature > 0.0 else False,
"return_full_text": False
}
try:
loop = asyncio.get_event_loop()
# Run model inference in thread executor
results = await loop.run_in_executor(
None,
lambda: GEMMA_PIPELINE(prompt, **gen_kwargs)
)
response_time_ms = int((time.time() - start_time) * 1000)
# Parse results
if results and isinstance(results, list) and len(results) > 0:
if isinstance(results[0], dict) and 'generated_text' in results[0]:
generated_text = results[0]['generated_text'].strip()
# Log slow responses
if response_time_ms > 5000:
log_interaction(
intent="gemma_generate_slow",
tenant_id=tenant_id,
success=True,
response_time_ms=response_time_ms,
details="Slow generation detected"
)
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=True,
response_time_ms=response_time_ms,
prompt_preview=sanitize_for_logging(prompt[:100])
)
return {
"response": generated_text,
"available": True,
"response_time_ms": response_time_ms
}
# Unexpected output format
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=False,
error="Unexpected model output format",
response_time_ms=response_time_ms
)
return {
"response": "I got an unexpected response from my language model. Let me try to help you another way!",
"available": True,
"error": "Unexpected output format"
}
except asyncio.CancelledError:
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=False,
error="Generation cancelled"
)
raise
except Exception as e:
response_time_ms = int((time.time() - start_time) * 1000)
log_interaction(
intent="gemma_generate",
tenant_id=tenant_id,
success=False,
error=str(e),
response_time_ms=response_time_ms,
fallback_used=True
)
return {
"response": "I'm having trouble generating a response right now. Please try again!",
"available": False,
"error": str(e),
"response_time_ms": response_time_ms
}