NGUYEN, Xuan Phi commited on
Commit
c1519e7
1 Parent(s): 7eb44d4

add sea lava16

Browse files
multipurpose_chatbot/engines/__init__.py CHANGED
@@ -9,6 +9,7 @@ BACKENDS = [
9
  # "llava_llama_cpp",
10
  "debug",
11
  "sealmmm_transformers",
 
12
  ]
13
 
14
  ENGINE_LOADED = False
@@ -42,6 +43,9 @@ def load_multipurpose_chatbot_engine(backend: str):
42
  elif backend == 'sealmmm_transformers':
43
  from .sealmmm_engine import SeaLMMMv0Engine
44
  model_engine = SeaLMMMv0Engine()
 
 
 
45
  else:
46
  raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
47
 
 
9
  # "llava_llama_cpp",
10
  "debug",
11
  "sealmmm_transformers",
12
+ "sealava16_transformers"
13
  ]
14
 
15
  ENGINE_LOADED = False
 
43
  elif backend == 'sealmmm_transformers':
44
  from .sealmmm_engine import SeaLMMMv0Engine
45
  model_engine = SeaLMMMv0Engine()
46
+ elif backend == 'sealava16_transformers':
47
+ from .sealava16_transformers_engine import SeaLlava16Engine
48
+ model_engine = SeaLlava16Engine()
49
  else:
50
  raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
51
 
multipurpose_chatbot/engines/sealava16_transformers_engine.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+ import torch
5
+ import gradio as gr
6
+ from typing import Any, Iterator
7
+ from typing import Iterator, List, Optional, Tuple
8
+ import filelock
9
+ import glob
10
+ import json
11
+ import time
12
+ from gradio.routes import Request
13
+ from gradio.utils import SyncToAsyncIterator, async_iteration
14
+ from gradio.helpers import special_args
15
+ import anyio
16
+ from typing import AsyncGenerator, Callable, Literal, Union, cast
17
+
18
+ from gradio_client.documentation import document, set_documentation_group
19
+
20
+ from typing import List, Optional, Union, Dict, Tuple
21
+ from tqdm.auto import tqdm
22
+ from huggingface_hub import snapshot_download
23
+
24
+ from gradio.components import Button
25
+ from gradio.events import Dependency, EventListenerMethod
26
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
27
+ import types
28
+ import sys
29
+ from .base_engine import BaseEngine
30
+ from .transformers_engine import TransformersEngine, NewGenerationMixin
31
+
32
+ from ..configs import (
33
+ STREAM_CHECK_MULTIPLE,
34
+ STREAM_YIELD_MULTIPLE,
35
+ )
36
+
37
+ CODE_PATH = os.environ.get("CODE_PATH", "")
38
+ MODEL_PATH = os.environ.get("MODEL_PATH", "")
39
+
40
+ IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]"
41
+ IMAGE_TOKEN = "<|image|>"
42
+
43
+ IMAGE_LENGTH = 576
44
+ MAX_PACHES = 5
45
+
46
+
47
+ BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
48
+ BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
49
+ LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
50
+ KEYWORDS = os.environ.get("KEYWORDS", "").strip()
51
+ KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
52
+ KEYWORDS = [x.lower() for x in KEYWORDS]
53
+
54
+ LANG_BLOCK_MESSAGE = """Unsupported language."""
55
+
56
+ KEYWORD_BLOCK_MESSAGE = "Invalid request."
57
+
58
+
59
+ def _detect_lang(text):
60
+ # Disable language that may have safety risk
61
+ from langdetect import detect as detect_lang
62
+ dlang = None
63
+ try:
64
+ dlang = detect_lang(text)
65
+ except Exception as e:
66
+ if "No features in text." in str(e):
67
+ return "en"
68
+ else:
69
+ return "zh"
70
+ return dlang
71
+
72
+
73
+ def block_lang(
74
+ message: str,
75
+ history: List[Tuple[str, str]] = None,
76
+ ) -> str:
77
+ # relieve history base block
78
+ if len(BLOCK_LANGS) == 0:
79
+ return False
80
+
81
+ if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
82
+ return True
83
+ else:
84
+ _lang = _detect_lang(message)
85
+ if _lang in BLOCK_LANGS:
86
+ # print(f'Detect blocked {_lang}: {message}')
87
+ return True
88
+ else:
89
+ return False
90
+
91
+ def safety_check(text, history=None, ) -> Optional[str]:
92
+ """
93
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
94
+ This provides an additional security measure to enhance safety and compliance with local regulations.
95
+ """
96
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
97
+ return KEYWORD_BLOCK_MESSAGE
98
+
99
+ if len(BLOCK_LANGS) > 0:
100
+ if block_lang(text, history):
101
+ return LANG_BLOCK_MESSAGE
102
+
103
+ return None
104
+
105
+
106
+ def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
107
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
108
+ return KEYWORD_BLOCK_MESSAGE
109
+ if len(BLOCK_LANGS) > 0:
110
+ import re
111
+ delimiter = delimiter or (r"</s>\n<\|im_start\|>user\n", r"</s>\n<\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
112
+ turns = re.split(r"|".join(delimiter), text)
113
+ turns = [t for t in turns if t.strip() != '']
114
+ for t in turns:
115
+ if block_lang(t):
116
+ return LANG_BLOCK_MESSAGE
117
+ return None
118
+
119
+
120
+ def is_check_safety():
121
+ return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0
122
+
123
+
124
+ def safety_check_conversation(conversation) -> Optional[str]:
125
+ """
126
+ Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
127
+ This provides an additional security measure to enhance safety and compliance with local regulations.
128
+ """
129
+ texts = [c['content'] for c in conversation]
130
+ for text in texts:
131
+ if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
132
+ return KEYWORD_BLOCK_MESSAGE
133
+
134
+ if len(BLOCK_LANGS) > 0:
135
+ if block_lang(text):
136
+ return LANG_BLOCK_MESSAGE
137
+ return None
138
+
139
+
140
+ class SeaLlava16Engine(TransformersEngine):
141
+
142
+ @property
143
+ def image_token(self):
144
+ return IMAGE_TOKEN
145
+
146
+ @property
147
+ def max_position_embeddings(self) -> int:
148
+ return self._model.config.max_position_embeddings
149
+
150
+ @property
151
+ def tokenizer(self):
152
+ return self._tokenizer
153
+
154
+ @property
155
+ def processor(self):
156
+ return self._processor
157
+
158
+ def load_model(self):
159
+ from transformers import AutoProcessor
160
+ import sys
161
+ # caution: path[0] is reserved for script path (or '' in REPL)
162
+ sys.path.append(CODE_PATH)
163
+
164
+
165
+ from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
166
+ from transformers.models.llava_next.processing_llava_next import LlavaNextProcessor
167
+ model_path = MODEL_PATH
168
+ print(f'Loading model from {model_path}')
169
+
170
+ print(f'model_path={model_path}')
171
+ if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
172
+ os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
173
+
174
+ self._processor = LlavaNextProcessor.from_pretrained(model_path)
175
+ self._model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
176
+
177
+ self._model.sample_old = self._model.sample
178
+ self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
179
+
180
+ self._tokenizer = self._processor.tokenizer
181
+ print(self._model)
182
+ print(f"{self.max_position_embeddings=}")
183
+
184
+ def get_multimodal_tokens(self, full_prompt, image_paths=None):
185
+ num_tokens = len(self.tokenizer.encode(full_prompt))
186
+ for image_path in image_paths:
187
+ num_tokens += IMAGE_LENGTH * MAX_PACHES
188
+ return num_tokens
189
+
190
+ def maybe_raise_safety(self, message, gen_index=-1):
191
+ if is_check_safety():
192
+ if gen_index < 0:
193
+ message_safety = safety_check_conversation_string(message)
194
+ if message_safety is not None:
195
+ raise gr.Error(message_safety)
196
+ else:
197
+ if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
198
+ message_safety = safety_check_conversation_string(message)
199
+ if message_safety is not None:
200
+ raise gr.Error(message_safety)
201
+
202
+ def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
203
+ from transformers.generation.utils import GenerationConfig
204
+ from PIL import Image
205
+ image_paths = kwargs.get("image_paths", None)
206
+ image_paths = image_paths or []
207
+
208
+ images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
209
+
210
+ with torch.no_grad():
211
+ # inputs = self.processor(prompt, images, return_tensors='pt', concat_images=True)
212
+ inputs = self.processor(prompt, images, return_tensors='pt')
213
+ # inputs = inputs.to("cuda", torch.bfloat16)
214
+ inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None}
215
+ num_tokens = self.get_multimodal_tokens(prompt, image_paths)
216
+ # non-streaming generation
217
+ # output = self._model.generate(
218
+ # **inputs,
219
+ # do_sample=True,
220
+ # temperature=temperature,
221
+ # max_new_tokens=max_tokens,
222
+ # pad_token_id=self.processor.tokenizer.pad_token_id,
223
+ # )
224
+ # # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
225
+ # full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
226
+ # response = full_output_text.split("<|im_start|>assistant\n")[-1]
227
+ # num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
228
+ # print(prompt)
229
+ # print(response)
230
+ # print(num_tokens)
231
+ # yield response, num_tokens
232
+
233
+ # if i % 4 == 0 and i > 1:
234
+ # message_safety = safety_check(response)
235
+ # if message_safety is not None:
236
+ # history = undo_history(history)
237
+ # yield history, "", None
238
+ # raise gr.Error(message_safety)
239
+ self.maybe_raise_safety(prompt)
240
+
241
+ # # ! streaming
242
+ generator = self._model.generate(
243
+ **inputs,
244
+ do_sample=True,
245
+ temperature=temperature,
246
+ max_new_tokens=max_tokens,
247
+ pad_token_id=self.processor.tokenizer.pad_token_id,
248
+ )
249
+
250
+ out_tokens = []
251
+ response = None
252
+ print(f"{STREAM_YIELD_MULTIPLE=}")
253
+ for index, token in enumerate(generator):
254
+ out_tokens.append(token.item())
255
+ response = self.tokenizer.decode(out_tokens, skip_special_tokens=True)
256
+
257
+ self.maybe_raise_safety(response, gen_index=index)
258
+
259
+ if STREAM_YIELD_MULTIPLE > 0:
260
+ if index % STREAM_YIELD_MULTIPLE == 0 and index > 0:
261
+ yield response, num_tokens
262
+ else:
263
+ yield response, num_tokens
264
+
265
+ del generator
266
+
267
+ if response is not None:
268
+ self.maybe_raise_safety(prompt)
269
+
270
+ full_text = prompt + response
271
+ num_tokens = self.get_multimodal_tokens(full_text, image_paths)
272
+ yield response, num_tokens
273
+