ChirathD commited on
Commit
d149168
1 Parent(s): 543ffeb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -9
handler.py CHANGED
@@ -3,9 +3,13 @@ from typing import Dict, List, Any
3
  from PIL import Image
4
  import torch
5
  import os
 
 
6
  from io import BytesIO
7
  # from transformers import BlipForConditionalGeneration, BlipProcessor
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
 
 
9
  # -
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -30,19 +34,31 @@ class EndpointHandler():
30
  A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
31
  - "caption": A string corresponding to the generated caption.
32
  """
 
33
  inputs = data.pop("inputs", data)
34
  parameters = data.pop("parameters", {})
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
37
 
38
- processed_image = self.processor(images=raw_images, return_tensors="pt")
39
- processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
40
- processed_image = {**processed_image, **parameters}
41
 
42
- with torch.no_grad():
43
- out = self.model.generate(
44
- **processed_image
45
- )
46
- captions = self.processor.batch_decode(out, skip_special_tokens=True)
47
 
48
  return {"captions": captions}
 
3
  from PIL import Image
4
  import torch
5
  import os
6
+ import io
7
+ import base64
8
  from io import BytesIO
9
  # from transformers import BlipForConditionalGeneration, BlipProcessor
10
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
11
+
12
+
13
  # -
14
 
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
34
  A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
35
  - "caption": A string corresponding to the generated caption.
36
  """
37
+ print(data)
38
  inputs = data.pop("inputs", data)
39
  parameters = data.pop("parameters", {})
40
+ print(input)
41
+ image_bytes = base64.b64decode(inputs)
42
+ image_io = io.BytesIO(image_bytes)
43
+ image = Image.open(image_io)
44
+
45
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
46
+ pixel_values = inputs.pixel_values
47
+
48
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=25)
49
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
+ print(generated_caption)
51
 
52
+ # raw_images = [Image.open(BytesIO(_img)) for _img in inputs]
53
 
54
+ # processed_image = self.processor(images=raw_images, return_tensors="pt")
55
+ # processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
56
+ # processed_image = {**processed_image, **parameters}
57
 
58
+ # with torch.no_grad():
59
+ # out = self.model.generate(
60
+ # **processed_image
61
+ # )
62
+ # captions = self.processor.batch_decode(out, skip_special_tokens=True)
63
 
64
  return {"captions": captions}