import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering from PIL import Image import torch # Load the processor and model processor = ViltProcessor.from_pretrained("MariaK/vilt_finetuned_200") model = ViltForQuestionAnswering.from_pretrained("MariaK/vilt_finetuned_200") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def predict(image, question): # prepare inputs inputs = processor(image, question, return_tensors="pt").to(device) # forward pass with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits idx = logits.argmax(-1).item() predicted_answer = model.config.id2label[idx] return predicted_answer # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil"), gr.Textbox(lines=1, placeholder="Enter your question here..."), ], outputs="text", title="Visual Question Answering with Fine-tuned Vilt", description="Upload an image and ask a question about it!", ) # Launch the interface iface.launch(share=True) # Set share=True to share the space