kppkkp commited on
Commit
49c6f36
1 Parent(s): eb50750

Upload modeling_OneChart.py

Browse files
Files changed (1) hide show
  1. modeling_OneChart.py +26 -12
modeling_OneChart.py CHANGED
@@ -393,8 +393,10 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
393
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
394
 
395
  def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
396
- dtype=torch.bfloat16
397
- device="cuda"
 
 
398
  def list_json_value(json_dict):
399
  rst_str = []
400
  sort_flag = True
@@ -456,17 +458,29 @@ class OneChartOPTForCausalLM(OPTForCausalLM):
456
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
457
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
458
 
459
- with torch.autocast(device, dtype=dtype):
 
 
 
 
 
 
 
 
 
 
 
 
460
  output_ids = self.generate(
461
- input_ids,
462
- images=[image_tensor_1.unsqueeze(0).half()],
463
- do_sample=False,
464
- num_beams = 1,
465
- # no_repeat_ngram_size = 20,
466
- # streamer=streamer,
467
- max_new_tokens=4096,
468
- stopping_criteria=[stopping_criteria]
469
- )
470
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
471
  outputs = outputs.replace("<Number>", "")
472
  outputs = outputs.strip()
 
393
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
394
 
395
  def chat(self, tokenizer, image_file, reliable_check=True, print_prompt=False):
396
+ device = "cuda" if torch.cuda.is_available() else "cpu"
397
+ # dtype = torch.bfloat16 if device=="cuda" else next(self.get_model().parameters()).dtype
398
+ dtype=torch.float16 if device=="cuda" else torch.float32
399
+ # print(device, dtype)
400
  def list_json_value(json_dict):
401
  rst_str = []
402
  sort_flag = True
 
458
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
459
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
460
 
461
+ if device=='cuda':
462
+ with torch.autocast(device, dtype=dtype):
463
+ output_ids = self.generate(
464
+ input_ids,
465
+ images=[image_tensor_1.unsqueeze(0)],
466
+ do_sample=False,
467
+ num_beams = 1,
468
+ # no_repeat_ngram_size = 20,
469
+ # streamer=streamer,
470
+ max_new_tokens=4096,
471
+ stopping_criteria=[stopping_criteria]
472
+ )
473
+ else:
474
  output_ids = self.generate(
475
+ input_ids,
476
+ images=[image_tensor_1.unsqueeze(0)],
477
+ do_sample=False,
478
+ num_beams = 1,
479
+ # no_repeat_ngram_size = 20,
480
+ # streamer=streamer,
481
+ max_new_tokens=4096,
482
+ stopping_criteria=[stopping_criteria]
483
+ )
484
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
485
  outputs = outputs.replace("<Number>", "")
486
  outputs = outputs.strip()