import torch from PIL import Image import gradio as gr device = 'cuda' if torch.cuda.is_available() else 'cpu' model = torch.hub.load('mair-lab/mapl', 'mapl') model.eval() model.to(device, torch.float16) def predict(image: Image.Image, question: str) -> str: pixel_values = model.image_transform(image).unsqueeze(0).to(device, torch.float16) input_ids = None if question: prompt = f"Please answer the question. Question: {question} Answer:" if '?' in question else question input_ids = model.text_transform(prompt).input_ids.to(device) generated_ids = model.generate( pixel_values=pixel_values, input_ids=input_ids, max_new_tokens=100, num_beams=5 ) answer = model.text_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return answer image = gr.components.Image(type='pil', label="Image") question = gr.components.Textbox(info="Ask a visual question or leave empty for captioning", placeholder="What is this?", label="Question") answer = gr.components.Textbox(label="Answer") interface = gr.Interface( fn=predict, inputs=[image, question], outputs=answer, title="MAPL🍁", description="Paper: [https://arxiv.org/abs/2210.07179](https://arxiv.org/abs/2210.07179)\nCode and weights: [https://github.com/mair-lab/mapl](https://github.com/mair-lab/mapl)", allow_flagging='never') interface.launch()