salma-remyx commited on
Commit
ae0cb85
·
verified ·
1 Parent(s): 1093fc5

Upload mllava/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mllava/utils.py +188 -0
mllava/utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ from .modeling_llava import LlavaForConditionalGeneration
4
+ from .processing_llava import MLlavaProcessor
5
+ # from ..conversation import conv_mllava_v1_mmtag as default_conv
6
+ from ..conversation import conv_mllava_v1 as default_conv, conv_templates
7
+
8
+ from typing import List, Tuple, Union, Tuple
9
+
10
+ def chat_mllava(
11
+ text:str,
12
+ images: List[Union[PIL.Image.Image, str]],
13
+ model:LlavaForConditionalGeneration,
14
+ processor:MLlavaProcessor,
15
+ max_input_length:int=None,
16
+ history:List[dict]=None,
17
+ **kwargs) -> Tuple[str, List[dict]]:
18
+ """
19
+ Chat with the Mllava model
20
+ Args:
21
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
22
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
23
+ model: LlavaForConditionalGeneration, the model to be used
24
+ processor: MLlavaProcessor, the processor to be used
25
+ max_input_length: int, the maximum input length
26
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
27
+ kwargs: dict, the generation kwargs
28
+ Returns:
29
+ Tuple[str, List[dict]], the generated text and the history of the conversation
30
+
31
+
32
+ """
33
+ if "llama-3" in model.language_model.name_or_path.lower():
34
+ conv = conv_templates['llama_3']
35
+ terminators = [
36
+ processor.tokenizer.eos_token_id,
37
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
38
+ ]
39
+ else:
40
+ conv = default_conv
41
+ terminators = None
42
+ kwargs["eos_token_id"] = terminators
43
+ conv = conv.copy()
44
+ conv.messages = []
45
+ if history is not None:
46
+ for message in history:
47
+ assert message["role"] in conv.roles
48
+ conv.append_message(message["role"], message["text"])
49
+ if text:
50
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
51
+ conv.append_message(conv.roles[0], text)
52
+ conv.append_message(conv.roles[1], "")
53
+ history.append({"role": conv.roles[0], "text": text})
54
+ history.append({"role": conv.roles[1], "text": ""})
55
+ else:
56
+ if conv.messages[-1][0] == conv.roles[1]:
57
+ assert conv.messages[-1][1] == "", "No user message should be provided"
58
+ else:
59
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
60
+ conv.append_message(conv.roles[0], "")
61
+ history.append({"role": conv.roles[0], "text": ""})
62
+ else:
63
+ history = []
64
+ history.append({"role": conv.roles[0], "text": text})
65
+ history.append({"role": conv.roles[1], "text": ""})
66
+ conv.append_message(conv.roles[0], text)
67
+ conv.append_message(conv.roles[1], "")
68
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
69
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
70
+
71
+ prompt = conv.get_prompt()
72
+ if images:
73
+ for i in range(len(images)):
74
+ if isinstance(images[i], str):
75
+ images[i] = PIL.Image.open(images[i]).convert("RGB")
76
+
77
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
78
+ for k, v in inputs.items():
79
+ if v is not None:
80
+ if isinstance(v, torch.Tensor):
81
+ inputs[k] = v.to(model.device)
82
+ elif isinstance(v, list):
83
+ inputs[k] = [x.to(model.device) for x in v]
84
+ else:
85
+ raise ValueError(f"Invalid input type: {type(v)}")
86
+
87
+
88
+ output_ids = model.generate(**inputs, **kwargs)
89
+ output_ids = output_ids[0]
90
+
91
+ # remove the input tokens
92
+ generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
93
+ generated_text = processor.decode(generated_ids, skip_special_tokens=True)
94
+
95
+ history[-1]["text"] = generated_text
96
+
97
+ return generated_text, history
98
+
99
+
100
+ def chat_mllava_stream(
101
+ text:str,
102
+ images: List[Union[PIL.Image.Image, str]],
103
+ model:LlavaForConditionalGeneration,
104
+ processor:MLlavaProcessor,
105
+ max_input_length:int=None,
106
+ history:List[dict]=None,
107
+ **kwargs) -> Tuple[str, List[dict]]:
108
+ """
109
+ Chat with the Mllava model
110
+ Args:
111
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
112
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
113
+ model: LlavaForConditionalGeneration, the model to be used
114
+ processor: MLlavaProcessor, the processor to be used
115
+ max_input_length: int, the maximum input length
116
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
117
+ kwargs: dict, the generation kwargs
118
+ Returns:
119
+ Tuple[str, List[dict]], the generated text and the history of the conversation
120
+
121
+
122
+ """
123
+ if "llama-3" in model.language_model.name_or_path.lower():
124
+ conv = conv_templates['llama_3']
125
+ terminators = [
126
+ processor.tokenizer.eos_token_id,
127
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
128
+ ]
129
+ else:
130
+ conv = default_conv
131
+ terminators = None
132
+ kwargs["eos_token_id"] = terminators
133
+ conv = conv.copy()
134
+ conv.messages = []
135
+ if history is not None:
136
+ for message in history:
137
+ assert message["role"] in conv.roles
138
+ conv.append_message(message["role"], message["text"])
139
+ if text:
140
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
141
+ conv.append_message(conv.roles[0], text)
142
+ conv.append_message(conv.roles[1], "")
143
+ history.append({"role": conv.roles[0], "text": text})
144
+ history.append({"role": conv.roles[1], "text": ""})
145
+ else:
146
+ if conv.messages[-1][0] == conv.roles[1]:
147
+ assert conv.messages[-1][1] == "", "No user message should be provided"
148
+ else:
149
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
150
+ conv.append_message(conv.roles[0], "")
151
+ history.append({"role": conv.roles[0], "text": ""})
152
+ else:
153
+ history = []
154
+ history.append({"role": conv.roles[0], "text": text})
155
+ history.append({"role": conv.roles[1], "text": ""})
156
+ conv.append_message(conv.roles[0], text)
157
+ conv.append_message(conv.roles[1], "")
158
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
159
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
160
+
161
+ prompt = conv.get_prompt()
162
+ if images:
163
+ for i in range(len(images)):
164
+ if isinstance(images[i], str):
165
+ images[i] = PIL.Image.open(images[i])
166
+ images[i] = images[i].convert("RGB")
167
+
168
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
169
+ print(processor.tokenizer.decode(inputs["input_ids"][0]))
170
+ for k, v in inputs.items():
171
+ if v is not None:
172
+ if isinstance(v, torch.Tensor):
173
+ inputs[k] = v.to(model.device)
174
+ elif isinstance(v, list):
175
+ inputs[k] = [x.to(model.device) for x in v]
176
+ else:
177
+ raise ValueError(f"Invalid input type: {type(v)}")
178
+
179
+ from transformers import TextIteratorStreamer
180
+ from threading import Thread
181
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
182
+ kwargs["streamer"] = streamer
183
+ inputs.update(kwargs)
184
+ thread = Thread(target=model.generate, kwargs=inputs)
185
+ thread.start()
186
+ for _output in streamer:
187
+ history[-1]["text"] += _output
188
+ yield history[-1]["text"], history