MKG_Analogy / app.py
flow3rdown's picture
fix bug
f563f80
import gradio as gr
import torch
from torch import nn
from huggingface_hub import hf_hub_download
from transformers import BertModel, BertTokenizer, CLIPModel, BertConfig, CLIPConfig, CLIPProcessor
from modeling_unimo import UnimoForMaskedLM
def load_dict_text(path):
with open(path, 'r') as f:
load_data = {}
lines = f.readlines()
for line in lines:
key, value = line.split('\t')
load_data[key] = value.replace('\n', '')
return load_data
def load_text(path):
with open(path, 'r') as f:
lines = f.readlines()
load_data = []
for line in lines:
load_data.append(line.strip().replace('\n', ''))
return load_data
class MKGformerModel(nn.Module):
def __init__(self, text_config, vision_config):
super().__init__()
self.model = UnimoForMaskedLM(text_config, vision_config)
def farword(self, batch):
return self.model(**batch, return_dict=True)
# tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# entity and relation
ent2text = load_dict_text('./dataset/MarKG/entity2text.txt')
rel2text = load_dict_text('./dataset/MarKG/relation2text.txt')
analogy_entities = load_text('./dataset/MARS/analogy_entities.txt')
analogy_relations = load_text('./dataset/MARS/analogy_relations.txt')
ent2description = load_dict_text('./dataset/MarKG/entity2textlong.txt')
text2ent = {text: ent for ent, text in ent2text.items()}
ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(ent2description)}
rel2token = {rel: f"[RELATION_{i}]" for i, rel in enumerate(rel2text)}
analogy_ent2token = {ent : f"[ENTITY_{i}]" for i, ent in enumerate(ent2description) if ent in analogy_entities}
analogy_rel2token = {rel : f"[RELATION_{i}]" for i, rel in enumerate(rel2text) if rel in analogy_relations}
entity_list = list(ent2token.values())
relation_list = list(rel2token.values())
analogy_ent_list = list(analogy_ent2token.values())
analogy_rel_list = list(analogy_rel2token.values())
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': relation_list})
vocab = tokenizer.get_added_vocab() # dict: word: idx
relation_id_st = vocab[relation_list[0]]
relation_id_ed = vocab[relation_list[-1]] + 1
entity_id_st = vocab[entity_list[0]]
entity_id_ed = vocab[entity_list[-1]] + 1
# analogy entities and relations
analogy_entity_ids = [vocab[ent] for ent in analogy_ent_list]
analogy_relation_ids = [vocab[rel] for rel in analogy_rel_list]
num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ["[R]"]})
# model
checkpoint_path = hf_hub_download(repo_id='flow3rdown/mkgformer_mart_ft', filename="mkgformer_mart_ft", repo_type='model')
clip_config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch32').vision_config
clip_config.device = 'cpu'
bert_config = BertConfig.from_pretrained('bert-base-uncased')
mkgformer = MKGformerModel(clip_config, bert_config)
mkgformer.model.resize_token_embeddings(len(tokenizer))
mkgformer.load_state_dict(torch.load(checkpoint_path, map_location='cpu')["state_dict"])
# processor
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
def single_inference_iit(head_img, head_id, tail_img, tail_id, question_txt, question_id):
# (I, I) -> (T, ?)
ques_ent_text = ent2description[question_id]
inputs = tokenizer(
tokenizer.sep_token.join([analogy_ent2token[head_id] + " ", "[R] ", analogy_ent2token[tail_id] + " "]),
tokenizer.sep_token.join([analogy_ent2token[question_id] + " " + ques_ent_text, "[R] ", "[MASK]"]),
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
inputs['sep_idx'] = torch.tensor(sep_idx)
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone()
for i, idx in enumerate(sep_idx):
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0
# image
pixel_values = processor(images=[head_img, tail_img], return_tensors='pt')['pixel_values'].squeeze()
inputs['pixel_values'] = pixel_values.unsqueeze(0)
input_ids = inputs['input_ids']
model_output = mkgformer.model(**inputs, return_dict=True)
logits = model_output[0].logits
bsz = input_ids.shape[0]
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]]
return answer
def single_inference_tti(head_txt, head_id, tail_txt, tail_id, question_img, question_id):
# (T, T) -> (I, ?)
head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id]
inputs = tokenizer(
tokenizer.sep_token.join([analogy_ent2token[head_id] + " " + head_ent_text, "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]),
tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]),
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
inputs['sep_idx'] = torch.tensor(sep_idx)
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone()
for i, idx in enumerate(sep_idx):
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0
# image
pixel_values = processor(images=question_img, return_tensors='pt')['pixel_values'].unsqueeze(1)
pixel_values = torch.cat((pixel_values, torch.zeros_like(pixel_values)), dim=1)
inputs['pixel_values'] = pixel_values
input_ids = inputs['input_ids']
model_output = mkgformer.model(**inputs, return_dict=True)
logits = model_output[0].logits
bsz = input_ids.shape[0]
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]]
return answer
def blended_inference_iti(head_img, head_id, tail_txt, tail_id, question_img, question_id):
# (I, T) -> (I, ?)
head_ent_text, tail_ent_text = ent2description[head_id], ent2description[tail_id]
inputs = tokenizer(
tokenizer.sep_token.join([analogy_ent2token[head_id], "[R] ", analogy_ent2token[tail_id] + " " + tail_ent_text]),
tokenizer.sep_token.join([analogy_ent2token[question_id] + " ", "[R] ", "[MASK]"]),
truncation="longest_first", max_length=128, padding="longest", return_tensors='pt', add_special_tokens=True)
sep_idx = [[i for i, ids in enumerate(input_ids) if ids == tokenizer.sep_token_id] for input_ids in inputs['input_ids']]
inputs['sep_idx'] = torch.tensor(sep_idx)
inputs['attention_mask'] = inputs['attention_mask'].unsqueeze(1).expand([inputs['input_ids'].size(0), inputs['input_ids'].size(1), inputs['input_ids'].size(1)]).clone()
for i, idx in enumerate(sep_idx):
inputs['attention_mask'][i, :idx[2], idx[2]:] = 0
# image
pixel_values = processor(images=[head_img, question_img], return_tensors='pt')['pixel_values'].squeeze()
inputs['pixel_values'] = pixel_values.unsqueeze(0)
input_ids = inputs['input_ids']
model_output = mkgformer.model(**inputs, return_dict=True)
logits = model_output[0].logits
bsz = input_ids.shape[0]
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) # bsz
mask_logits = logits[torch.arange(bsz), mask_idx][:, analogy_entity_ids] # bsz, 1, entity
answer = ent2text[list(analogy_ent2token.keys())[mask_logits.argmax().item()]]
return answer
def single_tab_iit():
with gr.Column():
gr.Markdown(""" $(I_h, I_t) : (T_q, ?)$
""")
with gr.Row():
with gr.Column():
head_image = gr.Image(type='pil', label="Head Image")
head_ent = gr.Textbox(lines=1, label="Head Entity")
with gr.Column():
tail_image = gr.Image(type='pil', label="Tail Image")
tail_ent = gr.Textbox(lines=1, label="Tail Entity")
with gr.Column():
question_text = gr.Textbox(lines=1, label="Question Name")
question_ent = gr.Textbox(lines=1, label="Question Entity")
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Output")
submit_btn.click(fn=single_inference_iit,
inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent],
outputs=[output_text])
examples=[['examples/tree.jpg', 'Q10884', 'examples/forest.jpg', 'Q4421', "Anhui", 'Q40956']]
ex = gr.Examples(
examples=examples,
fn=single_inference_iit,
inputs=[head_image, head_ent, tail_image, tail_ent, question_text, question_ent],
outputs=[output_text],
cache_examples=False,
run_on_click=False
)
def single_tab_tti():
with gr.Column():
gr.Markdown(""" $(T_h, T_t) : (I_q, ?)$
""")
with gr.Row():
with gr.Column():
head_text = gr.Textbox(lines=1, label="Head Name")
head_ent = gr.Textbox(lines=1, label="Head Entity")
with gr.Column():
tail_text = gr.Textbox(lines=1, label="Tail Name")
tail_ent = gr.Textbox(lines=1, label="Tail Entity")
with gr.Column():
question_image = gr.Image(type='pil', label="Question Image")
question_ent = gr.Textbox(lines=1, label="Question Entity")
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Output")
submit_btn.click(fn=single_inference_tti,
inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent],
outputs=[output_text])
examples=[['scrap', 'Q3217573', 'watch', 'Q178794', 'examples/root.jpg', 'Q111029']]
ex = gr.Examples(
examples=examples,
fn=single_inference_iit,
inputs=[head_text, head_ent, tail_text, tail_ent, question_image, question_ent],
outputs=[output_text],
cache_examples=False,
run_on_click=False
)
def blended_tab_iti():
with gr.Column():
gr.Markdown(""" $(I_h, T_t) : (I_q, ?)$
""")
with gr.Row():
with gr.Column():
head_image = gr.Image(type='pil', label="Head Image")
head_ent = gr.Textbox(lines=1, label="Head Entity")
with gr.Column():
tail_txt = gr.Textbox(lines=1, label="Tail Name")
tail_ent = gr.Textbox(lines=1, label="Tail Entity")
with gr.Column():
question_image = gr.Image(type='pil', label="Question Image")
question_ent = gr.Textbox(lines=1, label="Question Entity")
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Output")
submit_btn.click(fn=blended_inference_iti,
inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent],
outputs=[output_text])
examples=[['examples/watermelon.jpg', 'Q38645', 'fruit', 'Q3314483', 'examples/coffee.jpeg', 'Q8486']]
ex = gr.Examples(
examples=examples,
fn=single_inference_iit,
inputs=[head_image, head_ent, tail_txt, tail_ent, question_image, question_ent],
outputs=[output_text],
cache_examples=False,
run_on_click=False
)
TITLE = """MKG Analogy"""
with gr.Blocks() as block:
with gr.Column(elem_id="col-container"):
gr.HTML(TITLE)
with gr.Tab("Single Analogical Reasoning"):
single_tab_iit()
single_tab_tti()
with gr.Tab("Blended Analogical Reasoning"):
blended_tab_iti()
# gr.HTML(ARTICLE)
block.queue(max_size=64).launch(enable_queue=True)