kkpathak91's picture
Update app.py
0ebe145
##########################################################################################################
import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch
import traceback
config = {
"model_type": "roberta",
"model_name_or_path": "roberta-large",
"logic_lambda": 0.5,
"prior": "random",
"mask_rate": 0.0,
"cand_k": 1,
"max_seq1_length": 256,
"max_seq2_length": 128,
"max_num_questions": 8,
"do_lower_case": False,
"seed": 42,
"n_gpu": torch.cuda.device_count(),
}
os.system('git clone https://github.com/kkpathak91/project_metch/')
os.system('rm -r project_metch/data/')
os.system('rm -r project_metch/results/')
os.system('rm -r project_metch/models/')
os.system('mv project_metch/* ./')
model_dir = snapshot_download('kkpathak91/FVM')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')
from src.loren import Loren
loren = Loren(config, verbose=False)
try:
js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
raise ValueError(e)
def highlight_phrase(text, phrase):
text = loren.fc_client.tokenizer.clean_up_tokenization(text)
return text.replace('<mask>', f'<i><b>{phrase}</b></i>')
def highlight_entity(text, entity):
return text.replace(entity, f'<i><b>{entity}</b></i>')
def gradio_formatter(js, output_type):
zebra_css = '''
tr:nth-child(even) {
background: #f1f1f1;
}
thead{
background: #f1f1f1;
}'''
if output_type == 'e':
data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
elif output_type == 'z':
p_sup, p_ref, p_nei = [], [], []
for x in js['phrase_veracity']:
max_idx = torch.argmax(torch.tensor(x)).tolist()
x = ['%.4f' % xx for xx in x]
x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
p_sup.append(x[2])
p_ref.append(x[0])
p_nei.append(x[1])
data = {
'Claim Phrase': js['claim_phrases'],
'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
'p_SUP': p_sup,
'p_REF': p_ref,
'p_NEI': p_nei,
}
else:
raise NotImplementedError
data = pd.DataFrame(data)
pt = PrettyTable(field_names=list(data.columns),
align='l', border=True, hrules=1, vrules=1)
for v in data.values:
pt.add_row(v)
html = pt.get_html_string(attributes={
'style': 'border-width: 2px; bordercolor: black'
}, format=True)
html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
html = html.replace('&lt;', '<').replace('&gt;', '>')
return html
def run(claim):
try:
js = loren.check(claim)
except Exception as error_msg:
exc = traceback.format_exc()
msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
loren.logger.error(claim)
loren.logger.error(msg)
return 'Oops, something went wrong.', '', ''
label = js['claim_veracity']
loren.logger.warning(label + str(js))
ev_html = gradio_formatter(js, 'e')
z_html = gradio_formatter(js, 'z')
return label, z_html, ev_html
iface = gr.Interface(
fn=run,
inputs="text",
outputs=[
'text',
'html',
'html',
],
examples=['Kanpur is a city in Nepal',
'PV Sindhu is an Indian Badminton Player.'],
title="A Framework for Data-Driven Document Evaluation and Scoring",
layout='horizontal',
description="[Student Name: Karan Kumar Pathak] " " [Roll No.: 2020fc04334] ",
flagging_dir='results/flagged/',
allow_flagging=True,
flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
'Error: Require Commonsense', 'Error: Evidence Retrieval'],
enable_queue=True
)
iface.launch()