Spaces:
Build error
Build error
File size: 4,354 Bytes
a4fb052 3d98c13 a4fb052 3d98c13 a4fb052 3d98c13 a4fb052 3d98c13 a4fb052 3d98c13 a4fb052 3d98c13 a4fb052 bbc1f1a fa4e323 bbc1f1a a4fb052 bbc1f1a a4fb052 3d98c13 a4fb052 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import clip
from PIL import Image
import pandas as pd
import torch
from dataloader.extract_features_dataloader import transform_resize, question_preprocess
from model.vqa_model import NetVQA
from dataclasses import dataclass
from torch.cuda.amp import autocast
import gradio as gr
@dataclass
class InferenceConfig:
'''
Describes configuration of the training process
'''
model: str = "RN50x64"
checkpoint_root_clip: str = "./checkpoints/clip"
checkpoint_root_head: str = "./checkpoints/head"
use_question_preprocess: bool = True # True: delete ? at end
aux_mapping = {0: "unanswerable",
1: "unsuitable",
2: "yes",
3: "no",
4: "number",
5: "color",
6: "other"}
folds = 10
# Data
n_classes: int = 5726
# class mapping
class_mapping: str = "./data/annotations/class_mapping.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = InferenceConfig()
# load class mapping
cm = pd.read_csv(config.class_mapping)
classid_to_answer = {}
for i in range(len(cm)):
row = cm.iloc[i]
classid_to_answer[row["class_id"]] = row["answer"]
clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip, device=config.device)
model = NetVQA(config).to(config.device)
config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.model)
model_state_dict = torch.load(config.checkpoint_head)
model.load_state_dict(model_state_dict, strict=True)
model.eval()
# Select Preprocessing
image_transforms = transform_resize(clip_model.visual.input_resolution)
if config.use_question_preprocess:
question_transforms = question_preprocess
else:
question_transforms = None
clip_model.eval()
def predict(img, text):
img = Image.fromarray(img)
img = image_transforms(img)
img = img.unsqueeze(dim=0)
if question_transforms is not None:
question = question_transforms(text)
else:
question = text
question_tokens = clip.tokenize(question, truncate=True)
with torch.no_grad():
img = img.to(config.device)
img_feature = clip_model.encode_image(img)
question_tokens = question_tokens.to(config.device)
question_feature = clip_model.encode_text(question_tokens)
with autocast():
output, output_aux = model(img_feature, question_feature)
prediction_vqa = dict()
output = output.cpu().squeeze(0)
for k, v in classid_to_answer.items():
prediction_vqa[v] = float(output[k])
prediction_aux = dict()
output_aux = output_aux.cpu().squeeze(0)
for k, v in config.aux_mapping.items():
prediction_aux[v] = float(output_aux[k])
return prediction_vqa, prediction_aux
description = """
Less Is More: Linear Layers on CLIP Features as Powerful VizWiz Model
Our approach focuses on visual question answering for visual impaired people. We fine-tuned our approach on the <a href='https://vizwiz.org/tasks-and-datasets/vqa/' >CVPR Grand Challenge VizWiz 2022</a> data set.
You may click on one of the examples or upload your own image and question. The Gradio app shows the current answer for your question and an answer category.
Link to our <a href='https://arxiv.org/abs/2206.05281'>paper</a>.
"""
gr.Interface(fn=predict,
description=description,
inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
examples=[['examples/Augustiner.jpg', 'What is this?'],['examples/VizWiz_test_00006968.jpg', 'Can you tell me the color of the dog?'], ['examples/VizWiz_test_00005604.jpg', 'What drink is this?'], ['examples/VizWiz_test_00006246.jpg', 'Can you please tell me what kind of tea this is?'], ['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
).launch()
|