File size: 2,304 Bytes
9b164d1 f6721ff 9b164d1 f6721ff 9b164d1 44df4d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from typing import Dict, List, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import re
class EndpointHandler():
def __init__(self, path=""):
# Configuració de la quantització
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Carrega el processador i model de forma global
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
self.model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
quantization_config=quantization_config,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image_url = data.get("url")
prompt = data.get("prompt")
try:
response = requests.get(image_url, stream=True)
image = Image.open(response.raw)
if image.format == 'PNG':
image = image.convert('RGB')
buffer = BytesIO()
image.save(buffer, format="JPEG")
buffer.seek(0)
image = Image.open(buffer)
except Exception as e:
return {"error": str(e)}
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
output = self.model.generate(**inputs, max_new_tokens=100)
result = self.processor.decode(output[0], skip_special_tokens=True)
scores = self.extract_scores(result)
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
return sorted_scores
def extract_scores(self, response):
scores = {}
result_part = response.split("[/INST]")[-1].strip()
pattern = re.compile(r'(\d+)\.\s*(.*?):\s*(\d+)')
matches = pattern.findall(result_part)
for match in matches:
category_number = int(match[0])
category_name = match[1].strip()
score = int(match[2])
scores[category_name] = score
return scores
|