Spaces:
Runtime error
Runtime error
| """ | |
| File: vlm.py | |
| Description: Vision language model utility functions. | |
| Author: Didier Guillevic | |
| Date: 2025-03-16 | |
| """ | |
| import spaces | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import torch | |
| # | |
| # Load the model: google/gemma-3-4b-it | |
| # | |
| device = 'mps' | |
| model_id = "google/gemma-3-4b-it" | |
| processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left") | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| # | |
| # Build messages | |
| # | |
| def build_messages(message: dict, history: list[tuple]): | |
| """Build messages given message & history from a **multimodal** chat interface. | |
| Args: | |
| message: dictionary with keys: 'text', 'files' | |
| history: list of tuples with (message, response) | |
| Returns: | |
| list of messages (to be sent to the model) | |
| """ | |
| # Get the user's text and list of images | |
| user_text = message.get("text", "") | |
| user_images = message.get("files", []) # List of images | |
| # Build the message list including history | |
| messages = [] | |
| combined_user_input = [] #Combine images and text if found in same turn. | |
| for user_turn, bot_turn in history: | |
| if isinstance(user_turn, tuple): # Image input | |
| image_content = [{"type": "image", "url": image_url} for image_url in user_turn] | |
| combined_user_input.extend(image_content) | |
| elif isinstance(user_turn, str): #Text input | |
| combined_user_input.append({"type":"text", "text": user_turn}) | |
| if combined_user_input and bot_turn: | |
| messages.append({'role': 'user', 'content': combined_user_input}) | |
| messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]}) | |
| combined_user_input = [] #reset the combined user input. | |
| # Build the user message's content from the provided message | |
| user_content = [] | |
| if user_text: | |
| user_content.append({"type": "text", "text": user_text}) | |
| for image in user_images: | |
| user_content.append({"type": "image", "url": image}) | |
| messages.append({'role': 'user', 'content': user_content}) | |
| return messages | |
| # | |
| # Streaming response | |
| # | |
| def stream_response(messages: list[dict]): | |
| """Stream the model's response to the chat interface. | |
| Args: | |
| messages: list of messages to send to the model | |
| """ | |
| # Generate model's response | |
| inputs = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(model.device, dtype=torch.bfloat16) | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=1_024, | |
| do_sample=False | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_message = "" | |
| for new_text in streamer: | |
| partial_message += new_text | |
| yield partial_message | |
| # | |
| # Response (non-streaming) | |
| # | |
| def get_response(messages: list[dict]): | |
| """Get the model's response. | |
| Args: | |
| messages: list of messages to send to the model | |
| """ | |
| # Generate model's response | |
| inputs = processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt" | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| decoded = processor.decode(generation, skip_special_tokens=True) | |
| return decoded | |