import torch from PIL import Image import gradio as gr device = 'cuda' if torch.cuda.is_available() else 'cpu' model = torch.hub.load('mair-lab/mapl', 'mapl') model.eval() model.to(device, torch.bfloat16) def predict(image: Image.Image, question: str) -> str: pixel_values = model.image_transform(image).unsqueeze(0).to(device, torch.bfloat16) input_ids = None if question: text = f"Please answer the question. Question: {question} Answer:" if '?' in question else question input_ids = model.text_transform(text).input_ids.to(device) generated_ids = model.generate( pixel_values=pixel_values, input_ids=input_ids, max_new_tokens=50 ) answer = model.text_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return answer image = gr.components.Image(type='pil', label="Image") question = gr.components.Textbox(value="What is this?", label="Question") answer = gr.components.Textbox(label="Answer") interface = gr.Interface( fn=predict, inputs=[image, question], outputs=answer, allow_flagging='never') interface.launch()