|
import gradio as gr |
|
from transformers import ViltProcessor, ViltForQuestionAnswering |
|
from PIL import Image |
|
import requests |
|
|
|
|
|
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
|
|
|
def answer_questions(image): |
|
questions = { |
|
"face_check": "Does the image contain a face?", |
|
"hair_style": "What is the hair style of the person?", |
|
"hair_color": "What is the hair color of the person?", |
|
"eye_color": "What is the eye color of the person?", |
|
"person_sex": "What is the person's sex?", |
|
"facial_hair": "Describe the facial hair of the person?", |
|
"glasses": "Does the person wear glasses?", |
|
"age": "What is the age of the person?" |
|
} |
|
|
|
answers = {} |
|
|
|
|
|
face_check_input = processor(images=image, text=questions["face_check"], return_tensors="pt", padding=True) |
|
face_check_output = model(**face_check_input) |
|
face_check_idx = face_check_output.logits.argmax(-1).item() |
|
face_check_answer = model.config.id2label[face_check_idx] |
|
|
|
answers["face_check"] = face_check_answer |
|
|
|
if face_check_answer == "no": |
|
return {"error": "ERROR: Can not detect any person in this image"} |
|
|
|
|
|
for key, question in questions.items(): |
|
if key != "face_check": |
|
inputs = processor(images=image, text=question, return_tensors="pt", padding=True) |
|
outputs = model(**inputs) |
|
answer_idx = outputs.logits.argmax(-1).item() |
|
answer_text = model.config.id2label[answer_idx] |
|
answers[key] = answer_text |
|
|
|
return answers |
|
|
|
|
|
iface = gr.Interface( |
|
fn=answer_questions, |
|
inputs=gr.inputs.Image(type="pil"), |
|
outputs=gr.outputs.Textbox() |
|
) |
|
|
|
|
|
iface.launch() |
|
|