import util as cu
import gradio as gr
from collections import defaultdict
from io import StringIO
from urllib.parse import urlparse
chunk_separator = '[...]\n\n'
def get_url_to_supporting_cid_ctext_tuples(atom_support_l):
url_to_supporting_cid_sets = defaultdict(set)
url_to_supporting_cid_ctext_tuples = defaultdict(list)
for atom_support in atom_support_l:
for url, aggmatch_determination in atom_support.items():
if aggmatch_determination['true']:
for cid, ctext in zip(aggmatch_determination['id_l'], aggmatch_determination['chunk_text_l']):
if cid not in url_to_supporting_cid_sets[url]:
url_to_supporting_cid_sets[url].add(cid)
url_to_supporting_cid_ctext_tuples[url].append((cid, ctext))
# now sort each list of chunks
for url, cid_ctext_tuple_l in url_to_supporting_cid_ctext_tuples.items():
url_to_supporting_cid_ctext_tuples[url] = sorted(cid_ctext_tuple_l, key=lambda x: x[0])
# pprint.pp(url_to_supporting_cid_ctext_tuples)
return url_to_supporting_cid_ctext_tuples
def output_credit_dist(msg, cur_idx, _out_credit, _out_claims):
print('Start output_credit_dist.')
_out_credit.truncate(0)
_out_credit.seek(0)
_out_claims.truncate(0)
_out_claims.seek(0)
print(cu.style_str, file=_out_credit)
print(cu.style_str, file=_out_claims)
atoms_l, atom_topkmatches_l, credit_l = [], [], []
if len(msg) > 10:
atoms_l = cu.get_atoms_list(msg)
atoms_l = list(filter(lambda x: len(x) > 10, atoms_l))
if atoms_l:
atom_topkmatches_l = cu.get_atom_topk_matches_l_concurrent(atoms_l, max_workers=8)
print('Got atom chunk matches')
atomidx_w_single_url_aggmatch_l = cu.aggregate_atom_topkmatches_l(atom_topkmatches_l)
print('Aggregated atom chunk matches')
atom_support_l = cu.get_atmom_support_l_from_atomidx_w_single_url_aggmatch_l_concurrent(atoms_l, atomidx_w_single_url_aggmatch_l, max_workers=8)
print('Got atom support list')
credit_dist = cu.credit_atom_support_list(atom_support_l)
print('Computed credit distribution')
url_to_supporting_cid_ctext_tuples = get_url_to_supporting_cid_ctext_tuples(atom_support_l)
url_to_title = {}
for atom_topkmatches in atom_topkmatches_l:
for match in atom_topkmatches:
url_to_title[match['metadata']['url']] = match['metadata']['title']
credit_l = [(url, w) for url, w in credit_dist.items()]
credit_l = sorted(credit_l, key=lambda x: x[1], reverse=True)
print('Computed credit_l')
if not atom_topkmatches_l:
print(f"
", file=_out_credit)
print(f"
No sources were found that are relevant this target.
", file=_out_credit)
print(f"
", file=_out_credit)
return '', _out_credit.getvalue()
if not credit_l:
print(f"", file=_out_credit)
print(f"
No sources were found that strongly support this target.
", file=_out_credit)
print(f"
", file=_out_credit)
for url, w in credit_l:
match_text = chunk_separator.join([x[1] for x in url_to_supporting_cid_ctext_tuples[url]])
print(f"{url} cids: {[x[0] for x in url_to_supporting_cid_ctext_tuples[url]]}")
print(f"", file=_out_credit)
favicon = f"
"
print(f"
{favicon}  {url_to_title[url]}{100*w:.0f}%
", file=_out_credit)
print(f"
", file=_out_credit)
print(f"
{match_text}
", file=_out_credit)
print(f"
", file=_out_credit)
print(f"", file=_out_claims)
print(f"
Breakdown of article support for each extracted claim
", file=_out_claims)
for j, atom_support in enumerate(atom_support_l):
n_urls = len(atom_support.keys())
n_support = sum([1 if determination['true'] else 0 for determination in atom_support.values()])
print(f"", file=_out_claims)
for url, aggmatch_determination in atom_support.items():
title = url_to_title[url]
print(f"
{title}
", file=_out_claims)
print(f"
", file=_out_claims)
print(f"
Determination: {'Supported' if aggmatch_determination['true'] else 'NOT supported'}.
", file=_out_claims)
print(f"
Rationale: {aggmatch_determination['rationale']}
", file=_out_claims)
for cid, ctext in zip(aggmatch_determination['id_l'], aggmatch_determination['chunk_text_l']):
print(f"
Chunk {cid}: {ctext}
", file=_out_claims)
print(f"
", file=_out_claims)
cur_idx[0] = 0
print('End output_credit_dist.')
return 'Show claim breakdown', _out_credit.getvalue()
def toggle_output(cur_idx, _out_credit, _out_claims):
if cur_idx[0] < 0:
return ''
cur_idx[0] += 1
if cur_idx[0] % 2 == 0:
return 'Show claim breakdown', _out_credit.getvalue()
return 'Back to attribution', _out_claims.getvalue()
with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
_out_credit_var = gr.State(StringIO)
_out_claims_var = gr.State(StringIO)
cur_idx_var = gr.State([0])
msg = gr.Textbox(label='Target')
results_box = gr.HTML(label='Matches')
toggle = gr.Button("")
msg.submit(output_credit_dist, [msg, cur_idx_var, _out_credit_var, _out_claims_var], [toggle, results_box], queue=False)
toggle.click(toggle_output, [cur_idx_var, _out_credit_var, _out_claims_var], [toggle, results_box], queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch()