File size: 4,966 Bytes
8ceebf6 f3c07ed 9e5ebfb f3c07ed 8ceebf6 f3c07ed c6ad764 5f85a7e c6ad764 f3c07ed 5b1faf1 c6ad764 f3c07ed c6ad764 f3c07ed c6ad764 5b1faf1 f3c07ed c6ad764 f3c07ed c6ad764 f3c07ed 107fab0 c6ad764 f3c07ed 107fab0 c6ad764 f3c07ed 107fab0 c6ad764 f3c07ed 107fab0 c6ad764 f3c07ed 107fab0 f3c07ed c6ad764 107fab0 f3c07ed c6ad764 f3c07ed c6ad764 f3c07ed c6ad764 5b1faf1 f3c07ed 5b1faf1 107fab0 c6ad764 f3c07ed 8ceebf6 c6ad764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer')
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
# Define the normalization and transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet standard deviation
)
inference_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Load the dictionary
with open("DICTIONARY (3).txt", "r", encoding="utf-8") as file:
dictionary = dict(line.strip().split("\t") for line in file)
# Function to correct words in the caption using the dictionary
def correct_caption(caption):
corrected_words = [dictionary.get(word, word) for word in caption.split()]
corrected_caption = " ".join(corrected_words)
return corrected_caption
# Function to generate captions for an image
def generate_captions(image):
img_tensor = inference_transforms(image).unsqueeze(0)
generated = caption_model.generate(
img_tensor,
num_beams=3,
max_length=10,
early_stopping=True,
do_sample=True,
top_k=1000,
num_return_sequences=1,
)
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
return captions
# Function to generate questions given a context and answer
def generate_questions(context, answer):
text = "context: " + context + " " + "answer: " + answer + " </s>"
text_encoding = question_tokenizer.encode_plus(
text, return_tensors="pt"
)
question_model.eval()
generated_ids = question_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
'question: ', ' ') for g in generated_ids]
return questions
# Interface
class Seafoam(Base):
pass
seafoam = Seafoam()
def caption_question_interface(image):
# Generate captions
captions = generate_captions(image)
# Proofread captions using the dictionary
corrected_captions = [correct_caption(caption) for caption in captions]
# Generate questions for each caption
questions_with_answers = []
for caption in corrected_captions:
words = caption.split()
# Generate questions for the first word
if len(words) > 0:
answer = words[0]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word
if len(words) > 1:
answer = words[1]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word + first word
if len(words) > 1:
answer = " ".join(words[:2])
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the third word
if len(words) > 2:
answer = words[2]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the fourth word
if len(words) > 3:
answer = words[3]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Format questions with answers
formatted_questions = [f"Question: {q}\nKeyword: {a}" for q, a in questions_with_answers]
formatted_questions = "\n".join(formatted_questions)
# Return the generated captions and formatted questions with answers
return "\n".join(corrected_captions), formatted_questions
gr_interface = gr.Interface(
fn=caption_question_interface,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Textbox(label="Generated Captions"),
gr.Textbox(label="Generated Questions")
],
title="Visual Question Generator",
description="Generate captions and questions for images using Arabic image captioning model and question generation model",
theme=seafoam,
)
# Launch the interface
gr_interface.launch(share=True)
|