sooh-j commited on
Commit
c97c8cb
1 Parent(s): b4bc0d9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -18
handler.py CHANGED
@@ -1,22 +1,24 @@
1
- import requests
2
- from PIL import Image
 
 
 
3
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
  from typing import Dict, List, Any
 
 
 
5
  import torch
6
- import sys
7
- import base64
8
- import logging
9
- import copy
10
- import numpy as np
11
-
12
  class EndpointHandler():
13
  def __init__(self, path=""):
 
14
  self.model_base = "Salesforce/blip2-opt-2.7b"
15
  self.model_name = "sooh-j/blip2-vizwizqa"
16
  self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True)
17
- self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
18
  self.processor = Blip2Processor.from_pretrained(self.base_model_name)
19
- self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name)
20
 
21
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
  self.model.to(self.device)
@@ -59,6 +61,13 @@ class EndpointHandler():
59
  # return { "embeddings": embeddings }
60
 
61
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
62
  # await hf.visualQuestionAnswering({
63
  # model: 'dandelin/vilt-b32-finetuned-vqa',
64
  # inputs: {
@@ -66,22 +75,32 @@ class EndpointHandler():
66
  # image: await (await fetch('https://placekitten.com/300/300')).blob()
67
  # }
68
  # })
69
- inputs = data.get("inputs")
70
- imageBase64 = inputs.get("image")
71
- question = inputs.get("question")
 
 
 
 
 
 
 
72
 
73
  # data = data.pop("inputs", data)
74
  # data = data.pop("image", image)
75
 
76
  # image = Image.open(requests.get(imageBase64, stream=True).raw)
77
- image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
78
-
79
  prompt = f"Question: {question}, Answer:"
80
- processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)
81
 
82
  # answer = self._generate_answer(
83
  # model_path, prompt, image,
84
  # )
85
  out = self.model.generate(**processed)
86
-
87
- return self.processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
1
+ # import sys
2
+ # import base64
3
+ # import logging
4
+ # import copy
5
+ import numpy as np
6
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
7
  from typing import Dict, List, Any
8
+ from PIL import Image
9
+ from transformers import pipeline
10
+ import requests
11
  import torch
12
+
 
 
 
 
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
  self.model_base = "Salesforce/blip2-opt-2.7b"
17
  self.model_name = "sooh-j/blip2-vizwizqa"
18
  self.base_model = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True)
19
+ # self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
20
  self.processor = Blip2Processor.from_pretrained(self.base_model_name)
21
+ self.model = PeftModel.from_pretrained(self.model_name, self.base_model_name).to(self.device)
22
 
23
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
  self.model.to(self.device)
 
61
  # return { "embeddings": embeddings }
62
 
63
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
64
+ """
65
+ data args:
66
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
67
+ kwargs
68
+ Return:
69
+ A :obj:`list` | `dict`: will be serialized and returned
70
+ """
71
  # await hf.visualQuestionAnswering({
72
  # model: 'dandelin/vilt-b32-finetuned-vqa',
73
  # inputs: {
 
75
  # image: await (await fetch('https://placekitten.com/300/300')).blob()
76
  # }
77
  # })
78
+ inputs = data.pop("inputs", data)
79
+ try:
80
+ imageBase64 = inputs["image"]
81
+ image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
82
+
83
+ except:
84
+ image_url = inputs['image']
85
+ image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
86
+
87
+ question = inputs["question"]
88
 
89
  # data = data.pop("inputs", data)
90
  # data = data.pop("image", image)
91
 
92
  # image = Image.open(requests.get(imageBase64, stream=True).raw)
93
+ # image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
94
+
95
  prompt = f"Question: {question}, Answer:"
96
+ processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
97
 
98
  # answer = self._generate_answer(
99
  # model_path, prompt, image,
100
  # )
101
  out = self.model.generate(**processed)
102
+
103
+ result = {}
104
+ text_output = self.processor.decode(out[0], skip_special_tokens=True)
105
+ result["text_output"] = text_output
106
+ return result