rishabh-zuma commited on
Commit
7c90bcf
1 Parent(s): 28b2957

Added new handler

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import base64
4
+ import torch
5
+ import os
6
+ from io import BytesIO
7
+ from transformers import BlipForConditionalGeneration, BlipProcessor
8
+ import requests
9
+ from PIL import Image
10
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ class EndpointHandler():
14
+ def __init__(self, path=""):
15
+ # load the optimized model
16
+
17
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
18
+ # self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
19
+ # self.model.eval()
20
+ # self.model = self.model.to(device)
21
+
22
+
23
+
24
+ self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
25
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
26
+ self.model.eval()
27
+ self.model = self.model.to(device)
28
+
29
+
30
+
31
+
32
+ def __call__(self, data: Any) -> Dict[str, Any]:
33
+ """
34
+ Args:
35
+ data (:obj:):
36
+ includes the input data and the parameters for the inference.
37
+ Return:
38
+ A :obj:`dict`:. The object returned should be a dict of one list like {"captions": ["A hugging face at the office"]} containing :
39
+ - "caption": A string corresponding to the generated caption.
40
+ """
41
+ img_data = data.pop("image", data)
42
+ prompt = data.pop("prompt", None)
43
+ parameters = data.pop("parameters", {})
44
+
45
+ if isinstance(img_data, Image.Image):
46
+ raw_image = img_data
47
+ else:
48
+ inputs = isinstance(img_data, str) and [img_data] or img_data
49
+ # raw_image = [Image.open(BytesIO(base64.b64decode(_img))) for _img in inputs]
50
+ raw_image = Image.open(BytesIO(base64.b64decode(img_data)))
51
+
52
+ # processed_images = self.processor(images=raw_images, return_tensors="pt")
53
+ # processed_images["pixel_values"] = processed_images["pixel_values"].to(device)
54
+ # processed_images = {**processed_images, **parameters}
55
+
56
+ # with torch.no_grad():
57
+ # out = self.model.generate(**processed_images)
58
+ # captions = self.processor.batch_decode(out, skip_special_tokens=True)
59
+
60
+ ##############
61
+ # img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
62
+ # raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
63
+
64
+ inputs = processor(raw_image, prompt, return_tensors="pt")
65
+
66
+ out = model.generate(**inputs)
67
+ captions = processor.decode(out[0], skip_special_tokens=True)
68
+
69
+ return {"captions": captions}