apoorvkh commited on
Commit
ca0693b
1 Parent(s): 60609f1

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +64 -0
handler.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import Blip2ForConditionalGeneration, Blip2Processor
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import base64
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
13
+ self.model = Blip2ForConditionalGeneration.from_pretrained(
14
+ "Salesforce/blip2-flan-t5-xxl", device_map="auto",
15
+ torch_dtype=torch.float16
16
+ # load_in_8bit=True,
17
+ )
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
+ inputs = data["inputs"]
21
+
22
+ if inputs["mode"] == 'generate_text':
23
+
24
+ input_text: str = inputs['input_text']
25
+ image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
26
+ max_new_tokens: int = inputs['max_new_tokens']
27
+ stop: str = inputs['stop']
28
+ temperature: float = inputs['temperature']
29
+
30
+ inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
31
+ self.model.device, self.model.dtype
32
+ )
33
+ output = self.model.generate(
34
+ **inputs, max_new_tokens=max_new_tokens, temperature=temperature
35
+ )[0]
36
+ output_text = self.processor.decode(output, skip_special_tokens=True).strip()
37
+ if stop in output_text:
38
+ output_text = output_text[: output_text.find(stop)]
39
+
40
+ return {'output_text': output_text}
41
+
42
+ elif inputs["mode"] == 'get_continuation_likelihood':
43
+
44
+ prompt: str = inputs['prompt']
45
+ continuation = inputs['continuation']
46
+ image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
47
+
48
+ inputs = self.processor(
49
+ images=image, text=(prompt + continuation), return_tensors="pt"
50
+ ).to(self.model.device, self.model.dtype)
51
+ inputs["labels"] = inputs["input_ids"]
52
+ input_ids = inputs["input_ids"][0]
53
+ tokens = [self.processor.decode([t]) for t in input_ids]
54
+
55
+ logits = self.model(**inputs).logits[0]
56
+ logprobs = F.log_softmax(logits, dim=1)
57
+ logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]
58
+
59
+ return {
60
+ 'prompt': prompt,
61
+ 'continuation': continuation,
62
+ 'tokens': tokens,
63
+ 'logprobs': logprobs
64
+ }