import re import json import pandas as pd import gradio as gr import pyterrier as pt pt.init() import pyt_splade from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md factory_max = pyt_splade.SpladeFactory(agg='max') factory_sum = pyt_splade.SpladeFactory(agg='sum') COLAB_NAME = 'pyterrier_splade.ipynb' COLAB_INSTALL = ''' !pip install -q git+https://github.com/naver/splade !pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc '''.strip() def generate_vis(df, mode='Document'): if len(df) == 0: return '' result = [] if mode == 'Document': max_score = max(max(t.values()) for t in df['toks']) for row in df.itertuples(index=False): if mode == 'Query': tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'combine:0=([0-9.]+)\(([^)]+)\)', row.query)} max_score = max(tok_scores.values()) orig_tokens = factory_max.tokenizer.tokenize(row.query_0) id = row.qid else: tok_scores = row.toks orig_tokens = factory_max.tokenizer.tokenize(row.text) id = row.docno def toks2span(toks): return ' '.join(f'{t}' for t in toks) orig_tokens_set = set(orig_tokens) exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set] result.append(f'''
{mode}: {id}
{toks2span(orig_tokens)}
Expansion Tokens: {toks2span(exp_tokens)}
''') return '\n'.join(result) def predict_query(input, agg): code = f'''import pandas as pd import pyterrier as pt ; pt.init() import pyt_splade factory = pyt_splade.SpladeFactory(agg={agg}) query_pipeline = factory.query() query_pipeline({df2code(input)}) ''' pipeline = { 'max': factory_max, 'sum': factory_sum }[agg].query() res = pipeline(input) vis = generate_vis(res, mode='Query') return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) def predict_doc(input, agg): code = f'''import pandas as pd import pyterrier as pt ; pt.init() import pyt_splade factory = pyt_splade.SpladeFactory(agg={agg}) doc_pipeline = factory.indexing() doc_pipeline({df2code(input)}) ''' pipeline = { 'max': factory_max, 'sum': factory_sum }[agg].indexing() res = pipeline(input) vis = generate_vis(res, mode='Document') res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']] return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) interface( MarkdownFile('README.md'), MarkdownFile('query.md'), Demo( predict_query, pd.DataFrame([ {'qid': '1112389', 'query': 'what is the county for grand rapids, mn'}, ]), [ gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), ], scale=2/3 ), MarkdownFile('doc.md'), Demo( predict_doc, pd.DataFrame([ {'docno': '0', 'text': 'The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'}, ]), [ gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), ], scale=2/3 ), MarkdownFile('wrapup.md'), ).launch(share=False)