import gradio as gr import os import pandas as pd import torch from datasets import load_dataset from sentence_transformers.util import semantic_search from sentence_transformers import SentenceTransformer, util BUILDS = ['demographics300', 'uncurated3000'] # Download model model = SentenceTransformer('all-MiniLM-L6-v2') # Load embeddings dataset_embeddings_maps = {} dcid_maps = {} for build in BUILDS: print('Loading build ', build) ds = load_dataset('csv', data_files=f'embeddings_{build}.csv') df = ds["train"].to_pandas() dcid_maps[build] = df['dcid'].values.tolist() df = df.drop('dcid', axis=1) dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float) def inference(build, query): query_embeddings = model.encode([query]) # Note: multiple results may map to the same DCID. As well, the same string may hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15) # map to multiple DCIDs with the same score. sv2score = {} score2svs = {} for e in hits[0]: for d in dcid_maps[build][e['corpus_id']].split(','): s = e['score'] # Prefer the top score. if d not in sv2score: sv2score[d] = s if s not in score2svs: score2svs[s] = [d] else: score2svs[s].append(d) # Sort by scores scores = [s for s in sorted(score2svs.keys(), reverse=True)] svs = [' : '.join(score2svs[s]) for s in scores] # Addd to Pandas result = pd.DataFrame({'SV': svs, 'Cosine Score': scores}) return result # Create a simple search interface title = "DC Search Demo" description = """ Try querying for StatVars. - "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF) related to demographics - "uncurated3000": 3000 SVs with only auto-generated name related to demographics, crime, agriculture, households, housing, emissions, health """ # TODO: make logging work # HF_TOKEN = os.getenv('HF_TOKEN') # hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log") iface = gr.Interface(fn=inference, inputs=[ gr.Dropdown(choices=BUILDS, value='uncurated3000', label='Embeddings Build'), gr.Textbox(label='Query', placeholder='how long do people live?') ], outputs=gr.Dataframe(headers=['SV', 'Cosine Score'], label='Search Results'), title=title, description=description, allow_flagging="manual", flagging_options=["not at all related", "related but not ranked right"]) iface.launch()