import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format import os from typing import Dict, List, Any # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # HF_TASK = os.getenv('HF_TASK') # API_TOKEN = os.getenv('API_TOKEN') # Ensure you replace this with your actual API token # # Load processor and model # PROCESSOR = AutoProcessor.from_pretrained( # "marutitecblic/HtmlTocode", # trust_remote_code=True, # # token=API_TOKEN, # ) # MODEL = AutoModelForCausalLM.from_pretrained( # "marutitecblic/HtmlTocode", # # token=API_TOKEN, # trust_remote_code=True, # torch_dtype=torch.bfloat16, # ).to(DEVICE) # image_seq_len = MODEL.config.perceiver_config.resampler_n_latents # BOS_TOKEN = PROCESSOR.tokenizer.bos_token # BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids # def preprocess(event): # image = Image.open(event["file"]).convert("RGB") # inputs = PROCESSOR.tokenizer( # f"{BOS_TOKEN}{'' * image_seq_len}", # return_tensors="pt", # add_special_tokens=False, # ) # inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform) # inputs = {k: v.to(DEVICE) for k, v in inputs.items()} # return inputs # def inference(model_inputs): # inputs = preprocess(model_inputs) # generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096) # generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] # return {"generated_text": generated_text} # def postprocess(model_outputs): # return model_outputs # def handle(event, context): # model_inputs = event # model_outputs = inference(model_inputs) # response = postprocess(model_outputs) # return response class ImageToTextPipeline: def __init__(self,model_path:str): # Load processor and model self.PROCESSOR = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, # token=API_TOKEN, ) self.MODEL = AutoModelForCausalLM.from_pretrained( model_path, # token=API_TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(DEVICE) self.image_seq_len = self.MODEL.config.perceiver_config.resampler_n_latents self.BOS_TOKEN = self.PROCESSOR.tokenizer.bos_token self.BAD_WORDS_IDS = self.PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # image = data.pop("inputs", data) # # process image # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values # # run prediction # generated_ids = self.model.generate(pixel_values) # # decode output # prediction = generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True) image = Image.open(data["file"]).convert("RGB") inputs = self.PROCESSOR.tokenizer( f"{self.BOS_TOKEN}{'' * self.image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = self.PROCESSOR.image_processor([image], transform=self.custom_transform) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} # inputs = preprocess(model_inputs) generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096) generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] return {"text": generated_text} # return {"text":prediction[0]} # @classmethod def convert_to_rgb(self, image): if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite # @classmethod def custom_transform(self, x): x = self.convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = self.PROCESSOR.image_processor.rescale(x, scale=1 / 255) x = self.PROCESSOR.image_processor.normalize( x, mean=self.PROCESSOR.image_processor.image_mean, std=self.PROCESSOR.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x