from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration import torch class Inference: def __init__(self): self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") self.deplot_processor = Pix2StructProcessor.from_pretrained('google/deplot') self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot') self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") self.blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16) self.device = "cuda" if torch.cuda.is_available() else "cpu" def inference_vilt(self, image, text): encoding = self.vilt_processor(image, text, return_tensors="pt") outputs = self.vilt_model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() return f"{self.vilt_model.config.id2label[idx]}" def inference_deplot(self, image, text): inputs = self.deplot_processor(images=image, text=text, return_tensors="pt") predictions = self.deplot_model.generate(**inputs, max_new_tokens=512) return f"{self.deplot_processor.decode(predictions[0], skip_special_tokens=True)}" def inference_vilt(self, image, text): inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device, torch.float16) generated_ids = self.blip_model.generate(**inputs) generated_text = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return f"{generated_text}"