apoorvkh commited on
Commit
ca66b5c
1 Parent(s): 2126f25

device map fix

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -1,6 +1,9 @@
1
  from typing import Dict, Any
 
2
  import torch
3
- from transformers import Blip2ForConditionalGeneration, Blip2Processor
 
 
4
  from PIL import Image
5
  from io import BytesIO
6
  import base64
@@ -10,8 +13,15 @@ import torch.nn.functional as F
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
  self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
 
 
 
 
 
 
 
13
  self.model = Blip2ForConditionalGeneration.from_pretrained(
14
- "Salesforce/blip2-flan-t5-xxl", device_map="auto",
15
  torch_dtype=torch.float16
16
  # load_in_8bit=True,
17
  )
@@ -28,7 +38,7 @@ class EndpointHandler():
28
  temperature: float = inputs['temperature']
29
 
30
  inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
31
- 0, self.model.dtype
32
  )
33
  output = self.model.generate(
34
  **inputs, max_new_tokens=max_new_tokens, temperature=temperature
@@ -47,7 +57,7 @@ class EndpointHandler():
47
 
48
  inputs = self.processor(
49
  images=image, text=(prompt + continuation), return_tensors="pt"
50
- ).to(0, self.model.dtype)
51
  inputs["labels"] = inputs["input_ids"]
52
  input_ids = inputs["input_ids"][0]
53
  tokens = [self.processor.decode([t]) for t in input_ids]
 
1
  from typing import Dict, Any
2
+
3
  import torch
4
+ from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration
5
+ from accelerate import init_empty_weights, infer_auto_device_map
6
+
7
  from PIL import Image
8
  from io import BytesIO
9
  import base64
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
  self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
16
+
17
+ config = Blip2Config.from_pretrained("Salesforce/blip2-flan-t5-xxl")
18
+ with init_empty_weights():
19
+ model = Blip2ForConditionalGeneration(config)
20
+ device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
21
+ device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
22
+
23
  self.model = Blip2ForConditionalGeneration.from_pretrained(
24
+ "Salesforce/blip2-flan-t5-xxl", device_map=device_map,
25
  torch_dtype=torch.float16
26
  # load_in_8bit=True,
27
  )
 
38
  temperature: float = inputs['temperature']
39
 
40
  inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
41
+ self.model.device, self.model.dtype
42
  )
43
  output = self.model.generate(
44
  **inputs, max_new_tokens=max_new_tokens, temperature=temperature
 
57
 
58
  inputs = self.processor(
59
  images=image, text=(prompt + continuation), return_tensors="pt"
60
+ ).to(self.model.device, self.model.dtype)
61
  inputs["labels"] = inputs["input_ids"]
62
  input_ids = inputs["input_ids"][0]
63
  tokens = [self.processor.decode([t]) for t in input_ids]