########################################################################################################## import os import gradio as gr from huggingface_hub import snapshot_download from prettytable import PrettyTable import pandas as pd import torch import traceback config = { "model_type": "roberta", "model_name_or_path": "roberta-large", "logic_lambda": 0.5, "prior": "random", "mask_rate": 0.0, "cand_k": 1, "max_seq1_length": 256, "max_seq2_length": 128, "max_num_questions": 8, "do_lower_case": False, "seed": 42, "n_gpu": torch.cuda.device_count(), } os.system('git clone https://github.com/kkpathak91/project_metch/') os.system('rm -r project_metch/data/') os.system('rm -r project_metch/results/') os.system('rm -r project_metch/models/') os.system('mv project_metch/* ./') model_dir = snapshot_download('kkpathak91/FVM') config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/') config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/') config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/') from src.loren import Loren loren = Loren(config, verbose=False) try: js = loren.check('Donald Trump won the 2020 U.S. presidential election.') except Exception as e: raise ValueError(e) def highlight_phrase(text, phrase): text = loren.fc_client.tokenizer.clean_up_tokenization(text) return text.replace('', f'{phrase}') def highlight_entity(text, entity): return text.replace(entity, f'{entity}') def gradio_formatter(js, output_type): zebra_css = ''' tr:nth-child(even) { background: #f1f1f1; } thead{ background: #f1f1f1; }''' if output_type == 'e': data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]} elif output_type == 'z': p_sup, p_ref, p_nei = [], [], [] for x in js['phrase_veracity']: max_idx = torch.argmax(torch.tensor(x)).tolist() x = ['%.4f' % xx for xx in x] x[max_idx] = f'{x[max_idx]}' p_sup.append(x[2]) p_ref.append(x[0]) p_nei.append(x[1]) data = { 'Claim Phrase': js['claim_phrases'], 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])], 'p_SUP': p_sup, 'p_REF': p_ref, 'p_NEI': p_nei, } else: raise NotImplementedError data = pd.DataFrame(data) pt = PrettyTable(field_names=list(data.columns), align='l', border=True, hrules=1, vrules=1) for v in data.values: pt.add_row(v) html = pt.get_html_string(attributes={ 'style': 'border-width: 2px; bordercolor: black' }, format=True) html = f' \n' + html html = html.replace('<', '<').replace('>', '>') return html def run(claim): try: js = loren.check(claim) except Exception as error_msg: exc = traceback.format_exc() msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' loren.logger.error(claim) loren.logger.error(msg) return 'Oops, something went wrong.', '', '' label = js['claim_veracity'] loren.logger.warning(label + str(js)) ev_html = gradio_formatter(js, 'e') z_html = gradio_formatter(js, 'z') return label, z_html, ev_html iface = gr.Interface( fn=run, inputs="text", outputs=[ 'text', 'html', 'html', ], examples=['Kanpur is a city in Nepal', 'PV Sindhu is an Indian Badminton Player.'], title="A Framework for Data-Driven Document Evaluation and Scoring", layout='horizontal', description="[Student Name: Karan Kumar Pathak] " " [Roll No.: 2020fc04334] ", flagging_dir='results/flagged/', allow_flagging=True, flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise', 'Error: Require Commonsense', 'Error: Evidence Retrieval'], enable_queue=True ) iface.launch()