""" File: vlm.py Description: Vision language model utility functions. Author: Didier Guillevic Date: 2025-05-08 """ from transformers import AutoProcessor from transformers import Mistral3ForConditionalGeneration from transformers import TextIteratorStreamer from threading import Thread import re import time import torch import base64 import spaces import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # # Load the model: OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym # model_id = "OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym" device = 'cuda' if torch.cuda.is_available() else 'cpu' processor = AutoProcessor.from_pretrained(model_id) model = Mistral3ForConditionalGeneration.from_pretrained( model_id, #_attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).eval().to(device) # # Encode images as base64 # def encode_image(image_path): """Encode the image to base64.""" try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') except FileNotFoundError: print(f"Error: The file {image_path} was not found.") return None except Exception as e: # Added general exception handling print(f"Error: {e}") return None # # Build messages # def normalize_message_content(msg: dict) -> dict: content = msg.get("content") # Case 1: Already in expected format if isinstance(content, list) and all(isinstance(item, dict) for item in content): return {"role": msg["role"], "content": content} # Case 2: String (assume text) if isinstance(content, str): return {"role": msg["role"], "content": [{"type": "text", "text": content}]} # Case 3: Tuple with image path(s) if isinstance(content, tuple): return { "role": msg["role"], "content": [ {"type": "image", "image": encode_image(path)} # your `encode_image()` function for path in content if isinstance(path, str) ] } logger.warning(f"Unexpected content format in message: {msg}") return {"role": msg["role"], "content": [{"type": "text", "text": str(content)}]} def build_messages(message: dict, history: list[dict]): """Build messages given message & history from a **multimodal** chat interface. Args: message: dictionary with keys: 'text', 'files' history: list of dictionaries Returns: list of messages (to be sent to the model) """ logger.info(f"{message=}") logger.info(f"{history=}") # Get the user's text and list of images user_text = message.get("text", "") user_images = message.get("files", []) # List of images # Build the user message's content from the provided message user_content = [] if user_text: user_content.append({"type": "text", "text": user_text}) for image in user_images: user_content.append( { "type": "image", "image": f"data:image/jpeg;base64,{encode_image(image)}" } ) # Normalize existing history content messages = [normalize_message_content(msg) for msg in history] # Append new user message messages.append({'role': 'user', 'content': user_content}) logger.info(f"{messages=}") return messages # # stream response # @spaces.GPU @torch.inference_mode() def stream_response( messages: list[dict], max_new_tokens: int=1_024, temperature: float=0.15 ): """Stream the model's response to the chat interface. Args: messages: list of messages to send to the model """ # Generate model's response inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.float16) # Generate streamer = TextIteratorStreamer( processor, skip_prompt=True, skip_special_tokens=True) generation_args = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=0.9, do_sample=True ) thread = Thread(target=model.generate, kwargs=generation_args) thread.start() partial_message = "" for new_text in streamer: partial_message += new_text yield partial_message