Spaces:
Running
Running
import gradio as gr | |
import torch | |
from torch import nn | |
from huggingface_hub import hf_hub_download | |
from transformers import BertModel, BertTokenizer, CLIPModel, BertConfig, CLIPConfig, CLIPProcessor | |
from modeling_unimo import UnimoForMaskedLM | |
def load_dict_text(path): | |
with open(path, 'r') as f: | |
load_data = {} | |
lines = f.readlines() | |
for line in lines: | |
key, value = line.split('\t') | |
load_data[key] = value.replace('\n', '') | |
return load_data | |
def load_text(path): | |
with open(path, 'r') as f: | |
lines = f.readlines() | |
load_data = [] | |
for line in lines: | |
load_data.append(line.strip().replace('\n', '')) | |
return load_data | |
class MKGformerModel(nn.Module): | |
def __init__(self, text_config, vision_config): | |
super().__init__() | |
self.model = UnimoForMaskedLM(text_config, vision_config) | |
def farword(self, batch): | |
return self.model(**batch, return_dict=True) | |
# tokenizer | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
# entity and relation | |
ent2text = load_dict_text('./dataset/MarKG/entity2text.txt') | |
rel2text = load_dict_text('./dataset/MarKG/relation2text.txt') | |
analogy_entities = load_text('./dataset/MARS/analogy_entities.txt') | |
analogy_relations = load_text('./dataset/MARS/analogy_relations.txt') | |
ent2description = load_dict_text('./dataset/MarKG/entity2textlong.txt') | |
text2ent = {text: ent for ent, text in ent2text.items()} | |
ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(ent2description)} | |
rel2token = {rel: f"[RELATION_{i}]" for i, rel in enumerate(rel2text)} | |
analogy_ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(ent2description) if ent in analogy_entities} | |
analogy_rel2token = {rel : f"[RELATION_{i}]" for i, rel in enumerate(rel2text) if rel in analogy_relations} | |
entity_list = list(ent2token.values()) | |
relation_list = list(rel2token.values()) | |
analogy_ent_list = list(analogy_ent2token.values()) | |
analogy_rel_list = list(analogy_rel2token.values()) | |
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': entity_list}) | |
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': relation_list}) | |
vocab = tokenizer.get_added_vocab() # dict: word: idx | |
relation_id_st = vocab[relation_list[0]] | |
relation_id_ed = vocab[relation_list[-1]] + 1 | |
entity_id_st = vocab[entity_list[0]] | |
entity_id_ed = vocab[entity_list[-1]] + 1 | |
# analogy entities and relations | |
analogy_entity_ids = [vocab[ent] for ent in analogy_ent_list] | |
analogy_relation_ids = [vocab[rel] for rel in analogy_rel_list] | |
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ["[R]"]}) | |
# model | |
checkpoint_path = hf_hub_download(repo_id='flow3rdown/mkgformer_mart_ft', filename="mkgformer_mart_ft", repo_type='model') | |
clip_config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch32').vision_config | |
clip_config.device = 'cpu' | |
bert_config = BertConfig.from_pretrained('bert-base-uncased') | |
mkgformer = MKGformerModel(clip_config, bert_config) | |
mkgformer.model.resize_token_embeddings(len(tokenizer)) | |
mkgformer.load_state_dict(torch.load(checkpoint_path, map_location='cpu')["state_dict"]) | |
# processor | |
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id): | |
# (I, I) -> (T, ?) | |
ques_ent_text = ent2description[question_id] | |
inputs = tokenizer( | |
tokenizer.sep_token.join([analogy_ent2token[head_id] + " ", "[R] ", analogy_ent2token[tail_id] + " "]), | |
tokenizer.sep_token.join([analogy_ent2token[question_id] + " " + ques_ent_text, "[R] ", "[MASK]"]), | |
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
inputs['sep_idx'] = torch.tensor(sep_idx) | |
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
for i, idx in enumerate(sep_idx): | |
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
# image | |
pixel_values = processor(images=[head_img, tail_img], return_tensors='pt')['pixel_values'].squeeze() | |
inputs['pixel_values'] = pixel_values.unsqueeze(0) | |
input_ids = inputs['input_ids'] | |
model_output = mkgformer.model(**inputs, return_dict=True) | |
logits = model_output[0].logits | |
bsz = input_ids.shape[0] | |
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
return answer | |
def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id): | |
# (T, T) -> (I, ?) | |
head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] | |
inputs = tokenizer( | |
tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), | |
tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), | |
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
inputs['sep_idx'] = torch.tensor(sep_idx) | |
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
for i, idx in enumerate(sep_idx): | |
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
# image | |
pixel_values = processor(images=question_img, return_tensors='pt')['pixel_values'].unsqueeze(1) | |
pixel_values = torch.cat((pixel_values, torch.zeros_like(pixel_values)), dim=1) | |
inputs['pixel_values'] = pixel_values | |
input_ids = inputs['input_ids'] | |
model_output = mkgformer.model(**inputs, return_dict=True) | |
logits = model_output[0].logits | |
bsz = input_ids.shape[0] | |
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
return answer | |
def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id): | |
# (I, T) -> (I, ?) | |
head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id] | |
inputs = tokenizer( | |
tokenizer.sep_token.join([analogy_ent2token[head_id], "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]), | |
tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]), | |
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True) | |
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']] | |
inputs['sep_idx'] = torch.tensor(sep_idx) | |
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone() | |
for i, idx in enumerate(sep_idx): | |
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0 | |
# image | |
pixel_values = processor(images=[head_img, question_img], return_tensors='pt')['pixel_values'].squeeze() | |
inputs['pixel_values'] = pixel_values.unsqueeze(0) | |
input_ids = inputs['input_ids'] | |
model_output = mkgformer.model(**inputs, return_dict=True) | |
logits = model_output[0].logits | |
bsz = input_ids.shape[0] | |
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz | |
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity | |
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]] | |
return answer | |
def single_tab_iit(): | |
with gr.Column(): | |
gr.Markdown(""" $(I_h, I_t) : (T_q, ?)$ | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
head_image = gr.Image(type='pil', label="Head Image") | |
head_ent = gr.Textbox(lines=1, label="Head Entity") | |
with gr.Column(): | |
tail_image = gr.Image(type='pil', label="Tail Image") | |
tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
with gr.Column(): | |
question_text = gr.Textbox(lines=1, label="Question Name") | |
question_ent = gr.Textbox(lines=1, label="Question Entity") | |
submit_btn = gr.Button("Submit") | |
output_text = gr.Textbox(label="Output") | |
submit_btn.click(fn=single_inference_iit, | |
inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], | |
outputs=[output_text]) | |
examples=[['examples/tree.jpg', 'Q10884', 'examples/forest.png', 'Q4421', "Anhui", 'Q40956']] | |
ex = gr.Examples( | |
examples=examples, | |
fn=single_inference_iit, | |
inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent], | |
outputs=[output_text], | |
cache_examples=False, | |
run_on_click=False | |
) | |
def single_tab_tti(): | |
with gr.Column(): | |
gr.Markdown(""" $(T_h, T_t) : (I_q, ?)$ | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
head_text = gr.Textbox(lines=1, label="Head Name") | |
head_ent = gr.Textbox(lines=1, label="Head Entity") | |
with gr.Column(): | |
tail_text = gr.Textbox(lines=1, label="Tail Name") | |
tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
with gr.Column(): | |
question_image = gr.Image(type='pil', label="Question Image") | |
question_ent = gr.Textbox(lines=1, label="Question Entity") | |
submit_btn = gr.Button("Submit") | |
output_text = gr.Textbox(label="Output") | |
submit_btn.click(fn=single_inference_tti, | |
inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], | |
outputs=[output_text]) | |
examples=[['scrap', 'Q3217573', 'watch', 'Q178794', 'examples/root.jpg', 'Q111029']] | |
ex = gr.Examples( | |
examples=examples, | |
fn=single_inference_iit, | |
inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent], | |
outputs=[output_text], | |
cache_examples=False, | |
run_on_click=False | |
) | |
def blended_tab_iti(): | |
with gr.Column(): | |
gr.Markdown(""" $(I_h, T_t) : (I_q, ?)$ | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
head_image = gr.Image(type='pil', label="Head Image") | |
head_ent = gr.Textbox(lines=1, label="Head Entity") | |
with gr.Column(): | |
tail_txt = gr.Textbox(lines=1, label="Tail Name") | |
tail_ent = gr.Textbox(lines=1, label="Tail Entity") | |
with gr.Column(): | |
question_image = gr.Image(type='pil', label="Question Image") | |
question_ent = gr.Textbox(lines=1, label="Question Entity") | |
submit_btn = gr.Button("Submit") | |
output_text = gr.Textbox(label="Output") | |
submit_btn.click(fn=blended_inference_iti, | |
inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], | |
outputs=[output_text]) | |
examples=[['examples/watermelon.jpg', 'Q38645', 'fruit', 'Q3314483', 'examples/coffee.jpeg', 'Q8486']] | |
ex = gr.Examples( | |
examples=examples, | |
fn=single_inference_iit, | |
inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent], | |
outputs=[output_text], | |
cache_examples=False, | |
run_on_click=False | |
) | |
TITLE = """MKG Analogy""" | |
with gr.Blocks() as block: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(TITLE) | |
with gr.Tab("Single Analogical Reasoning"): | |
single_tab_iit() | |
single_tab_tti() | |
with gr.Tab("Blended Analogical Reasoning"): | |
blended_tab_iti() | |
# gr.HTML(ARTICLE) | |
block.queue(max_size=64).launch(enable_queue=True) |