AnwenHu commited on
Commit
483d723
1 Parent(s): 536a2a7

Update model_worker.py

Browse files
Files changed (1) hide show
  1. model_worker.py +5 -4
model_worker.py CHANGED
@@ -76,7 +76,8 @@ class ModelWorker:
76
  @torch.inference_mode()
77
  def generate_stream(self, params):
78
  tokenizer, model = self.tokenizer, self.model
79
-
 
80
  prompt = params["prompt"]
81
  ori_prompt = prompt
82
  images = params.get("images", None)
@@ -90,9 +91,9 @@ class ModelWorker:
90
  assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
91
 
92
  images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
93
- images = images.to(self.model.device, dtype=torch.float16)
94
- # images = images.to(self.model.device, dtype=torch.bfloat16)
95
- patch_positions = patch_positions.to(self.model.device)
96
 
97
  replace_token = DEFAULT_IMAGE_TOKEN
98
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
 
76
  @torch.inference_mode()
77
  def generate_stream(self, params):
78
  tokenizer, model = self.tokenizer, self.model
79
+ # for adjust to zero environment of huggingface
80
+ model.to(self.device)
81
  prompt = params["prompt"]
82
  ori_prompt = prompt
83
  images = params.get("images", None)
 
91
  assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
92
 
93
  images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
94
+ images = images.to(self.device, dtype=torch.float16)
95
+ # images = images.to(self.device, dtype=torch.bfloat16)
96
+ patch_positions = patch_positions.to(self.device)
97
 
98
  replace_token = DEFAULT_IMAGE_TOKEN
99
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)