kkpathak91 commited on
Commit
eb56186
1 Parent(s): 547d257

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Author : karan kumar pathak
3
+ @Contact : 2020fc04334@wilp.bits-pilani.com
4
+ @Description:
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from huggingface_hub import snapshot_download
10
+ from prettytable import PrettyTable
11
+ import pandas as pd
12
+ import torch
13
+ import traceback
14
+
15
+ config = {
16
+ "model_type": "roberta",
17
+ "model_name_or_path": "roberta-large",
18
+ "logic_lambda": 0.5,
19
+ "prior": "random",
20
+ "mask_rate": 0.0,
21
+ "cand_k": 1,
22
+ "max_seq1_length": 256,
23
+ "max_seq2_length": 128,
24
+ "max_num_questions": 8,
25
+ "do_lower_case": False,
26
+ "seed": 42,
27
+ "n_gpu": torch.cuda.device_count(),
28
+ }
29
+
30
+ os.system('git clone https://github.com/jiangjiechen/LOREN/')
31
+ os.system('rm -r LOREN/data/')
32
+ os.system('rm -r LOREN/results/')
33
+ os.system('rm -r LOREN/models/')
34
+ os.system('mv LOREN/* ./')
35
+
36
+ model_dir = snapshot_download('Jiangjie/loren')
37
+ config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
38
+ config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
39
+ config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')
40
+
41
+
42
+ from src.loren import Loren
43
+
44
+
45
+ loren = Loren(config, verbose=False)
46
+ try:
47
+ js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
48
+ except Exception as e:
49
+ raise ValueError(e)
50
+
51
+
52
+ def highlight_phrase(text, phrase):
53
+ text = loren.fc_client.tokenizer.clean_up_tokenization(text)
54
+ return text.replace('<mask>', f'<i><b>{phrase}</b></i>')
55
+
56
+
57
+ def highlight_entity(text, entity):
58
+ return text.replace(entity, f'<i><b>{entity}</b></i>')
59
+
60
+
61
+ def gradio_formatter(js, output_type):
62
+ zebra_css = '''
63
+ tr:nth-child(even) {
64
+ background: #f1f1f1;
65
+ }
66
+ thead{
67
+ background: #f1f1f1;
68
+ }'''
69
+ if output_type == 'e':
70
+ data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
71
+ elif output_type == 'z':
72
+ p_sup, p_ref, p_nei = [], [], []
73
+ for x in js['phrase_veracity']:
74
+ max_idx = torch.argmax(torch.tensor(x)).tolist()
75
+ x = ['%.4f' % xx for xx in x]
76
+ x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
77
+ p_sup.append(x[2])
78
+ p_ref.append(x[0])
79
+ p_nei.append(x[1])
80
+
81
+ data = {
82
+ 'Claim Phrase': js['claim_phrases'],
83
+ 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
84
+ 'p_SUP': p_sup,
85
+ 'p_REF': p_ref,
86
+ 'p_NEI': p_nei,
87
+ }
88
+ else:
89
+ raise NotImplementedError
90
+ data = pd.DataFrame(data)
91
+ pt = PrettyTable(field_names=list(data.columns),
92
+ align='l', border=True, hrules=1, vrules=1)
93
+ for v in data.values:
94
+ pt.add_row(v)
95
+ html = pt.get_html_string(attributes={
96
+ 'style': 'border-width: 2px; bordercolor: black'
97
+ }, format=True)
98
+ html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
99
+ html = html.replace('&lt;', '<').replace('&gt;', '>')
100
+ return html
101
+
102
+
103
+ def run(claim):
104
+ try:
105
+ js = loren.check(claim)
106
+ except Exception as error_msg:
107
+ exc = traceback.format_exc()
108
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
109
+ loren.logger.error(claim)
110
+ loren.logger.error(msg)
111
+ return 'Oops, something went wrong.', '', ''
112
+ label = js['claim_veracity']
113
+ loren.logger.warning(label + str(js))
114
+ ev_html = gradio_formatter(js, 'e')
115
+ z_html = gradio_formatter(js, 'z')
116
+ return label, z_html, ev_html
117
+
118
+
119
+ iface = gr.Interface(
120
+ fn=run,
121
+ inputs="text",
122
+ outputs=[
123
+ 'text',
124
+ 'html',
125
+ 'html',
126
+ ],
127
+ examples=['Donald Trump won the U.S. 2020 presidential election.',
128
+ 'The first inauguration of Bill Clinton was in the United States.',
129
+ 'The Cry of the Owl is based on a book by an American.',
130
+ 'Smriti Mandhana is an Indian woman.'],
131
+ title="LOREN",
132
+ layout='horizontal',
133
+ description="LOREN is an interpretable Fact Verification model using Wikipedia as its knowledge source. "
134
+ "This is a demo system for the AAAI 2022 paper: \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\"(https://arxiv.org/abs/2012.13577). "
135
+ "See the paper for more details. You can add a *FLAG* on the bottom to record interesting or bad cases! "
136
+ "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)",
137
+ flagging_dir='results/flagged/',
138
+ allow_flagging=True,
139
+ flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
140
+ 'Error: Require Commonsense', 'Error: Evidence Retrieval'],
141
+ enable_queue=True
142
+ )
143
+ iface.launch()