test_call2vec / app.py
FreshP's picture
Added examples
6779746
raw history blame
No virus
4.87 kB
import gradio as gr
import numpy as np
import pandas as pd
from datetime import datetime
import os
import re
from huggingface_hub import hf_hub_url, cached_download
from gensim.models.fasttext import load_facebook_model
ACCESS_KEY = os.environ.get('ACCESS_KEY')
# Setup model
url = hf_hub_url(repo_id="simonschoe/call2vec", filename="model.bin")
cached_download(url)
model = load_facebook_model(cached_download(url))
def semantic_search(_input, n):
""" Perform semantic search """
_input = re.split('[,;\n]', _input)
_input = [s.strip().lower().replace(' ', '_') for s in _input if s]
if _input[0] != ACCESS_KEY:
with open('log.txt', 'a') as f:
f.write(str(datetime.now()) + '+++' + '___'.join(_input) + '\n')
if len(_input) > 1:
avg_input = np.stack([model.wv[w] for w in _input], axis=0).mean(axis=0)
nearest_neighbours = model.wv.most_similar(positive=avg_input, topn=n)
frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbours]
else:
nearest_neighbours = model.wv.most_similar(positive=_input[0], topn=n)
frequencies = [model.wv.get_vecattr(nn[0], 'count') for nn in nearest_neighbours]
if _input[0] == ACCESS_KEY:
with open('log.txt', 'r') as f:
prompts = f.readlines()
prompts = [p.strip().split('+++') for p in prompts]
result = pd.DataFrame(prompts,
columns=['Time', 'Prompt'])
else:
result = pd.DataFrame([(a[0],a[1],b) for a,b in zip(nearest_neighbours, frequencies)],
columns=['Token', 'Cosine Similarity', 'Corpus Frequency'])
result.to_csv('result.csv')
return result, 'result.csv', '\n'.join(_input)
app = gr.Blocks()
with app:
gr.Markdown("# Call2Vec")
gr.Markdown("## Semantic Search in Quarterly Earnings Conference Calls")
with gr.Row():
with gr.Column():
gr.Markdown(
"""
#### Project Description
Call2Vec is a [fastText](https://fasttext.cc/) word embedding model trained via [Gensim](https://radimrehurek.com/gensim/). It maps each token in the vocabulary into a dense, 300-dimensional vector space, designed for performing semantic search.
The model is trained on a large sample of quarterly earnings conference calls, held by U.S. firms during the 2006-2022 period. In particular, the training data is restriced to the (rather sponentous) executives' remarks of the Q&A section of the call. The data has been preprocessed prior to model training via stop word removal, lemmatization, named entity masking, and coocurrence modeling.
"""
)
gr.Markdown(
"""
#### App usage
The model is intented to be used for **semantic search**: It encodes the search query (entered in the textbox on the right) in a dense vector space and finds semantic neighbours, i.e., token which frequently occur within similar contexts in the underlying training data.
The model allows for two use cases:
1. *Single Search:* The input query consists of a single word. When provided a bi-, tri-, or even fourgram, the quality of the model output depends on the presence of the query token in the model's vocabulary. N-grams should be concated by an underscore (e.g., "machine_learning" or "artifical_intelligence").
2. *Multi Search:* The input query may consist of several words or n-grams, seperated by comma, semi-colon or newline. It then computes the average vector over all inputs and performs semantic search based on the average input token.
"""
)
with gr.Column():
text_in = gr.Textbox(lines=1, placeholder="Insert text", label="Search Query")
with gr.Row():
n = gr.Slider(value=50, minimum=5, maximum=250, step=5, label="Number of Neighbours")
compute_bt = gr.Button("Start\nSearch")
df_out = gr.Dataframe(interactive=False)
f_out = gr.File(interactive=False, label="Download")
gr.Examples(
examples = [["transformation", 3], ["climate_change", 3], ["risk, political_risk, uncertainty", 5]],
inputs = [text_in, n],
outputs = [df_out, f_out, text_in],
fn = semantic_search,
cache_examples=True
)
gr.Markdown(
"""
<div style='text-align: center;'>Call2Vec by X and Y</center></div>
<p class="aligncenter"><img 'id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=simonschoe.call2vec&left_color=green&right_color=blue" /></p>
"""
)
compute_bt.click(semantic_search, inputs=[text_in, n], outputs=[df_out, f_out, text_in])
app.launch()