| | from __future__ import annotations |
| |
|
| | import base64 |
| | from curses import raw |
| | import re |
| | import unicodedata |
| | from dataclasses import dataclass |
| | from io import BytesIO |
| | from typing import Any, Dict, Optional, List, Set |
| |
|
| | import torch |
| | from PIL import Image |
| | from transformers import AutoProcessor, LlavaForConditionalGeneration |
| | from transformers.utils import logging |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| | logging.set_verbosity_info() |
| |
|
| | BASE_MODEL_ID = "mistral-community/pixtral-12b" |
| |
|
| | DEFAULT_PROMPT = ( |
| | "Here is a picture showing some food waste.\n\n" |
| | "Task: Provide a list of the food waste items visible in the picture.\n" |
| | "For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n" |
| | "Rules:\n" |
| | "- Output must be a single line (no extra text, no markdown).\n" |
| | "- Use the exact spelling of CATEGORY as listed.\n" |
| | "- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n" |
| | "CATEGORY DICTIONARY:\n" |
| | "1. beer\n" |
| | "2. cabbage\n" |
| | "3. whipped_porridge\n" |
| | "4. cider\n" |
| | "5. orange\n" |
| | "6. dairy_product_milk\n" |
| | "7. yoghurt\n" |
| | "8. herbs\n" |
| | "9. potato_product\n" |
| | "10. strawberry\n" |
| | "11. minced_chicken\n" |
| | "12. soda\n" |
| | "13. sauce_paste\n" |
| | "14. rice_porridge\n" |
| | "15. blueberry\n" |
| | "16. light_bread\n" |
| | "17. semolina_porridge\n" |
| | "18. egg\n" |
| | "19. crepes\n" |
| | "20. sliced_cheese\n" |
| | "21. chicken\n" |
| | "22. food_waste\n" |
| | "23. candy\n" |
| | "24. vegetarian_sauce\n" |
| | "25. cheese\n" |
| | "26. cereal\n" |
| | "27. pork_product\n" |
| | "28. vegetable\n" |
| | "29. honey\n" |
| | "30. plant_based_cream\n" |
| | "31. vegetarian_pizza\n" |
| | "32. dried fruits\n" |
| | "33. rice_or_pasta\n" |
| | "34. pork_steak\n" |
| | "35. wine\n" |
| | "36. soft_drink\n" |
| | "37. minced_beef\n" |
| | "38. steak\n" |
| | "39. frankfurter\n" |
| | "40. vegetarian_soup\n" |
| | "41. beef_frankfurter\n" |
| | "42. alcoholic_long_drink\n" |
| | "43. raspberry\n" |
| | "44. rice\n" |
| | "45. mandarin\n" |
| | "46. juice\n" |
| | "47. porridge\n" |
| | "48. fish_soup\n" |
| | "49. fish_hamburger\n" |
| | "50. fruit\n" |
| | "51. coffee\n" |
| | "52. plant_based_ice_cream\n" |
| | "53. beef_sausage\n" |
| | "54. minced_pork\n" |
| | "55. meat\n" |
| | "56. sweet_pastry\n" |
| | "57. ice_cream\n" |
| | "58. fish_salad\n" |
| | "59. flour\n" |
| | "60. snack\n" |
| | "61. fish_product\n" |
| | "62. cherry\n" |
| | "63. shellfish\n" |
| | "64. cream\n" |
| | "65. dessert\n" |
| | "66. cold_cut_chicken\n" |
| | "67. onion\n" |
| | "68. dark_bread\n" |
| | "69. plant_based_milk\n" |
| | "70. fermented_milk\n" |
| | "71. cocoa\n" |
| | "72. bell_pepper\n" |
| | "73. beef\n" |
| | "74. sausage\n" |
| | "75. dairy_product\n" |
| | "76. pork\n" |
| | "77. fish_fillet\n" |
| | "78. strong_alcoholic_beverage\n" |
| | "79. biscuit\n" |
| | "80. currant\n" |
| | "81. sweet_soup\n" |
| | "82. popcorn\n" |
| | "83. meat_sauce\n" |
| | "84. meat_pizza\n" |
| | "85. lemon\n" |
| | "86. plum\n" |
| | "87. grain_products\n" |
| | "88. sugar_honey_syrup\n" |
| | "89. vegetarian_hamburger\n" |
| | "90. jam\n" |
| | "91. vegetarian_stew\n" |
| | "92. beef_steak\n" |
| | "93. tomato\n" |
| | "94. meat_soup\n" |
| | "95. block_cheese\n" |
| | "96. potato_based\n" |
| | "97. potato\n" |
| | "98. flakes\n" |
| | "99. soft_cheese\n" |
| | "100. beef_product\n" |
| | "101. chocolate\n" |
| | "102. cauliflower\n" |
| | "103. nectarine\n" |
| | "104. banana\n" |
| | "105. apple\n" |
| | "106. alcoholic_beverage\n" |
| | "107. meat_salad\n" |
| | "108. pork_sausage\n" |
| | "109. lingonberry\n" |
| | "110. chicken_sausage\n" |
| | "111. pork_frankfurter\n" |
| | "112. minced_meat\n" |
| | "113. baked_goods\n" |
| | "114. broccoli\n" |
| | "115. cold_cut_beef\n" |
| | "116. meat_dish\n" |
| | "117. plant_cheese\n" |
| | "118. butter\n" |
| | "119. children_milk\n" |
| | "120. fish_sauce\n" |
| | "121. buttermilk\n" |
| | "122. cucumber\n" |
| | "123. meat_hamburger\n" |
| | "124. pasta\n" |
| | "125. tea\n" |
| | "126. berries\n" |
| | "127. milk\n" |
| | "128. fish_pizza\n" |
| | "129. fresh_cheese\n" |
| | "130. fish_stew\n" |
| | "131. sirup\n" |
| | "132. vegetarian_salad\n" |
| | "133. spice\n" |
| | "134. quark\n" |
| | "135. cold_cut_pork\n" |
| | "136. sauce_seasoning\n" |
| | "137. oatmeal\n" |
| | "138. sugar\n" |
| | "139. meat_product\n" |
| | "140. vegetarian_dish\n" |
| | "141. savory_pastry\n" |
| | "142. chicken_product\n" |
| | "143. vegetable_mix\n" |
| | "144. potato_chip\n" |
| | "145. muesli\n" |
| | "146. carrot\n" |
| | "147. plant_based_yogurt\n" |
| | "148. sweet\n" |
| | "149. eggs\n" |
| | "150. chicken_frankfurter\n" |
| | "151. bread\n" |
| | "152. fish_strips\n" |
| | "153. lettuce\n" |
| | "154. fish\n" |
| | "155. meat_stew\n" |
| | "156. hot_beverage\n" |
| | "157. cherry tomato\n" |
| | "158. fish_dish\n" |
| | "159. fat_oil\n" |
| | "160. grape\n" |
| | "161. plant_protein\n" |
| | "162. pastry\n" |
| | "163. oil\n" |
| | "164. cold_cut_meat\n\n" |
| | "OUTPUT FORMAT (single line):\n" |
| | "CATEGORY1,CATEGORY2,CATEGORY3\n" |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class GenerationConfig: |
| |
|
| | max_new_tokens: int = 256 |
| | temperature: float = 0.0 |
| | no_repeat_ngram_size: int = 6 |
| | repetition_penalty: float = 1.1 |
| | max_length: int = 4096 |
| | max_side: int = 512 |
| |
|
| |
|
| | _ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]") |
| |
|
| | ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL) |
| |
|
| | def _clean_text(s: str) -> str: |
| | s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ") |
| | s = re.sub(r"[\u200B-\u200F]", "", s) |
| | s = re.sub(r"\s+", " ", s).strip() |
| | return s |
| |
|
| |
|
| | _CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE) |
| |
|
| | def _extract_categories_from_prompt(prompt: str) -> Set[str]: |
| | cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)} |
| | return {c for c in cats if c} |
| |
|
| | def _clean_model_output(s: str) -> str: |
| | s = _clean_text(s).lower() |
| | s = _ALLOWED_RE.sub("", s) |
| | return s |
| |
|
| | def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]: |
| | s = _clean_model_output(raw) |
| | parts = [p.strip() for p in s.split(",") if p.strip()] |
| | out: List[str] = [] |
| | seen: Set[str] = set() |
| | for p in parts: |
| | if p in allowed and p not in seen: |
| | out.append(p) |
| | seen.add(p) |
| | return out |
| |
|
| |
|
| | class EndpointHandler: |
| | |
| |
|
| | def __init__(self, path: str = ".") -> None: |
| | """ |
| | Initializes the model and processor from the `path` directory, |
| | which contains the merged weights (pixtral-12b-foodwaste-merged). |
| | """ |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info("Initializing EndpointHandler on device: %s", self.device) |
| | self.processor = AutoProcessor.from_pretrained( |
| | BASE_MODEL_ID, |
| | trust_remote_code=True, |
| | ) |
| |
|
| | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| | self.model = LlavaForConditionalGeneration.from_pretrained( |
| | BASE_MODEL_ID, |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | device_map={"": self.device}, |
| | trust_remote_code=True, |
| | ) |
| | self.model.eval() |
| | logger.info("Model and processor successfully loaded from '%s'.", path) |
| |
|
| | |
| | tokenizer = getattr(self.processor, "tokenizer", None) |
| | if tokenizer is not None and tokenizer.pad_token_id is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
|
| | |
| | eos_candidates: List[int] = [] |
| | if self.model.config.eos_token_id is not None: |
| | eos_candidates.append(self.model.config.eos_token_id) |
| | if tokenizer is not None and tokenizer.eos_token_id is not None: |
| | eos_candidates.append(tokenizer.eos_token_id) |
| |
|
| | self.eos_token_ids: List[int] = list({i for i in eos_candidates}) |
| | if not self.eos_token_ids: |
| | raise ValueError("No EOS token id found on model or tokenizer.") |
| |
|
| | pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None) |
| | if pad_id is None and tokenizer is not None: |
| | pad_id = tokenizer.pad_token_id |
| | if pad_id is None: |
| | pad_id = self.eos_token_ids[0] |
| |
|
| | self.pad_token_id: int = pad_id |
| |
|
| | self.gen_config = GenerationConfig() |
| | logger.info( |
| | "Generation config: max_new_tokens=%d, temperature=%.3f", |
| | self.gen_config.max_new_tokens, |
| | self.gen_config.temperature, |
| | ) |
| | |
| | self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(DEFAULT_PROMPT) |
| | logger.info("Extracted %d categories from DEFAULT_PROMPT.", len(self.default_allowed_categories)) |
| | |
| |
|
| | |
| | @staticmethod |
| | def _decode_image(image_b64: str) -> Image.Image: |
| | |
| | try: |
| | img_bytes = base64.b64decode(image_b64) |
| | img = Image.open(BytesIO(img_bytes)).convert("RGB") |
| | return img |
| | except Exception as exc: |
| | raise ValueError(f"Could not decode base64 image: {exc}") from exc |
| |
|
| | @staticmethod |
| | def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: |
| | w, h = img.size |
| | m = max(w, h) |
| | if m <= max_side: |
| | return img |
| | scale = max_side / m |
| | |
| | return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) |
| |
|
| | def _build_chat_text(self, prompt: str) -> str: |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": prompt}, |
| | {"type": "image"}, |
| | ], |
| | } |
| | ] |
| |
|
| | chat_text = self.processor.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | tokenize=False, |
| | ) |
| | return chat_text |
| |
|
| | |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | inputs = data.get("inputs", data) |
| | debug = bool(inputs.get("debug", False)) |
| | |
| | prompt: str = inputs.get("prompt") or DEFAULT_PROMPT |
| | allowed_categories = _extract_categories_from_prompt(prompt) or self.default_allowed_categories |
| | image_b64: Optional[str] = inputs.get("image") |
| | if not image_b64: |
| | raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") |
| | |
| | image = self._decode_image(image_b64) |
| | |
| | max_length = int(inputs.get("max_length", self.gen_config.max_length)) |
| | max_side = int(inputs.get("max_side", self.gen_config.max_side)) |
| | image = self._resize_max_side(image, max_side=max_side) |
| | |
| | max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens)) |
| | temperature = float(inputs.get("temperature", self.gen_config.temperature)) |
| | |
| | chat_text = self._build_chat_text(prompt) |
| | |
| | enc = self.processor( |
| | text=[chat_text], |
| | images=[image], |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=max_length, |
| | padding=False, |
| | ) |
| | |
| | prompt_len = int(enc["input_ids"].shape[1]) |
| | tokens_left = max(1, max_length - prompt_len) |
| | max_new_tokens = min(max_new_tokens, tokens_left) |
| | |
| | if debug: |
| | logger.info("===== TOKEN BUDGET DEBUG (inference) =====") |
| | logger.info("img size: %s", getattr(image, "size", None)) |
| | logger.info("max_length: %d", max_length) |
| | logger.info("prompt_len: %d", prompt_len) |
| | logger.info("tokens_left_for_answer(approx): %d", tokens_left) |
| | tok = getattr(self.processor, "tokenizer", None) |
| | if tok is not None: |
| | ids = enc["input_ids"][0].tolist() |
| | logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False)) |
| | logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False)) |
| | logger.info("=========================================") |
| | |
| | enc = {k: v.to(self.device) for k, v in enc.items()} |
| | if "pixel_values" in enc: |
| | enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype) |
| | |
| | gen_kwargs: Dict[str, Any] = { |
| | "max_new_tokens": max_new_tokens, |
| | "do_sample": temperature > 0.0, |
| | "eos_token_id": self.eos_token_ids, |
| | "pad_token_id": self.pad_token_id, |
| | "no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size, |
| | "repetition_penalty": self.gen_config.repetition_penalty, |
| | } |
| | if temperature > 0.0: |
| | gen_kwargs["temperature"] = temperature |
| | |
| | with torch.inference_mode(): |
| | output_ids = self.model.generate(**enc, **gen_kwargs) |
| | |
| | generated_only = output_ids[:, enc["input_ids"].shape[1]:] |
| | generated_text = self.processor.batch_decode( |
| | generated_only, |
| | skip_special_tokens=True, |
| | )[0].strip() |
| | |
| | cats = _parse_and_validate_categories(generated_text, allowed_categories) |
| | if not cats: |
| | cats = ["food_waste"] if "food_waste" in allowed_categories else [] |
| | generated_text = ",".join(cats) |
| |
|
| | logger.info("Generated text: %s", generated_text) |
| | return {"generated_text": generated_text, "categories": cats} |
| |
|