Html / handler.py
Jaykintecblic's picture
Update handler.py
f0507f9 verified
raw
history blame contribute delete
No virus
3.18 kB
from typing import Dict, List, Any
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format
class EndpointHandler:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor = AutoProcessor.from_pretrained(
model_path,
# token=api_token
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
# token=api_token,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(self.device)
self.image_seq_len = self.model.config.perceiver_config.resampler_n_latents
self.bos_token = self.processor.tokenizer.bos_token
self.bad_words_ids = self.processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
def convert_to_rgb(self, image: Image.Image) -> Image.Image:
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
def custom_transform(self, image: Image.Image) -> torch.Tensor:
image = self.convert_to_rgb(image)
image = to_numpy_array(image)
image = resize(image, (960, 960), resample=PILImageResampling.BILINEAR)
image = self.processor.image_processor.rescale(image, scale=1 / 255)
image = self.processor.image_processor.normalize(
image,
mean=self.processor.image_processor.image_mean,
std=self.processor.image_processor.image_std
)
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
return torch.tensor(image)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data.get("inputs")
if isinstance(image, str):
image = Image.open(image)
inputs = self.processor.tokenizer(
f"{self.bos_token}<fake_token_around_image>{'<image>' * self.image_seq_len}<fake_token_around_image>",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = self.processor.image_processor([image], transform=self.custom_transform)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
generated_ids = self.model.generate(**inputs, bad_words_ids=self.bad_words_ids, max_length=2048, early_stopping=True, do_sample=True, num_beams=4, top_k=100,temperature=0.7)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(generated_text)
# return {"text": generated_text}
# Format the output as an array of dictionaries with 'label' and 'score'
output = [{"label": generated_text, "score": 1.0}]
return output