File size: 3,593 Bytes
73f064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c77c4
73f064f
 
9695e26
 
 
 
3965865
9695e26
73f064f
57c77c4
 
73f064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a64b07
73f064f
 
 
 
 
 
 
 
3a64b07
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
""" vlm.py

Utilities for working with Vision Language Models

:author: Didier Guillevic
:email: didier@guillevic.net
:creation: 2024-12-28
"""

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import os
from mistralai import Mistral
import base64

#
# Mistral AI client
#
api_key = os.environ["MISTRAL_API_KEY"]
client = Mistral(api_key=api_key)
model_id = "mistral-small-latest" # 128k context window

#
# Encode images as base64
#
def encode_image(image_path):
    """Encode the image to base64."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:  # Added general exception handling
        print(f"Error: {e}")
        return None


#
# Build messages
#
def build_messages(message: dict, history: list[tuple]):
    """Build messages given message & history from a **multimodal** chat interface.

    Args:
        message: dictionary with keys: 'text', 'files'
        history: list of tuples with (message, response)
    
    Returns:
        list of messages (to be sent to the model)
    """
    logger.info(f"{message=}")
    logger.info(f"{history=}")
    # Get the user's text and list of images
    user_text = message.get("text", "")
    user_images = message.get("files", [])  # List of images

    # Build the message list including history
    messages = []
    combined_user_input = [] # Combine images and text if found in same turn.
    for user_turn, bot_turn in history:
        if isinstance(user_turn, tuple):  # Image input
            image_content = [
                {
                    "type": "image_url",
                    "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
                } for image in user_turn
            ]
            combined_user_input.extend(image_content)
        elif isinstance(user_turn, str): # Text input
            combined_user_input.append({"type": "text", "text": user_turn})
        if combined_user_input and bot_turn:
            messages.append({'role': 'user', 'content': combined_user_input})
            messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]})
            combined_user_input = [] #reset the combined user input.
    
    # Build the user message's content from the provided message
    user_content = []
    if user_text:
        user_content.append({"type": "text", "text": user_text})
    for image in user_images:
        user_content.append(
            {
                "type": "image_url",
                "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
            }
        )
    
    messages.append({'role': 'user', 'content': user_content})
    logger.info(f"{messages=}")

    return messages

#
# get response
#
def get_response(messages: list[dict]):
    """Get the model's response.
    
    Args:
        messages: list of messages to send to the model
    """
    response = client.chat.complete(model=model_id, messages=messages)
    logger.info(f"{response=}")
    return response.choices[0].message.content

#
# stream response
#
def stream_response(messages: list[dict]):
    """Stream the model's response.
    
    Args:
        messages: list of messages to send to the model
    """
    response = ""
    for chunk in client.chat.stream(model=model_id, messages=messages):
        response += chunk.data.choices[0].delta.content
        yield response