dc_statvar_demo / app.py
Prashanth Radhakrishnan
Deploy to HF
fc54c76
raw
history blame contribute delete
No virus
2.85 kB
import gradio as gr
import os
import pandas as pd
import torch
from datasets import load_dataset
from sentence_transformers.util import semantic_search
from sentence_transformers import SentenceTransformer, util
BUILDS = ['demographics300', 'uncurated3000']
# Download model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Load embeddings
dataset_embeddings_maps = {}
dcid_maps = {}
for build in BUILDS:
print('Loading build ', build)
ds = load_dataset('csv', data_files=f'embeddings_{build}.csv')
df = ds["train"].to_pandas()
dcid_maps[build] = df['dcid'].values.tolist()
df = df.drop('dcid', axis=1)
dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float)
def inference(build, query):
query_embeddings = model.encode([query])
# Note: multiple results may map to the same DCID. As well, the same string may
hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15)
# map to multiple DCIDs with the same score.
sv2score = {}
score2svs = {}
for e in hits[0]:
for d in dcid_maps[build][e['corpus_id']].split(','):
s = e['score']
# Prefer the top score.
if d not in sv2score:
sv2score[d] = s
if s not in score2svs:
score2svs[s] = [d]
else:
score2svs[s].append(d)
# Sort by scores
scores = [s for s in sorted(score2svs.keys(), reverse=True)]
svs = [' : '.join(score2svs[s]) for s in scores]
# Addd to Pandas
result = pd.DataFrame({'SV': svs, 'Cosine Score': scores})
return result
# Create a simple search interface
title = "DC Search Demo"
description = """
Try querying for StatVars.
- "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF)
related to demographics
- "uncurated3000": 3000 SVs with only auto-generated name related to
demographics, crime, agriculture, households, housing, emissions, health
"""
# TODO: make logging work
# HF_TOKEN = os.getenv('HF_TOKEN')
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log")
iface = gr.Interface(fn=inference,
inputs=[
gr.Dropdown(choices=BUILDS,
value='uncurated3000',
label='Embeddings Build'),
gr.Textbox(label='Query',
placeholder='how long do people live?')
],
outputs=gr.Dataframe(headers=['SV', 'Cosine Score'],
label='Search Results'),
title=title,
description=description,
allow_flagging="manual",
flagging_options=["not at all related",
"related but not ranked right"])
iface.launch()