kgt5 / app.py
Apoorv Saxena
Update app.py
a1b42e5
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def getScores(ids, scores, pad_token_id):
"""get sequence scores from model.generate output"""
scores = torch.stack(scores, dim=1)
log_probs = torch.log_softmax(scores, dim=2)
# remove start token
ids = ids[:,1:]
# gather needed probs
x = ids.unsqueeze(-1).expand(log_probs.shape)
needed_logits = torch.gather(log_probs, 2, x)
final_logits = needed_logits[:, :, 0]
padded_mask = (ids == pad_token_id)
final_logits[padded_mask] = 0
final_scores = final_logits.sum(dim=-1)
return final_scores.cpu().detach().numpy()
def topkSample(input, model, tokenizer,
num_samples=5,
num_beams=1,
max_output_length=30):
tokenized = tokenizer(input, return_tensors="pt")
out = model.generate(**tokenized,
do_sample=True,
num_return_sequences = num_samples,
num_beams = num_beams,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
output_scores = True,
return_dict_in_generate=True,
max_length=max_output_length,)
out_tokens = out.sequences
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
return sorted_pair_list
def greedyPredict(input, model, tokenizer):
input_ids = tokenizer([input], return_tensors="pt").input_ids
out_tokens = model.generate(input_ids)
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
return out_str[0]
def predict_tail(entity, relation):
global model, tokenizer
input = entity + "| " + relation
out = topkSample(input, model, tokenizer, num_samples=25)
out_dict = {}
for k, v in out:
out_dict[k] = np.exp(v).item()
return out_dict
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
ent_input = gr.inputs.Textbox(lines=1, default="Apoorv Umang Saxena")
rel_input = gr.inputs.Textbox(lines=1, default="country")
output = gr.outputs.Label()
examples = [
['Adrian Kochsiek', 'sex or gender'],
['Apoorv Umang Saxena', 'family name'],
['World War II', 'followed by'],
['Apoorv Umang Saxena', 'country'],
['Ippolito Boccolini', 'writing language'] ,
['Roelant', 'writing system'] ,
['The Accountant 2227', 'language of work or name'] ,
['Microbial Infection and AMR in Hospitalized Patients With Covid 19', 'study type'] ,
['Carla Fracci', 'manner of death'] ,
['list of programs broadcast by Comet', 'is a list of'] ,
['Loreta Podhradí', 'continent'] ,
['Opistognathotrema', 'taxon rank'] ,
['Museum Arbeitswelt Steyr', 'wheelchair accessibility'] ,
['Heliotropium tytoides', 'subject has role'] ,
['School bus crash rates on routine and nonroutine routes.', 'sponsor'] ,
['Tachigalieae', 'taxon rank'] ,
['Irena Salusová', 'place of detention'] ,
]
title = "Interactive demo: KGT5"
description = """Demo for <a href='https://arxiv.org/abs/2203.10321'>Sequence-to-Sequence Knowledge Graph Completion and Question Answering </a> (KGT5). This particular model is a T5-base model trained on the task of tail prediction on WikiKG90Mv2 dataset and obtains 0.239 validation MRR on this task (<a href="https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2">leaderboard</a>, see paper for details).
To use it, simply give an entity name and relation and click 'submit'. Upto 25 model predictions will show up in a few seconds. The model works best when the exact entity/relation names that it has been trained on are used.
It is sometimes able to generalize to unseen entities as well (see examples).
"""
#article = """
#<p style='text-align: center'><a href='https://arxiv.org/abs/2203.10321'>Sequence-to-Sequence Knowledge Graph Completion and Question Answering </a> | <a href='https://github.com/apoorvumang/kgt5'>Github Repo</a></p>
#"""
article = """
Under the hood, this demo concatenates the entity and relation, feeds it to the model and then samples 25 sequences, which are then ranked according to their sequence probabilities.
<br>
The text representations of the relations and entities can be downloaded from here: <a href="https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle">https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle</a> and
<a href="https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle">https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle</a>
<br>
For more details see the <a href='https://github.com/apoorvumang/kgt5'>Github repo</a> or the <a href="https://huggingface.co/apoorvumang/kgt5-base-wikikg90mv2">hf model page</a>.
"""
iface = gr.Interface(fn=predict_tail,
inputs=[ent_input, rel_input],
outputs=output,
title=title,
description=description,
article=article,
examples=examples,)
iface.launch()