VaqAndOkvqa / app.py
DDingcheol's picture
Update app.py
b28441d
#ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
from datasets import load_dataset
from collections import defaultdict
# ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
dataset = dataset_load['train'].select(range(300))
# ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์„ ํƒ
selected_features = ['image', 'answers', 'question']
selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})
# ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ
answers_to_id = defaultdict(lambda: len(answers_to_id))
selected_dataset = selected_dataset.map(lambda ex: {
'answers': [answers_to_id[ans] for ans in ex['answers']],
'question': ex['question'],
'image': ex['image']
})
id_to_answers = {v: k for k, v in answers_to_id.items()}
id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
'question': ex['question'],
'image': ex['image']})
flattened_features = []
for ex in selected_dataset:
flattened_example = {
'answers': ex['answers'],
'question': ex['question'],
'image': ex['image'],
}
flattened_features.append(flattened_example)
# ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model_name = 'microsoft/git-base-vqav2'
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ํ•™์Šต
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
def preprocess_function(examples):
tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
return {
'input_ids': tokenized_inputs['input_ids'],
'attention_mask': tokenized_inputs['attention_mask'],
'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
'labels': [[label] for label in examples['answers']]
}
dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=20,
per_device_train_batch_size=4,
logging_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ok_vqa_dataset
)
# ๋ชจ๋ธ ํ•™์Šต
trainer.train()
import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model_name = 'microsoft/git-base-vqav2' # ์‚ฌ์šฉํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
# ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_answer(image, question):
inputs = tokenizer(question, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# ์ด๋ฏธ์ง€์™€ ๊ด€๋ จ๋œ ์ฒ˜๋ฆฌ ์ˆ˜ํ–‰
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ฝ”๋“œ๋ฅผ ์—ฌ๊ธฐ์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค (์ž…๋ ฅ๋œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์ „์ฒ˜๋ฆฌ ๋“ฑ)
# ๋ชจ๋ธ์— ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์˜ˆ์ธก ์ˆ˜ํ–‰
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# ์˜ˆ์ธก ๊ฒฐ๊ณผ์—์„œ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ๋ ˆ์ด๋ธ” ID ๊ฐ€์ ธ์˜ค๊ธฐ
predicted_label_id = torch.argmax(outputs.logits).item()
predicted_label = id_to_label_fn(predicted_label_id)
return predicted_label
iface = gr.Interface(
fn=predict_answer,
inputs=["image", "text"],
outputs="text",
title="Visual Question Answering",
description="Input an image and a question to get the model's answer.",
example=[
"https://your_image_url.jpg",
"What is shown in the image?"
]
)
iface.launch()