|
import numpy as np |
|
import pandas as pd |
|
from openai import OpenAI |
|
import config |
|
|
|
client = OpenAI(api_key=) |
|
|
|
def cosine_similarity(a, b): |
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
def _get_embedding(text, model="text-embedding-3-large"): |
|
try: |
|
text = text.replace("\n", " ") |
|
except: |
|
None |
|
return client.embeddings.create(input = [text], model=model).data[0].embedding |
|
|
|
|
|
|
|
def augment_user_input(user_input): |
|
|
|
prompt = f""" |
|
Based on the profile of this student, propose a highly detailed bullet point list of training programs in French that could be good for him: |
|
|
|
{user_input} |
|
""" |
|
augmented_input = client.chat.completions.create( |
|
model="gpt-4-turbo-preview", |
|
temperature=1, |
|
max_tokens = 400, |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
], |
|
).choices[0].message.content |
|
return f"{user_input}\n{augmented_input}" |
|
|
|
def search_programs(raw_input,nb_programs_to_display=10,augment_input = False, filters = [], path_to_csv = "data_planeta_february2024.csv",): |
|
user_input = raw_input |
|
if augment_input: |
|
user_input = augment_user_input(raw_input) |
|
df = pd.read_csv(path_to_csv).dropna(subset=["Embeddings"]) |
|
if len(filters) != 0: |
|
formatted_filters = [] |
|
for filter in filters: |
|
formatted_filters.append(f"\nÉCOLE: {filter}") |
|
df = df[df["ÉCOLE"].isin(formatted_filters)].reset_index(drop=True).copy() |
|
try: |
|
df["embeddings"] = df.Embeddings.apply(lambda x: x["Embeddings"]) |
|
except: |
|
pass |
|
try: |
|
df["embeddings"] = df.Embeddings.apply(lambda x: np.array(eval(x))) |
|
except: |
|
pass |
|
embedding = _get_embedding(user_input, model="text-embedding-3-large") |
|
def wrap_cos(x,y): |
|
try: |
|
res = cosine_similarity(x,y) |
|
except: |
|
res = 0 |
|
return res |
|
try: |
|
df['similarity'] = df.Embeddings.apply(lambda x: wrap_cos(eval(x), embedding)) |
|
except: |
|
breakpoint() |
|
results = df.sort_values('similarity', ascending=False).head(int(nb_programs_to_display)).to_dict(orient="records") |
|
final_string = "" |
|
i = 1 |
|
for result in results: |
|
content = str(result["summary_french"]) |
|
extracted_string_program = "" |
|
extracted_string_program += content.split("##")[1].split("\n\n")[0] |
|
|
|
for sub_element in content.split("##")[2:]: |
|
extracted_string_program += sub_element |
|
|
|
extracted_string_program=extracted_string_program.replace("\n# ", "\n### ").replace("55555","###") |
|
|
|
displayed_string = "##"+extracted_string_program + "\n\n------\n\n" |
|
|
|
final_string += displayed_string |
|
i += 1 |
|
|
|
return final_string |
|
|