kkpathak91 commited on
Commit
6830ff4
1 Parent(s): a474698

Create new file

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install paddlepaddle')
3
+ os.system('pip install paddleocr')
4
+ from paddleocr import PaddleOCR, draw_ocr
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import torch
8
+
9
+ torch.hub.download_url_to_file('https://i.imgur.com/aqMBT0i.jpg', 'example.jpg')
10
+
11
+ def inference(img, lang):
12
+ ocr = PaddleOCR(use_angle_cls=True, lang=lang,use_gpu=False)
13
+ img_path = img.name
14
+ result = ocr.ocr(img_path, cls=True)
15
+ image = Image.open(img_path).convert('RGB')
16
+ boxes = [line[0] for line in result]
17
+ txts = [line[1][0] for line in result]
18
+ # scores = [line[1][1] for line in result]
19
+ im_show = draw_ocr(image, boxes, txts,
20
+ font_path='simfang.ttf')
21
+ im_show = Image.fromarray(im_show)
22
+ im_show.save('result.jpg')
23
+ return 'result.jpg'
24
+
25
+ title = 'A Framework for Data-Driven Document Evaluation and scoring - Image to Text Extraction '
26
+ description = 'Demo for Optical character recognition(OCR)'
27
+ article = ""
28
+ examples = [['example.jpg','en']]
29
+ css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
30
+ gr.Interface(
31
+ inference,
32
+ [gr.inputs.Image(type='file', label='Input'),gr.inputs.Dropdown(choices=['ch', 'en', 'fr', 'german', 'korean', 'japan'], type="value", default='en', label='language')],
33
+ gr.outputs.Image(type='file', label='Output'),
34
+ title=title,
35
+ description=description,
36
+ article=article,
37
+ examples=examples,
38
+ css=css,
39
+ enable_queue=True
40
+ ).launch(debug=True)
41
+
42
+
43
+ ##########################################################################################################
44
+
45
+ import os
46
+ import gradio as gr
47
+ from huggingface_hub import snapshot_download
48
+ from prettytable import PrettyTable
49
+ import pandas as pd
50
+ import torch
51
+ import traceback
52
+
53
+ config = {
54
+ "model_type": "roberta",
55
+ "model_name_or_path": "roberta-large",
56
+ "logic_lambda": 0.5,
57
+ "prior": "random",
58
+ "mask_rate": 0.0,
59
+ "cand_k": 1,
60
+ "max_seq1_length": 256,
61
+ "max_seq2_length": 128,
62
+ "max_num_questions": 8,
63
+ "do_lower_case": False,
64
+ "seed": 42,
65
+ "n_gpu": torch.cuda.device_count(),
66
+ }
67
+
68
+ os.system('git clone https://github.com/kkpathak91/project_metch/')
69
+ os.system('rm -r project_metch/data/')
70
+ os.system('rm -r project_metch/results/')
71
+ os.system('rm -r project_metch/models/')
72
+ os.system('mv project_metch/* ./')
73
+
74
+ model_dir = snapshot_download('kkpathak91/FVM')
75
+ config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
76
+ config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
77
+ config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')
78
+
79
+
80
+ from src.loren import Loren
81
+
82
+
83
+ loren = Loren(config, verbose=False)
84
+ try:
85
+ js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
86
+ except Exception as e:
87
+ raise ValueError(e)
88
+
89
+
90
+ def highlight_phrase(text, phrase):
91
+ text = loren.fc_client.tokenizer.clean_up_tokenization(text)
92
+ return text.replace('<mask>', f'<i><b>{phrase}</b></i>')
93
+
94
+
95
+ def highlight_entity(text, entity):
96
+ return text.replace(entity, f'<i><b>{entity}</b></i>')
97
+
98
+
99
+ def gradio_formatter(js, output_type):
100
+ zebra_css = '''
101
+ tr:nth-child(even) {
102
+ background: #f1f1f1;
103
+ }
104
+ thead{
105
+ background: #f1f1f1;
106
+ }'''
107
+ if output_type == 'e':
108
+ data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
109
+ elif output_type == 'z':
110
+ p_sup, p_ref, p_nei = [], [], []
111
+ for x in js['phrase_veracity']:
112
+ max_idx = torch.argmax(torch.tensor(x)).tolist()
113
+ x = ['%.4f' % xx for xx in x]
114
+ x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
115
+ p_sup.append(x[2])
116
+ p_ref.append(x[0])
117
+ p_nei.append(x[1])
118
+
119
+ data = {
120
+ 'Claim Phrase': js['claim_phrases'],
121
+ 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
122
+ 'p_SUP': p_sup,
123
+ 'p_REF': p_ref,
124
+ 'p_NEI': p_nei,
125
+ }
126
+ else:
127
+ raise NotImplementedError
128
+ data = pd.DataFrame(data)
129
+ pt = PrettyTable(field_names=list(data.columns),
130
+ align='l', border=True, hrules=1, vrules=1)
131
+ for v in data.values:
132
+ pt.add_row(v)
133
+ html = pt.get_html_string(attributes={
134
+ 'style': 'border-width: 2px; bordercolor: black'
135
+ }, format=True)
136
+ html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
137
+ html = html.replace('&lt;', '<').replace('&gt;', '>')
138
+ return html
139
+
140
+
141
+ def run(claim):
142
+ try:
143
+ js = loren.check(claim)
144
+ except Exception as error_msg:
145
+ exc = traceback.format_exc()
146
+ msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}'
147
+ loren.logger.error(claim)
148
+ loren.logger.error(msg)
149
+ return 'Oops, something went wrong.', '', ''
150
+ label = js['claim_veracity']
151
+ loren.logger.warning(label + str(js))
152
+ ev_html = gradio_formatter(js, 'e')
153
+ z_html = gradio_formatter(js, 'z')
154
+ return label, z_html, ev_html
155
+
156
+
157
+ iface = gr.Interface(
158
+ fn=run,
159
+ inputs="text",
160
+ outputs=[
161
+ 'text',
162
+ 'html',
163
+ 'html',
164
+ ],
165
+ examples=['Kanpur is a city in Nepal',
166
+ 'PV Sindhu is an Indian Badminton Player.'],
167
+ title="A Framework for Data-Driven Document Evaluation and Scoring",
168
+ layout='horizontal',
169
+ description="[Student Name: Karan Kumar Pathak] " " [Roll No.: 2020fc04334] ",
170
+ flagging_dir='results/flagged/',
171
+ allow_flagging=True,
172
+ flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
173
+ 'Error: Require Commonsense', 'Error: Evidence Retrieval'],
174
+ enable_queue=True
175
+ )
176
+ iface.launch()