Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import pandas as pd | |
import torch | |
from datasets import load_dataset | |
from sentence_transformers.util import semantic_search | |
from sentence_transformers import SentenceTransformer, util | |
BUILDS = ['demographics300', 'uncurated3000'] | |
# Download model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Load embeddings | |
dataset_embeddings_maps = {} | |
dcid_maps = {} | |
for build in BUILDS: | |
print('Loading build ', build) | |
ds = load_dataset('csv', data_files=f'embeddings_{build}.csv') | |
df = ds["train"].to_pandas() | |
dcid_maps[build] = df['dcid'].values.tolist() | |
df = df.drop('dcid', axis=1) | |
dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float) | |
def inference(build, query): | |
query_embeddings = model.encode([query]) | |
# Note: multiple results may map to the same DCID. As well, the same string may | |
hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15) | |
# map to multiple DCIDs with the same score. | |
sv2score = {} | |
score2svs = {} | |
for e in hits[0]: | |
for d in dcid_maps[build][e['corpus_id']].split(','): | |
s = e['score'] | |
# Prefer the top score. | |
if d not in sv2score: | |
sv2score[d] = s | |
if s not in score2svs: | |
score2svs[s] = [d] | |
else: | |
score2svs[s].append(d) | |
# Sort by scores | |
scores = [s for s in sorted(score2svs.keys(), reverse=True)] | |
svs = [' : '.join(score2svs[s]) for s in scores] | |
# Addd to Pandas | |
result = pd.DataFrame({'SV': svs, 'Cosine Score': scores}) | |
return result | |
# Create a simple search interface | |
title = "DC Search Demo" | |
description = """ | |
Try querying for StatVars. | |
- "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF) | |
related to demographics | |
- "uncurated3000": 3000 SVs with only auto-generated name related to | |
demographics, crime, agriculture, households, housing, emissions, health | |
""" | |
# TODO: make logging work | |
# HF_TOKEN = os.getenv('HF_TOKEN') | |
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log") | |
iface = gr.Interface(fn=inference, | |
inputs=[ | |
gr.Dropdown(choices=BUILDS, | |
value='uncurated3000', | |
label='Embeddings Build'), | |
gr.Textbox(label='Query', | |
placeholder='how long do people live?') | |
], | |
outputs=gr.Dataframe(headers=['SV', 'Cosine Score'], | |
label='Search Results'), | |
title=title, | |
description=description, | |
allow_flagging="manual", | |
flagging_options=["not at all related", | |
"related but not ranked right"]) | |
iface.launch() | |