kkpathak91 commited on
Commit
8ae3ac6
1 Parent(s): 921df19

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ @Author : Jiangjie Chen
5
+ @Time : 2021/12/13 17:17
6
+ @Contact : jjchen19@fudan.edu.cn
7
+ @Description:
8
+ """
9
+
10
+ import os
11
+ import gradio as gr
12
+ from huggingface_hub import snapshot_download
13
+ from prettytable import PrettyTable
14
+ import pandas as pd
15
+ import torch
16
+ import traceback
17
+
18
+ config = {
19
+ "model_type": "roberta",
20
+ "model_name_or_path": "roberta-large",
21
+ "logic_lambda": 0.5,
22
+ "prior": "random",
23
+ "mask_rate": 0.0,
24
+ "cand_k": 1,
25
+ "max_seq1_length": 256,
26
+ "max_seq2_length": 128,
27
+ "max_num_questions": 8,
28
+ "do_lower_case": False,
29
+ "seed": 42,
30
+ "n_gpu": torch.cuda.device_count(),
31
+ }
32
+
33
+ os.system('git clone https://github.com/jiangjiechen/LOREN/')
34
+ os.system('rm -r LOREN/data/')
35
+ os.system('rm -r LOREN/results/')
36
+ os.system('rm -r LOREN/models/')
37
+ os.system('mv LOREN/* ./')
38
+
39
+ model_dir = snapshot_download('Jiangjie/loren')
40
+ config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
41
+ config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
42
+ config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')
43
+
44
+
45
+ from src.loren import Loren
46
+
47
+
48
+ loren = Loren(config, verbose=False)
49
+ try:
50
+ js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
51
+ except Exception as e:
52
+ raise ValueError(e)
53
+
54
+
55
+ def highlight_phrase(text, phrase):
56
+ text = loren.fc_client.tokenizer.clean_up_tokenization(text)
57
+ return text.replace('<mask>', f'<i><b>{phrase}</b></i>')
58
+
59
+
60
+ def highlight_entity(text, entity):
61
+ return text.replace(entity, f'<i><b>{entity}</b></i>')
62
+
63
+
64
+ def gradio_formatter(js, output_type):
65
+ zebra_css = '''
66
+ tr:nth-child(even) {
67
+ background: #f1f1f1;
68
+ }
69
+ thead{
70
+ background: #f1f1f1;
71
+ }'''
72
+ if output_type == 'e':
73
+ data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
74
+ elif output_type == 'z':
75
+ p_sup, p_ref, p_nei = [], [], []
76
+ for x in js['phrase_veracity']:
77
+ max_idx = torch.argmax(torch.tensor(x)).tolist()
78
+ x = ['%.4f' % xx for xx in x]
79
+ x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
80
+ p_sup.append(x[2])
81
+ p_ref.append(x[0])
82
+ p_nei.append(x[1])
83
+
84
+ data = {
85
+ 'Claim Phrase': js['claim_phrases'],
86
+ 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
87
+ 'p_SUP': p_sup,
88
+ 'p_REF': p_ref,
89
+ 'p_NEI': p_nei,
90
+ }
91
+ else:
92
+ raise NotImplementedError
93
+ data = pd.DataFrame(data)
94
+ pt = PrettyTable(field_names=list(data.columns),
95
+ align='l', border=True, hrules=1, vrules=1)
96
+ for v in data.values:
97
+ pt.add_row(v)
98
+ html = pt.get_html_string(attributes={
99
+ 'style': 'border-width: 2px; bordercolor: black'
100
+ }, format=True)
101
+ html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
102
+ html = html.replace('&lt;', '<').replace('&gt;', '>')
103
+ return html
104
+
105
+
106
+ def run(claim):
107
+ try:
108
+ js = loren.check(claim)
109
+ except Exception as error_msg:
110
+ exc = traceback.format_exc()
111
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
112
+ loren.logger.error(claim)
113
+ loren.logger.error(msg)
114
+ return 'Oops, something went wrong.', '', ''
115
+ label = js['claim_veracity']
116
+ loren.logger.warning(label + str(js))
117
+ ev_html = gradio_formatter(js, 'e')
118
+ z_html = gradio_formatter(js, 'z')
119
+ return label, z_html, ev_html
120
+
121
+
122
+ iface = gr.Interface(
123
+ fn=run,
124
+ inputs="text",
125
+ outputs=[
126
+ 'text',
127
+ 'html',
128
+ 'html',
129
+ ],
130
+ examples=['Donald Trump won the U.S. 2020 presidential election.',
131
+ 'The first inauguration of Bill Clinton was in the United States.',
132
+ 'The Cry of the Owl is based on a book by an American.',
133
+ 'Smriti Mandhana is an Indian woman.'],
134
+ title="LOREN",
135
+ layout='horizontal',
136
+ description="LOREN is an interpretable Fact Verification model using Wikipedia as its knowledge source. "
137
+ "This is a demo system for the AAAI 2022 paper: \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\"(https://arxiv.org/abs/2012.13577). "
138
+ "See the paper for more details. You can add a *FLAG* on the bottom to record interesting or bad cases! "
139
+ "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)",
140
+ flagging_dir='results/flagged/',
141
+ allow_flagging=True,
142
+ flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
143
+ 'Error: Require Commonsense', 'Error: Evidence Retrieval'],
144
+ enable_queue=True
145
+ )
146
+ iface.launch()