xzl12306 commited on
Commit
1b2e2ca
1 Parent(s): d6bc023

change to fp32

Browse files
Files changed (2) hide show
  1. app.py +13 -13
  2. tinychart/model/builder.py +3 -3
app.py CHANGED
@@ -123,10 +123,10 @@ def get_response(params):
123
 
124
  if type(images) is list:
125
  images = [
126
- image.to(model.device, dtype=torch.float16) for image in images
127
  ]
128
  else:
129
- images = images.to(model.device, dtype=torch.float16)
130
 
131
  replace_token = DEFAULT_IMAGE_TOKEN
132
  if getattr(model.config, "mm_use_im_start_end", False):
@@ -343,44 +343,43 @@ def build_demo():
343
  visible=False,
344
  )
345
 
346
- # cur_dir = os.path.dirname(os.path.abspath(__file__))
347
  cur_dir = Path(__file__).parent
348
  gr.Examples(
349
  examples=[
350
  [
351
- f"{cur_dir}/examples/market.png",
352
  "What is the highest number of companies in the domestic market? Answer with detailed steps.",
353
  ],
354
  [
355
- f"{cur_dir}/examples/college.png",
356
  "What is the difference between Asians and Whites degree distribution? Answer with detailed steps."
357
  ],
358
  [
359
- f"{cur_dir}/examples/immigrants.png",
360
  "How many immigrants are there in 1931?",
361
  ],
362
  [
363
- f"{cur_dir}/examples/sails.png",
364
  "By how much percentage wholesale is less than retail? Answer with detailed steps."
365
  ],
366
  [
367
- f"{cur_dir}/examples/diseases.png",
368
  "Is the median value of all the bars greater than 30? Answer with detailed steps.",
369
  ],
370
  [
371
- f"{cur_dir}/examples/economy.png",
372
  "Which team has higher economy in 28 min?"
373
  ],
374
  [
375
- f"{cur_dir}/examples/workers.png",
376
  "Generate underlying data table for the chart."
377
  ],
378
  [
379
- f"{cur_dir}/examples/sports.png",
380
  "Create a brief summarization or extract key insights based on the chart image."
381
  ],
382
  [
383
- f"{cur_dir}/examples/albums.png",
384
  "Redraw the chart with Python code."
385
  ]
386
  ],
@@ -489,7 +488,8 @@ if __name__ == "__main__":
489
  model_name=args.model_name,
490
  device="cpu",
491
  load_4bit=args.load_4bit,
492
- load_8bit=args.load_8bit
 
493
  )
494
 
495
  demo = build_demo()
 
123
 
124
  if type(images) is list:
125
  images = [
126
+ image.to(model.device, dtype=torch.float32) for image in images
127
  ]
128
  else:
129
+ images = images.to(model.device, dtype=torch.float32)
130
 
131
  replace_token = DEFAULT_IMAGE_TOKEN
132
  if getattr(model.config, "mm_use_im_start_end", False):
 
343
  visible=False,
344
  )
345
 
 
346
  cur_dir = Path(__file__).parent
347
  gr.Examples(
348
  examples=[
349
  [
350
+ f"{cur_dir}/images/market.png",
351
  "What is the highest number of companies in the domestic market? Answer with detailed steps.",
352
  ],
353
  [
354
+ f"{cur_dir}/images/college.png",
355
  "What is the difference between Asians and Whites degree distribution? Answer with detailed steps."
356
  ],
357
  [
358
+ f"{cur_dir}/images/immigrants.png",
359
  "How many immigrants are there in 1931?",
360
  ],
361
  [
362
+ f"{cur_dir}/images/sails.png",
363
  "By how much percentage wholesale is less than retail? Answer with detailed steps."
364
  ],
365
  [
366
+ f"{cur_dir}/images/diseases.png",
367
  "Is the median value of all the bars greater than 30? Answer with detailed steps.",
368
  ],
369
  [
370
+ f"{cur_dir}/images/economy.png",
371
  "Which team has higher economy in 28 min?"
372
  ],
373
  [
374
+ f"{cur_dir}/images/workers.png",
375
  "Generate underlying data table for the chart."
376
  ],
377
  [
378
+ f"{cur_dir}/images/sports.png",
379
  "Create a brief summarization or extract key insights based on the chart image."
380
  ],
381
  [
382
+ f"{cur_dir}/images/albums.png",
383
  "Redraw the chart with Python code."
384
  ]
385
  ],
 
488
  model_name=args.model_name,
489
  device="cpu",
490
  load_4bit=args.load_4bit,
491
+ load_8bit=args.load_8bit,
492
+ torch_dtype=torch.float32,
493
  )
494
 
495
  demo = build_demo()
tinychart/model/builder.py CHANGED
@@ -40,7 +40,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
40
  bnb_4bit_use_double_quant=True,
41
  bnb_4bit_quant_type='nf4'
42
  )
43
- else:
44
  kwargs['torch_dtype'] = torch.float16
45
 
46
  # Load LLaVA model
@@ -97,7 +97,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
97
  **kwargs)
98
 
99
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
100
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
101
  model.load_state_dict(mm_projector_weights, strict=False)
102
  else:
103
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
@@ -115,7 +115,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
115
  vision_tower.load_model()
116
 
117
  if device != "auto":
118
- vision_tower.to(device=device, dtype=torch.float16)
119
 
120
  image_processor = vision_tower.image_processor
121
 
 
40
  bnb_4bit_use_double_quant=True,
41
  bnb_4bit_quant_type='nf4'
42
  )
43
+ elif 'torch_dtype' not in kwargs:
44
  kwargs['torch_dtype'] = torch.float16
45
 
46
  # Load LLaVA model
 
97
  **kwargs)
98
 
99
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
100
+ mm_projector_weights = {k: v.to(kwargs['torch_dtype']) for k, v in mm_projector_weights.items()}
101
  model.load_state_dict(mm_projector_weights, strict=False)
102
  else:
103
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="right")
 
115
  vision_tower.load_model()
116
 
117
  if device != "auto":
118
+ vision_tower.to(device=device, dtype=kwargs['torch_dtype'])
119
 
120
  image_processor = vision_tower.image_processor
121