Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ViTModel, BertModel, BertTokenizer | |
from torchvision import transforms | |
from PIL import Image | |
import json | |
from torch import nn | |
from huggingface_hub import hf_hub_download | |
# Định nghĩa mô hình | |
class VQAModel(nn.Module): | |
def __init__(self, num_answers): | |
super(VQAModel, self).__init__() | |
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224') | |
self.bert = BertModel.from_pretrained('bert-base-uncased') | |
self.classifier = nn.Sequential( | |
nn.Dropout(0.5), | |
nn.Linear(768 * 3, num_answers) | |
) | |
def forward(self, image, input_ids, attention_mask): | |
image_features = self.vit(image).last_hidden_state[:, 0, :] | |
text_features = self.bert(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] | |
combined = torch.cat([image_features, text_features, image_features * text_features], dim=1) | |
output = self.classifier(combined) | |
return output | |
# Load mô hình từ Hugging Face Hub | |
repo_id = "duyan2803/vqa-model-vit-bert" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
try: | |
# Load config | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
num_answers = config["num_answers"] | |
print(f"Number of answers: {num_answers}") | |
# Load weights | |
weights_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") | |
model = VQAModel(num_answers=num_answers) | |
state_dict = torch.load(weights_path, map_location=device, weights_only=True) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
print("Đã load mô hình thành công!") | |
# Load tokenizer | |
tokenizer = BertTokenizer.from_pretrained(repo_id) | |
print("Đã load tokenizer thành công!") | |
# Load answer list | |
answer_list_path = hf_hub_download(repo_id=repo_id, filename="answer_list.json") | |
with open(answer_list_path, "r") as f: | |
answer_list = json.load(f) | |
print("Đã load answer list thành công!") | |
except Exception as e: | |
print(f"Lỗi khi load mô hình hoặc file: {str(e)}") | |
raise e | |
# Hàm dự đoán | |
def predict(image, question): | |
try: | |
# Xử lý ảnh | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
# Xử lý câu hỏi | |
tokenized = tokenizer(question, padding='max_length', truncation=True, max_length=32, return_tensors='pt') | |
input_ids = tokenized['input_ids'].to(device) | |
attention_mask = tokenized['attention_mask'].to(device) | |
# Dự đoán | |
with torch.no_grad(): | |
output = model(image_tensor, input_ids, attention_mask) | |
pred_idx = output.argmax(dim=1).item() | |
return answer_list[pred_idx] | |
except Exception as e: | |
return f"Lỗi khi dự đoán: {str(e)}" | |
# Giao diện Gradio | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(type="pil", label="Upload an image"), | |
gr.Textbox(label="Ask a question") | |
], | |
outputs=gr.Textbox(label="Answer"), | |
title="VQA Demo - Car Recognition", | |
description="Upload an image of a car and ask a question (e.g., 'What color is this car?' or 'What is this car?')." | |
) | |
interface.launch() |