Kalaoke's picture
Update handler.py
1b02147 verified
from __future__ import annotations
import base64
from curses import raw
import re
import unicodedata
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Dict, Optional, List, Set
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
logging.set_verbosity_info()
BASE_MODEL_ID = "mistral-community/pixtral-12b"
DEFAULT_PROMPT = (
"Here is a picture showing some food waste.\n\n"
"Task: Provide a list of the food waste items visible in the picture.\n"
"For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n"
"Rules:\n"
"- Output must be a single line (no extra text, no markdown).\n"
"- Use the exact spelling of CATEGORY as listed.\n"
"- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n"
"CATEGORY DICTIONARY:\n"
"1. beer\n"
"2. cabbage\n"
"3. whipped_porridge\n"
"4. cider\n"
"5. orange\n"
"6. dairy_product_milk\n"
"7. yoghurt\n"
"8. herbs\n"
"9. potato_product\n"
"10. strawberry\n"
"11. minced_chicken\n"
"12. soda\n"
"13. sauce_paste\n"
"14. rice_porridge\n"
"15. blueberry\n"
"16. light_bread\n"
"17. semolina_porridge\n"
"18. egg\n"
"19. crepes\n"
"20. sliced_cheese\n"
"21. chicken\n"
"22. food_waste\n"
"23. candy\n"
"24. vegetarian_sauce\n"
"25. cheese\n"
"26. cereal\n"
"27. pork_product\n"
"28. vegetable\n"
"29. honey\n"
"30. plant_based_cream\n"
"31. vegetarian_pizza\n"
"32. dried fruits\n"
"33. rice_or_pasta\n"
"34. pork_steak\n"
"35. wine\n"
"36. soft_drink\n"
"37. minced_beef\n"
"38. steak\n"
"39. frankfurter\n"
"40. vegetarian_soup\n"
"41. beef_frankfurter\n"
"42. alcoholic_long_drink\n"
"43. raspberry\n"
"44. rice\n"
"45. mandarin\n"
"46. juice\n"
"47. porridge\n"
"48. fish_soup\n"
"49. fish_hamburger\n"
"50. fruit\n"
"51. coffee\n"
"52. plant_based_ice_cream\n"
"53. beef_sausage\n"
"54. minced_pork\n"
"55. meat\n"
"56. sweet_pastry\n"
"57. ice_cream\n"
"58. fish_salad\n"
"59. flour\n"
"60. snack\n"
"61. fish_product\n"
"62. cherry\n"
"63. shellfish\n"
"64. cream\n"
"65. dessert\n"
"66. cold_cut_chicken\n"
"67. onion\n"
"68. dark_bread\n"
"69. plant_based_milk\n"
"70. fermented_milk\n"
"71. cocoa\n"
"72. bell_pepper\n"
"73. beef\n"
"74. sausage\n"
"75. dairy_product\n"
"76. pork\n"
"77. fish_fillet\n"
"78. strong_alcoholic_beverage\n"
"79. biscuit\n"
"80. currant\n"
"81. sweet_soup\n"
"82. popcorn\n"
"83. meat_sauce\n"
"84. meat_pizza\n"
"85. lemon\n"
"86. plum\n"
"87. grain_products\n"
"88. sugar_honey_syrup\n"
"89. vegetarian_hamburger\n"
"90. jam\n"
"91. vegetarian_stew\n"
"92. beef_steak\n"
"93. tomato\n"
"94. meat_soup\n"
"95. block_cheese\n"
"96. potato_based\n"
"97. potato\n"
"98. flakes\n"
"99. soft_cheese\n"
"100. beef_product\n"
"101. chocolate\n"
"102. cauliflower\n"
"103. nectarine\n"
"104. banana\n"
"105. apple\n"
"106. alcoholic_beverage\n"
"107. meat_salad\n"
"108. pork_sausage\n"
"109. lingonberry\n"
"110. chicken_sausage\n"
"111. pork_frankfurter\n"
"112. minced_meat\n"
"113. baked_goods\n"
"114. broccoli\n"
"115. cold_cut_beef\n"
"116. meat_dish\n"
"117. plant_cheese\n"
"118. butter\n"
"119. children_milk\n"
"120. fish_sauce\n"
"121. buttermilk\n"
"122. cucumber\n"
"123. meat_hamburger\n"
"124. pasta\n"
"125. tea\n"
"126. berries\n"
"127. milk\n"
"128. fish_pizza\n"
"129. fresh_cheese\n"
"130. fish_stew\n"
"131. sirup\n"
"132. vegetarian_salad\n"
"133. spice\n"
"134. quark\n"
"135. cold_cut_pork\n"
"136. sauce_seasoning\n"
"137. oatmeal\n"
"138. sugar\n"
"139. meat_product\n"
"140. vegetarian_dish\n"
"141. savory_pastry\n"
"142. chicken_product\n"
"143. vegetable_mix\n"
"144. potato_chip\n"
"145. muesli\n"
"146. carrot\n"
"147. plant_based_yogurt\n"
"148. sweet\n"
"149. eggs\n"
"150. chicken_frankfurter\n"
"151. bread\n"
"152. fish_strips\n"
"153. lettuce\n"
"154. fish\n"
"155. meat_stew\n"
"156. hot_beverage\n"
"157. cherry tomato\n"
"158. fish_dish\n"
"159. fat_oil\n"
"160. grape\n"
"161. plant_protein\n"
"162. pastry\n"
"163. oil\n"
"164. cold_cut_meat\n\n"
"OUTPUT FORMAT (single line):\n"
"CATEGORY1,CATEGORY2,CATEGORY3\n"
)
@dataclass
class GenerationConfig:
max_new_tokens: int = 256
temperature: float = 0.0
no_repeat_ngram_size: int = 6
repetition_penalty: float = 1.1
max_length: int = 4096
max_side: int = 512
_ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]")
ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL)
def _clean_text(s: str) -> str:
s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ")
s = re.sub(r"[\u200B-\u200F]", "", s)
s = re.sub(r"\s+", " ", s).strip()
return s
_CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE)
def _extract_categories_from_prompt(prompt: str) -> Set[str]:
cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)}
return {c for c in cats if c}
def _clean_model_output(s: str) -> str:
s = _clean_text(s).lower()
s = _ALLOWED_RE.sub("", s)
return s
def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]:
s = _clean_model_output(raw)
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[str] = []
seen: Set[str] = set()
for p in parts:
if p in allowed and p not in seen:
out.append(p)
seen.add(p)
return out
class EndpointHandler:
def __init__(self, path: str = ".") -> None:
"""
Initializes the model and processor from the `path` directory,
which contains the merged weights (pixtral-12b-foodwaste-merged).
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Initializing EndpointHandler on device: %s", self.device)
self.processor = AutoProcessor.from_pretrained(
BASE_MODEL_ID,
trust_remote_code=True,
)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.model = LlavaForConditionalGeneration.from_pretrained(
BASE_MODEL_ID,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map={"": self.device},
trust_remote_code=True,
)
self.model.eval()
logger.info("Model and processor successfully loaded from '%s'.", path)
# pad token management
tokenizer = getattr(self.processor, "tokenizer", None)
if tokenizer is not None and tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Preparation of EOS/PAD IDs for generate
eos_candidates: List[int] = []
if self.model.config.eos_token_id is not None:
eos_candidates.append(self.model.config.eos_token_id)
if tokenizer is not None and tokenizer.eos_token_id is not None:
eos_candidates.append(tokenizer.eos_token_id)
self.eos_token_ids: List[int] = list({i for i in eos_candidates})
if not self.eos_token_ids:
raise ValueError("No EOS token id found on model or tokenizer.")
pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None)
if pad_id is None and tokenizer is not None:
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = self.eos_token_ids[0]
self.pad_token_id: int = pad_id
self.gen_config = GenerationConfig()
logger.info(
"Generation config: max_new_tokens=%d, temperature=%.3f",
self.gen_config.max_new_tokens,
self.gen_config.temperature,
)
self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(DEFAULT_PROMPT)
logger.info("Extracted %d categories from DEFAULT_PROMPT.", len(self.default_allowed_categories))
@staticmethod
def _decode_image(image_b64: str) -> Image.Image:
try:
img_bytes = base64.b64decode(image_b64)
img = Image.open(BytesIO(img_bytes)).convert("RGB")
return img
except Exception as exc: # pragma: no cover - log production
raise ValueError(f"Could not decode base64 image: {exc}") from exc
@staticmethod
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / m
# LANCZOS = better downscale
return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS)
def _build_chat_text(self, prompt: str) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image"},
],
}
]
chat_text = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
return chat_text
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
debug = bool(inputs.get("debug", False))
prompt: str = inputs.get("prompt") or DEFAULT_PROMPT
allowed_categories = _extract_categories_from_prompt(prompt) or self.default_allowed_categories
image_b64: Optional[str] = inputs.get("image")
if not image_b64:
raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
image = self._decode_image(image_b64)
max_length = int(inputs.get("max_length", self.gen_config.max_length))
max_side = int(inputs.get("max_side", self.gen_config.max_side))
image = self._resize_max_side(image, max_side=max_side)
max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens))
temperature = float(inputs.get("temperature", self.gen_config.temperature))
chat_text = self._build_chat_text(prompt)
enc = self.processor(
text=[chat_text],
images=[image],
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False, # important for correct prompt_len
)
prompt_len = int(enc["input_ids"].shape[1])
tokens_left = max(1, max_length - prompt_len)
max_new_tokens = min(max_new_tokens, tokens_left)
if debug:
logger.info("===== TOKEN BUDGET DEBUG (inference) =====")
logger.info("img size: %s", getattr(image, "size", None))
logger.info("max_length: %d", max_length)
logger.info("prompt_len: %d", prompt_len)
logger.info("tokens_left_for_answer(approx): %d", tokens_left)
tok = getattr(self.processor, "tokenizer", None)
if tok is not None:
ids = enc["input_ids"][0].tolist()
logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False))
logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False))
logger.info("=========================================")
enc = {k: v.to(self.device) for k, v in enc.items()}
if "pixel_values" in enc:
enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0.0,
"eos_token_id": self.eos_token_ids,
"pad_token_id": self.pad_token_id,
"no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size,
"repetition_penalty": self.gen_config.repetition_penalty,
}
if temperature > 0.0:
gen_kwargs["temperature"] = temperature
with torch.inference_mode():
output_ids = self.model.generate(**enc, **gen_kwargs)
generated_only = output_ids[:, enc["input_ids"].shape[1]:]
generated_text = self.processor.batch_decode(
generated_only,
skip_special_tokens=True,
)[0].strip()
cats = _parse_and_validate_categories(generated_text, allowed_categories)
if not cats:
cats = ["food_waste"] if "food_waste" in allowed_categories else []
generated_text = ",".join(cats)
logger.info("Generated text: %s", generated_text)
return {"generated_text": generated_text, "categories": cats}