adasdimchom commited on
Commit
b2d86ed
·
1 Parent(s): 1347a75

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -15
handler.py CHANGED
@@ -10,15 +10,10 @@ class EndpointHandler():
10
  """
11
  path:
12
  """
13
- # Preload all the elements you are going to need at inference.
14
- # pseudo:
15
- # self.model= load_model(path)
16
- #self.processor = Blip2Processor.from_pretrained(path)
17
- #self.pipeline = pipeline(model = path)
18
- self.path = path
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- #self.processor = Blip2Processor.from_pretrained(path)
21
- #self.model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
 
22
 
23
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
  """
@@ -30,10 +25,8 @@ class EndpointHandler():
30
  """
31
  inputs = data.pop("inputs", data)
32
  image_url = inputs['image_url']
33
- #image = Image.open(requests.get(image_url, stream=True).raw)
34
- #processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
35
-
36
- #generated_ids = self.pipeline(**inputs)
37
- #generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
38
-
39
- return image_url, self.path, self.device
 
10
  """
11
  path:
12
  """
 
 
 
 
 
 
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ self.processor = Blip2Processor.from_pretrained(path)
15
+ self.model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
16
+ self.model.to(self.device)
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  """
 
25
  """
26
  inputs = data.pop("inputs", data)
27
  image_url = inputs['image_url']
28
+ image = Image.open(requests.get(image_url, stream=True).raw)
29
+ processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
30
+ generated_ids = self.model.generate(**processed_image)
31
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
+ return image_url, generated_text