cross-selling-ai-finder-lite-demo / get_similar_program.py
Adr740's picture
Upload 4 files
06b126d verified
raw
history blame
2.84 kB
import numpy as np
import pandas as pd
from openai import OpenAI
import config
client = OpenAI(api_key=)
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _get_embedding(text, model="text-embedding-3-large"):
try:
text = text.replace("\n", " ")
except:
None
return client.embeddings.create(input = [text], model=model).data[0].embedding
def augment_user_input(user_input):
prompt = f"""
Based on the profile of this student, propose a highly detailed bullet point list of training programs in French that could be good for him:
{user_input}
"""
augmented_input = client.chat.completions.create(
model="gpt-4-turbo-preview",
temperature=1,
max_tokens = 400,
messages=[
{"role": "user", "content": prompt},
],
).choices[0].message.content
return f"{user_input}\n{augmented_input}"
def search_programs(raw_input,nb_programs_to_display=10,augment_input = False, filters = [], path_to_csv = "data_planeta_february2024.csv",):
user_input = raw_input
if augment_input:
user_input = augment_user_input(raw_input)
df = pd.read_csv(path_to_csv).dropna(subset=["Embeddings"])
if len(filters) != 0:
formatted_filters = []
for filter in filters:
formatted_filters.append(f"\nÉCOLE: {filter}")
df = df[df["ÉCOLE"].isin(formatted_filters)].reset_index(drop=True).copy()
try:
df["embeddings"] = df.Embeddings.apply(lambda x: x["Embeddings"])
except:
pass
try:
df["embeddings"] = df.Embeddings.apply(lambda x: np.array(eval(x)))
except:
pass
embedding = _get_embedding(user_input, model="text-embedding-3-large")
def wrap_cos(x,y):
try:
res = cosine_similarity(x,y)
except:
res = 0
return res
try:
df['similarity'] = df.Embeddings.apply(lambda x: wrap_cos(eval(x), embedding))
except:
breakpoint()
results = df.sort_values('similarity', ascending=False).head(int(nb_programs_to_display)).to_dict(orient="records")
final_string = ""
i = 1
for result in results:
content = str(result["summary_french"])
extracted_string_program = ""
extracted_string_program += content.split("##")[1].split("\n\n")[0]
for sub_element in content.split("##")[2:]:
extracted_string_program += sub_element
extracted_string_program=extracted_string_program.replace("\n# ", "\n### ").replace("55555","###")
displayed_string = "##"+extracted_string_program + "\n\n------\n\n"
final_string += displayed_string
i += 1
return final_string