QwenStem-7b / handler.py
analist's picture
Update handler.py
9853ed9 verified
"""
Custom Handler for QwenStem-7b on Hugging Face Endpoints
Handles both text and multimodal (text+image) inputs
"""
import torch
import base64
import logging
from io import BytesIO
from typing import Dict, List, Any, Optional
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model handler for HF Endpoints
Args:
path: Path to the model directory (provided by HF Endpoints)
"""
logger.info(f"Initializing model from path: {path}")
# Détection du device disponible
if torch.cuda.is_available():
self.device = torch.device("cuda")
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
self.device = torch.device("cpu")
logger.info("Using CPU")
try:
# Chargement du processor
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
path if path else "analist/QwenStem-7b",
trust_remote_code=True
)
# Chargement du modèle SANS quantification pour HF Endpoints
# La quantification sera gérée par l'infrastructure si nécessaire
logger.info("Loading model...")
self.model = AutoModelForVision2Seq.from_pretrained(
path if path else "analist/QwenStem-7b",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(self.device)
# Mise en mode évaluation
self.model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Configuration de génération par défaut
self.default_generation_config = {
"max_new_tokens": 9192 * 10,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"repetition_penalty": 1.05
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process incoming request for HF Endpoints
Args:
data: Dictionary containing:
- inputs: Text prompt (str) or dict with 'text' and optionally 'image'
- parameters: Optional generation parameters (dict)
Returns:
List with response dictionary
"""
try:
# Extraction des données
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Log de la requête
logger.info(f"Processing request - Input type: {type(inputs)}")
# Fusion des paramètres
gen_config = {**self.default_generation_config, **parameters}
# Traitement selon le type d'entrée
if isinstance(inputs, dict):
# Format structuré
text = inputs.get("text", "")
image_data = inputs.get("image", None)
if image_data:
logger.info("Processing multimodal input (text + image)")
response = self._process_multimodal(text, image_data, gen_config)
else:
logger.info("Processing text-only input from dict")
response = self._process_text(text, gen_config)
elif isinstance(inputs, str):
# Texte simple
logger.info("Processing text-only input")
response = self._process_text(inputs, gen_config)
else:
raise ValueError(f"Unsupported input type: {type(inputs)}")
return [{"generated_text": response}]
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
return [{"error": str(e), "error_type": type(e).__name__}]
def _process_text(self, text: str, config: dict) -> str:
"""
Process text-only input
"""
if not text:
raise ValueError("Empty text input")
# Construction des messages
messages = [
{"role": "user", "content": [
{"type": "text", "text": text}
]}
]
# Application du template
text_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(self.device)
# Génération
with torch.no_grad():
outputs = self.model.generate(
text_inputs,
max_new_tokens=config.get("max_new_tokens", 9192 * 10),
temperature=config.get("temperature", 0.7),
top_p=config.get("top_p", 0.9),
do_sample=config.get("do_sample", True),
repetition_penalty=config.get("repetition_penalty", 1.05),
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id
)
# Décodage de la réponse
full_response = self.processor.decode(outputs[0], skip_special_tokens=True)
# Extraction de la partie assistant
if "assistant" in full_response:
response = full_response.split("assistant")[-1].strip()
else:
# Retirer le prompt de l'entrée
response = full_response[len(self.processor.decode(text_inputs[0], skip_special_tokens=True)):].strip()
return response
def _process_multimodal(self, text: str, image_b64: str, config: dict) -> str:
"""
Process text + image input
"""
# Décodage de l'image
try:
if image_b64.startswith('data:image'):
# Retirer le header data:image/png;base64, si présent
image_b64 = image_b64.split(',')[1]
image_bytes = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
logger.info(f"Image loaded: {image.size}")
except Exception as e:
logger.error(f"Image decode error: {str(e)}")
raise ValueError(f"Failed to decode image: {str(e)}")
# Construction du message multimodal
messages = [
{"role": "user", "content": [
{"type": "text", "text": text if text else "Analyse cette image."},
{"type": "image"}
]}
]
# Génération du prompt
prompt = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
# Traitement avec l'image
inputs = self.processor(
text=prompt,
images=[image],
return_tensors="pt"
)
# Déplacement vers le device
inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
for k, v in inputs.items()}
# Génération
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=config.get("max_new_tokens", 9192 * 10),
temperature=config.get("temperature", 0.7),
top_p=config.get("top_p", 0.9),
do_sample=config.get("do_sample", True),
repetition_penalty=config.get("repetition_penalty", 1.05),
pad_token_id=self.processor.tokenizer.eos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id
)
# Décodage
full_response = self.processor.decode(outputs[0], skip_special_tokens=True)
# Extraction de la réponse
if "assistant" in full_response:
response = full_response.split("assistant")[-1].strip()
else:
response = full_response.split(text)[-1].strip() if text in full_response else full_response
return response
def health(self) -> Dict[str, Any]:
"""
Health check endpoint for monitoring
Returns system and model status
"""
health_status = {
"status": "healthy",
"model": {
"name": "QwenStem-7b",
"type": "Vision-Language Model",
"loaded": hasattr(self, 'model') and self.model is not None,
"device": str(self.device) if hasattr(self, 'device') else "unknown"
},
"system": {
"torch_version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
}
}
# Informations GPU si disponible
if torch.cuda.is_available() and hasattr(self, 'device') and self.device.type == 'cuda':
try:
gpu_props = torch.cuda.get_device_properties(0)
health_status["gpu"] = {
"name": gpu_props.name,
"memory_total_gb": round(gpu_props.total_memory / (1024**3), 2),
"memory_allocated_gb": round(torch.cuda.memory_allocated() / (1024**3), 2),
"memory_reserved_gb": round(torch.cuda.memory_reserved() / (1024**3), 2),
"utilization_percent": round(torch.cuda.memory_allocated() / gpu_props.total_memory * 100, 2)
}
except Exception as e:
logger.warning(f"Could not get GPU stats: {e}")
health_status["gpu"] = {"error": str(e)}
# Test rapide du modèle si demandé
if hasattr(self, 'model') and self.model is not None:
try:
# Test minimal pour vérifier que le modèle répond
with torch.no_grad():
test_input = self.processor.apply_chat_template(
[{"role": "user", "content": [{"type": "text", "text": "test"}]}],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(self.device)
# Génération très courte juste pour tester
_ = self.model.generate(
test_input,
max_new_tokens=1,
do_sample=False
)
health_status["model"]["responsive"] = True
except Exception as e:
logger.error(f"Model test failed: {e}")
health_status["model"]["responsive"] = False
health_status["model"]["error"] = str(e)
health_status["status"] = "degraded"
return health_status