|
import gradio as gr |
|
import torch |
|
from transformers import ViTModel, BertModel, BertTokenizer |
|
from torchvision import transforms |
|
from PIL import Image |
|
import json |
|
from torch import nn |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class VQAModel(nn.Module): |
|
def __init__(self, num_answers): |
|
super(VQAModel, self).__init__() |
|
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224') |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.classifier = nn.Sequential( |
|
nn.Dropout(0.5), |
|
nn.Linear(768 * 3, num_answers) |
|
) |
|
|
|
def forward(self, image, input_ids, attention_mask): |
|
image_features = self.vit(image).last_hidden_state[:, 0, :] |
|
text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] |
|
combined = torch.cat([image_features, text_features, image_features * text_features], dim=1) |
|
output = self.classifier(combined) |
|
return output |
|
|
|
|
|
repo_id = "duyan2803/vqa-model-vit-bert" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
try: |
|
|
|
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
num_answers = config["num_answers"] |
|
|
|
|
|
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") |
|
model = VQAModel(num_answers=num_answers) |
|
state_dict = torch.load(weights_path, map_location=device, weights_only=True) |
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
model.eval() |
|
print("Đã load mô hình thành công!") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
print("Đã load tokenizer thành công!") |
|
|
|
|
|
answer_list_path = hf_hub_download(repo_id=repo_id, filename="answer_list.json") |
|
with open(answer_list_path, "r") as f: |
|
answer_list = json.load(f) |
|
print("Đã load answer list thành công!") |
|
except Exception as e: |
|
print(f"Lỗi khi load mô hình hoặc file: {str(e)}") |
|
raise e |
|
|
|
|
|
def predict(image, question): |
|
try: |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
tokenized = tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt') |
|
input_ids = tokenized['input_ids'].to(device) |
|
attention_mask = tokenized['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor, input_ids, attention_mask) |
|
pred_idx = output.argmax(dim=1).item() |
|
|
|
return answer_list[pred_idx] |
|
except Exception as e: |
|
return f"Lỗi khi dự đoán: {str(e)}" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload an image"), |
|
gr.Textbox(label="Ask a question") |
|
], |
|
outputs=gr.Textbox(label="Answer"), |
|
title="VQA Demo - Car Recognition", |
|
description="Upload an image of a car and ask a question (e.g., 'What color is this car?' or 'What is this car?')." |
|
) |
|
|
|
interface.launch() |
|
|