luisresende13's picture
Update handler.py
080a129 verified
from typing import Dict, List, Any
from transformers import pipeline
from PIL import Image
import requests
class EndpointHandler():
def __init__(self, path=""):
self.pipe = pipeline("image-to-text", model=path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop('inputs', data)
url = inputs.get('url')
prompt = inputs.get('prompt')
max_new_tokens = inputs.get('max_new_tokens', 1000)
image = Image.open(requests.get(url, stream=True).raw)
prompt = f'user<image>\n{prompt}\nassistant:'
results = self.pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": max_new_tokens})
result = results[0]
result['generated_text'] = result['generated_text'].replace(prompt.replace('<image>', '') + ' ', '')
return result