{ "cells": [ { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "import os\n", "# os.system('pip install openpyxl')\n", "# os.system('pip install sentence-transformers')\n", "import pandas as pd\n", "import gradio as gr\n", "from sentence_transformers import SentenceTransformer\n", "\n", "model = SentenceTransformer('all-mpnet-base-v2') #all-MiniLM-L6-v2 #all-mpnet-base-v2\n", "\n", "df = pd.read_parquet('df_encoded3.parquet')\n", "df['tags'] = df['tags'].apply(lambda x : str(x))\n", "def parse_raised(x):\n", " if x == 'Undisclosed':\n", " return 0\n", " else: \n", " quantifier = x[-1]\n", " x = float(x[1:-1])\n", " if quantifier == 'K':\n", " return x/1000\n", " elif quantifier == 'M':\n", " return x\n", "df['raised'] = df['raised'].apply(lambda x : parse_raised(x))\n", "df['stage'] = df['stage'].apply(lambda x : x.lower())\n", "df = df.reset_index(drop=True)\n", "\n", "from sklearn.neighbors import NearestNeighbors\n", "import pandas as pd\n", "from sentence_transformers import SentenceTransformer\n", "\n", "nbrs = NearestNeighbors(n_neighbors=5000, algorithm='ball_tree').fit(df['text_vector_'].values.tolist())\n", "\n", "def search(df, query):\n", " product = model.encode(query).tolist()\n", " # product = df.iloc[0]['text_vector_'] #use one of the products as sample\n", "\n", " #prepare model\n", " # \n", " distances, indices = nbrs.kneighbors([product]) #input the vector of the reference object\n", "\n", " #print out the description of every recommended product\n", " return df.iloc[list(indices)[0]][['name', 'raised', 'target', 'size', 'stage', 'country', 'source', 'description', 'tags']]\n", "\n", "def filter_df(df, column_name, filter_type, filter_value, minimum_acceptable_size=0):\n", " if filter_type == '==':\n", " df_filtered = df[df[column_name]==filter_value]\n", " elif filter_type == '>=':\n", " df_filtered = df[df[column_name]>=filter_value]\n", " elif filter_type == '<=':\n", " df_filtered = df[df[column_name]<=filter_value]\n", " elif filter_type == 'contains':\n", " df_filtered = df[df['target'].str.contains(filter_value)]\n", "\n", " if df_filtered.size >= minimum_acceptable_size:\n", " return df_filtered\n", " else:\n", " return df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#the first module becomes text1, the second module file1\n", "def greet(size, target, stage, query): \n", " def raised_zero(x):\n", " if x == 0:\n", " return 'Undisclosed'\n", " else:\n", " return x\n", " df_knn = search(df, query)\n", " df_knn['raised'] = df_knn['raised'].apply(lambda x : raised_zero(x))\n", "\n", " df_target = filter_df(df_knn, 'target', 'contains', target, 500)\n", " if stage != 'ALL':\n", " df_stage = filter_df(df_target, 'stage', '==', stage.lower(), 40)\n", " else:\n", " #we bypass the filter\n", " df_stage = df_target\n", "\n", " df_size = filter_df(df_stage, 'size', '==', size, 20)\n", " \n", " # display(df_stage)\n", " # df_raised = df_target[(df_target['raised'] >= raised) | (df_target['raised'] == 0)]\n", "\n", " #we live the sorting for last\n", " return df_size[0:100] #.sort_values('raised', ascending=False)\n", "\n", "with gr.Blocks(theme=gr.themes.Soft(primary_hue='amber', secondary_hue='gray', neutral_hue='amber')) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Startup Search Engine\n", " \"\"\"\n", " )\n", " size = gr.Radio(['1-10', '11-50', '51-200', '201-500', '500+', '11-500+'], multiselect=False, value='11-500+', label='size')\n", " target = gr.Radio(['B2B', 'B2C', 'B2G', 'B2B2C'], multiselect=False, value='B2B', label='target')\n", " stage = gr.Radio(['pre-seed', 'A', 'B', 'C', 'exit', 'ALL'], multiselect=False, value='ALL', label='stage')\n", " # raised = gr.Slider(0, 20, value=5, step_size=1, label=\"Minimum raising (in Millions)\")\n", " query = gr.Textbox(label='Describe the Startup you are searching for', value='age reversing')\n", " btn = gr.Button(value=\"Search for a Startup\")\n", " output1 = gr.DataFrame(label='value')\n", " # btn.click(greet, inputs='text', outputs=['dataframe'])\n", " btn.click(greet, [size, target, stage, query], [output1])\n", "demo.launch(share=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define database of sentences\n", "sentences = pd.Series(['The quick brown fox jumps over the lazy dog',\n", " 'A quick brown dog jumps over the lazy fox',\n", " 'The lazy dog jumps over the quick brown fox',\n", " 'The quick brown fox jumps over the lazy cat',\n", " 'The quick brown cat jumps over the lazy dog'])\n", "\n", "# Encode sentences\n", "sentence_embeddings = model.encode(sentences)\n", "\n", "# Define query sentence\n", "query = 'A lazy dog jumps over the quick brown fox'\n", "\n", "# Encode query\n", "query_embedding = model.encode(query)\n", "\n", "# Search for similar sentences\n", "cosine_scores = util.pytorch_cos_sim(query_embedding, sentence_embeddings)\n", "most_similar_sentence = sentences[cosine_scores.argmax()]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }