Upload 43 files
Browse files- modules/models/XMChat.py +12 -4
- modules/shared.py +0 -6
modules/models/XMChat.py
CHANGED
@@ -19,6 +19,13 @@ from ..utils import *
|
|
19 |
from .base_model import BaseLLMModel
|
20 |
from .. import shared
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# print('model loading')
|
23 |
# model = AutoModelForCausalLM.from_pretrained(
|
24 |
# "/home/shaozw/labs/imp-v0",
|
@@ -173,16 +180,17 @@ A chat between a curious user and an artificial intelligence assistant. This art
|
|
173 |
def get_answer_at_once(self):
|
174 |
# question = self.history[-1]["content"].strip()
|
175 |
# question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
|
|
|
176 |
prompt = self._get_imp_style_inputs()
|
177 |
logging.info(prompt)
|
178 |
# image_tok_cnt = prompt.count('<image>')
|
179 |
# global model, tokenizer
|
180 |
-
input_ids =
|
181 |
image_tensor = None
|
182 |
if '<image>' in prompt:
|
183 |
# logging.info("Preprocessing...")
|
184 |
-
image_tensor =
|
185 |
-
output_ids =
|
186 |
input_ids,
|
187 |
max_new_tokens=3000,
|
188 |
images=image_tensor,
|
@@ -194,5 +202,5 @@ A chat between a curious user and an artificial intelligence assistant. This art
|
|
194 |
# repetition_penalty=self.repetition_penalty,
|
195 |
num_return_sequences=1,
|
196 |
use_cache=True)[0]
|
197 |
-
response =
|
198 |
return response, len(response)
|
|
|
19 |
from .base_model import BaseLLMModel
|
20 |
from .. import shared
|
21 |
|
22 |
+
imp_model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
"MILVLG/imp-v1-3b",
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
device_map="auto",
|
26 |
+
trust_remote_code=True)
|
27 |
+
imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True)
|
28 |
+
|
29 |
# print('model loading')
|
30 |
# model = AutoModelForCausalLM.from_pretrained(
|
31 |
# "/home/shaozw/labs/imp-v0",
|
|
|
180 |
def get_answer_at_once(self):
|
181 |
# question = self.history[-1]["content"].strip()
|
182 |
# question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
|
183 |
+
global imp_model, imp_tokenizer
|
184 |
prompt = self._get_imp_style_inputs()
|
185 |
logging.info(prompt)
|
186 |
# image_tok_cnt = prompt.count('<image>')
|
187 |
# global model, tokenizer
|
188 |
+
input_ids = imp_tokenizer(prompt, return_tensors='pt').input_ids
|
189 |
image_tensor = None
|
190 |
if '<image>' in prompt:
|
191 |
# logging.info("Preprocessing...")
|
192 |
+
image_tensor = imp_model.image_preprocess(self.image_bytes)
|
193 |
+
output_ids = imp_model.generate(
|
194 |
input_ids,
|
195 |
max_new_tokens=3000,
|
196 |
images=image_tensor,
|
|
|
202 |
# repetition_penalty=self.repetition_penalty,
|
203 |
num_return_sequences=1,
|
204 |
use_cache=True)[0]
|
205 |
+
response = imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
|
206 |
return response, len(response)
|
modules/shared.py
CHANGED
@@ -16,12 +16,6 @@ class State:
|
|
16 |
usage_api_url = USAGE_API_URL
|
17 |
openai_api_base = OPENAI_API_BASE
|
18 |
images_completion_url = IMAGES_COMPLETION_URL
|
19 |
-
imp_model = AutoModelForCausalLM.from_pretrained(
|
20 |
-
"MILVLG/imp-v1-3b",
|
21 |
-
torch_dtype=torch.float16,
|
22 |
-
device_map="auto",
|
23 |
-
trust_remote_code=True)
|
24 |
-
imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True)
|
25 |
|
26 |
def interrupt(self):
|
27 |
self.interrupted = True
|
|
|
16 |
usage_api_url = USAGE_API_URL
|
17 |
openai_api_base = OPENAI_API_BASE
|
18 |
images_completion_url = IMAGES_COMPLETION_URL
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def interrupt(self):
|
21 |
self.interrupted = True
|