Upload modeling_OneChart.py
Browse files- modeling_OneChart.py +26 -12
modeling_OneChart.py
CHANGED
@@ -393,8 +393,10 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
|
|
393 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
394 |
|
395 |
def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
|
396 |
-
|
397 |
-
device
|
|
|
|
|
398 |
def list_json_value(json_dict):
|
399 |
rst_str = []
|
400 |
sort_flag = True
|
@@ -456,17 +458,29 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
|
|
456 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
457 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
458 |
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
output_ids = self.generate(
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
|
471 |
outputs = outputs.replace("<Number>", "")
|
472 |
outputs = outputs.strip()
|
|
|
393 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
394 |
|
395 |
def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
|
396 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
397 |
+
# dtype = torch.bfloat16 if device=="cuda" else next(self.get_model().parameters()).dtype
|
398 |
+
dtype=torch.float16 if device=="cuda" else torch.float32
|
399 |
+
# print(device, dtype)
|
400 |
def list_json_value(json_dict):
|
401 |
rst_str = []
|
402 |
sort_flag = True
|
|
|
458 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
459 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
460 |
|
461 |
+
if device=='cuda':
|
462 |
+
with torch.autocast(device, dtype=dtype):
|
463 |
+
output_ids = self.generate(
|
464 |
+
input_ids,
|
465 |
+
images=[image_tensor_1.unsqueeze(0)],
|
466 |
+
do_sample=False,
|
467 |
+
num_beams = 1,
|
468 |
+
# no_repeat_ngram_size = 20,
|
469 |
+
# streamer=streamer,
|
470 |
+
max_new_tokens=4096,
|
471 |
+
stopping_criteria=[stopping_criteria]
|
472 |
+
)
|
473 |
+
else:
|
474 |
output_ids = self.generate(
|
475 |
+
input_ids,
|
476 |
+
images=[image_tensor_1.unsqueeze(0)],
|
477 |
+
do_sample=False,
|
478 |
+
num_beams = 1,
|
479 |
+
# no_repeat_ngram_size = 20,
|
480 |
+
# streamer=streamer,
|
481 |
+
max_new_tokens=4096,
|
482 |
+
stopping_criteria=[stopping_criteria]
|
483 |
+
)
|
484 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
|
485 |
outputs = outputs.replace("<Number>", "")
|
486 |
outputs = outputs.strip()
|