File size: 3,452 Bytes
70be7c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import Dict, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path=""):
# Configuraci贸 de la quantitzaci贸
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Carrega el processador i model de forma global
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
self.model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
quantization_config=quantization_config,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
logs = []
logs.append("Iniciant processament de la petici贸.")
inputs = data.get("inputs")
if not inputs:
logs.append("Format d'entrada inv脿lid. Manca la clau 'inputs'.")
return {"error": "Invalid input format. 'inputs' key is missing.", "logs": logs}
image_url = inputs.get("url")
image_data = inputs.get("image_data")
prompt = inputs.get("prompt")
max_tokens = inputs.get("max_tokens", 100)
if not prompt:
logs.append("S'ha de proporcionar 'prompt' en 'inputs'.")
return {"error": "The 'prompt' must be provided in 'inputs'.", "logs": logs}
if not image_url and not image_data:
logs.append("S'ha de proporcionar 'url' o 'image_data' en 'inputs'.")
return {"error": "Either 'url' or 'image_data' must be provided in 'inputs'.", "logs": logs}
logs.append(f"Processant entrada: url={image_url}, image_data={'present' if image_data else 'absent'}, prompt={prompt}")
try:
if image_url:
logs.append(f"Carregant imatge des de URL: {image_url}")
response = requests.get(image_url, stream=True)
image = Image.open(response.raw)
elif image_data:
logs.append("Carregant imatge des de dades d'imatge en brut.")
image = Image.open(BytesIO(base64.b64decode(image_data)))
if image.format == 'PNG':
logs.append("Convertint imatge PNG a JPG.")
image = image.convert('RGB')
buffer = BytesIO()
image.save(buffer, format="JPEG")
buffer.seek(0)
image = Image.open(buffer)
except Exception as e:
logs.append(f"Error carregant imatge: {str(e)}")
return {"error": str(e), "logs": logs}
try:
logs.append("Processant imatge amb el model.")
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
output = self.model.generate(**inputs, max_new_tokens=max_tokens)
result = self.processor.decode(output[0], skip_special_tokens=True)
logs.append("Processament complet.")
return {"input_prompt": prompt, "model_output": result, "logs": logs}
except Exception as e:
logs.append(f"Error processant el model: {str(e)}")
return {"error": str(e), "logs": logs}
|