File size: 12,059 Bytes
c296fdd
 
ae51174
94f0f9e
c296fdd
 
 
ae51174
 
 
c296fdd
ae51174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f0f9e
ae51174
94f0f9e
ae51174
 
94f0f9e
 
ae51174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c296fdd
 
ae51174
 
 
 
 
 
 
c296fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae51174
 
 
 
 
 
 
c296fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae51174
 
 
 
 
 
 
 
 
c296fdd
 
 
 
 
 
 
 
 
 
 
ae51174
 
 
 
 
 
 
 
 
c296fdd
 
 
 
 
 
 
ae51174
 
 
 
 
 
 
c296fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

from typing import Dict, Any
from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
from aiflows.utils.general_helpers import encode_image,encode_from_buffer
import cv2


class VisionAtomicFlow(ChatAtomicFlow):
    """ This class implements the atomic flow for the VisionFlowModule. It is a flow that, given a textual input, and a set of images and/or videos, generates a textual output.
    It uses the litellm library as a backend. See https://docs.litellm.ai/docs/providers for supported models and APIs.
    
    *Configuration Parameters*:

    - `name` (str): The name of the flow. Default: "VisionAtomicFlow"
    - `description` (str): A description of the flow. This description is used to generate the help message of the flow.
    Default: "A flow that, given a textual input, and a set of images and/or videos, generates a textual output."
    - enable_cache (bool): If True, the flow will use the cache. Default: True
    - `n_api_retries` (int): The number of times to retry the API call in case of failure. Default: 6
    - `wait_time_between_api_retries` (int): The time to wait between API retries in seconds. Default: 20
    - `system_name` (str): The name of the system. Default: "system"
    - `user_name` (str): The name of the user. Default: "user"
    - `assistant_name` (str): The name of the assistant. Default: "assistant"
    - `backend` (Dict[str, Any]): The configuration of the backend which is used to fetch api keys. Default: LiteLLMBackend with the
    default parameters of ChatAtomicFlow (see Flow card of ChatAtomicFlowModule). Except for the following parameters
    whose default value is overwritten:
        - `api_infos` (List[Dict[str, Any]]): The list of api infos. Default: No default value, this parameter is required.
        - `model_name` (Union[Dict[str,str],str]): The name of the model to use. 
        When using multiple API providers, the model_name can be a dictionary of the form 
        {"provider_name": "model_name"}.
        Default: "gpt-4-vision-preview" (the name needs to follow the name of the model in litellm  https://docs.litellm.ai/docs/providers).
        - `n` (int) : The number of answers to generate. Default: 1
        - `max_tokens` (int): The maximum number of tokens to generate. Default: 2000
        - `temperature` (float): The temperature to use. Default: 0.3
        - `top_p` (float): An alternative to sampling with temperature. It instructs the model to consider the results of
        the tokens with top_p probability. Default: 0.2
        - `frequency_penalty` (float): The higher this value, the more likely the model will repeat itself. Default: 0.0
        - `presence_penalty` (float): The higher this value, the less likely the model will talk about a new topic. Default: 0.0
    - `system_message_prompt_template` (Dict[str,Any]): The template of the system message. It is used to generate the system message.
    By default its of type aiflows.prompt_template.JinjaPrompt.
    None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the system prompt. 
    Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
    - `init_human_message_prompt_template` (Dict[str,Any]): The prompt template of the human/user message used to initialize the conversation
    (first time in). It is used to generate the human message. It's passed as the user message to the LLM.
    By default its of type aiflows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one
    wants to use the init_human_message_prompt_template. Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt.
    - `previous_messages` (Dict[str,Any]): Defines which previous messages to include in the input of the LLM. Note that if `first_k`and `last_k` are both none,
    all the messages of the flows's history are added to the input of the LLM. Default:
        - `first_k` (int): If defined, adds the first_k earliest messages of the flow's chat history to the input of the LLM. Default: None
        - `last_k` (int): If defined, adds the last_k latest messages of the flow's chat history to the input of the LLM. Default: None
    - Other parameters are inherited from the default configuration of ChatAtomicFlow (see Flow card of ChatAtomicFlowModule).
    
    *Input Interface Initialized (Expected input the first time in flow)*:
    
    - `query` (str): The textual query to run the model on.
    - `data` (Dict[str, Any]): The data (images or video) to run the model on. It can contain the following keys:
        - `images` (List[Dict[str, Any]]): A list of images to run the model on. Each image is a dictionary that contains the following keys:
            - `type` (str): The type of the image. It can be "local_path" or "url".
            - `image` (str): The image. If type is "local_path", it is a local path to the image. If type is "url", it is a url to the image.
        - `video` (Dict[str, Any]): A video to run the model on. It is a dictionary that contains the following keys:
            - `video_path` (str): The path to the video.
            - `resize` (int): The resize we want to apply on the frames of the video.
            - `frame_step_size` (int): The step size between the frames of the video (to send to the model).
            - `start_frame` (int): The start frame of the video (to send to the model).
            - `end_frame` (int): The last frame of the video (to send to the model).
    
    *Input Interface (Expected input the after the first time in flow)*:
    
    - `query` (str): The textual query to run the model on.
    - `data` (Dict[str, Any]): The data (images or video) to run the model on. It can contain the following keys:
        - `images` (List[Dict[str, Any]]): A list of images to run the model on. Each image is a dictionary that contains the following keys:
            - `type` (str): The type of the image. It can be "local_path" or "url".
            - `image` (str): The image. If type is "local_path", it is a local path to the image. If type is "url", it is a url to the image.
        - `video` (Dict[str, Any]): A video to run the model on. It is a dictionary that contains the following keys:
            - `video_path` (str): The path to the video.
            - `resize` (int): The resize we want to apply on the frames of the video.
            - `frame_step_size` (int): The step size between the frames of the video (to send to the model).
            - `start_frame` (int): The start frame of the video (to send to the model).
            - `end_frame` (int): The last frame of the video (to send to the model).
    
    *Output Interface*:
    
        - `api_output`s (str): The api output of the flow to the query and data
        
    """
    @staticmethod      
    def get_image(image):
        """ This method returns an image in the appropriate format for API.
        
        :param image: The image dictionary.
        :type image: Dict[str, Any]
        :return: The image url.
        :rtype: Dict[str, Any]
        """
        extension_dict = {
            "jpg": "jpeg",
            "jpeg": "jpeg",
            "png": "png",
            "webp": "webp",
            "gif": "gif"
        }
        supported_image_types = ["local_path","url"]
        assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} "
        
        processed_image = None
        url = None
        if image["type"] == "local_path":
            processed_image = encode_image(image.get("image"))
            image_extension_type = image.get("image").split(".")[-1]
            url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}"
            
        elif image["type"] == "url":
            processed_image = image
            url = image.get("image")
            
        return {"type": "image_url", "image_url": {"url": url}}
    
    @staticmethod
    def get_video(video):
        """ This method returns the video in the appropriate format for API.
        
        :param video: The video dictionary.
        :type video: Dict[str, Any]
        :return: The video url.
        :rtype: Dict[str, Any]
        """
        video_path = video["video_path"]
        resize = video.get("resize",768)
        frame_step_size = video.get("frame_step_size",10)
        start_frame = video.get("start_frame",0)
        end_frame = video.get("end_frame",None)
        base64Frames = []
        video = cv2.VideoCapture(video_path)
        while video.isOpened():
            success,frame = video.read()
            if not success:
                break
            _,buffer = cv2.imencode(".jpg",frame)
            base64Frames.append(encode_from_buffer(buffer))
        video.release()
        return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size])

    @staticmethod
    def get_user_message(prompt_template, input_data: Dict[str, Any]):
        """ This method constructs the user message to be passed to the API.
        
        :param prompt_template: The prompt template to use.
        :type prompt_template: PromptTemplate
        :param input_data: The input data.
        :type input_data: Dict[str, Any]
        :return: The constructed user message (images , videos and text).
        :rtype: Dict[str, Any]
        """
        content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data)
        media_data = input_data["data"]
        if "video" in media_data:
            content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])]
        if "images" in media_data:
            images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]]
            content.extend(images)
        return content
    
    @staticmethod
    def _get_message(prompt_template, input_data: Dict[str, Any]):
        """ This method constructs the textual message to be passed to the API.
        
        :param prompt_template: The prompt template to use.
        :type prompt_template: PromptTemplate
        :param input_data: The input data.
        :type input_data: Dict[str, Any]
        :return: The constructed textual message.
        :rtype: Dict[str, Any]
        """
        template_kwargs = {}
        for input_variable in prompt_template.input_variables:
            template_kwargs[input_variable] = input_data[input_variable]
        msg_content = prompt_template.format(**template_kwargs)
        return [{"type": "text", "text": msg_content}]

    def _process_input(self, input_data: Dict[str, Any]):
        """ This method processes the input data (prepares the messages to send to the API).
        
        :param input_data: The input data.
        :type input_data: Dict[str, Any]
        :return: The processed input data.
        :rtype: Dict[str, Any]
        """
        if self._is_conversation_initialized():
            # Construct the message using the human message prompt template
            user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)

        else:
            # Initialize the conversation (add the system message, and potentially the demonstrations)
            self._initialize_conversation(input_data)
            if getattr(self, "init_human_message_prompt_template", None) is not None:
                # Construct the message using the query message prompt template
                user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data)
            else:
                user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)

        self._state_update_add_chat_message(role=self.flow_config["user_name"],
                                            content=user_message_content)