levanti_en_ar / semsearch.py
Guy Mor-Lan
add files
e35836c
raw
history blame
No virus
2.51 kB
import numpy as np
import torch
import pandas as pd
import translate
import gradio as gr
# data = pd.read_csv("./embedding_data.csv")
# embeddings = np.load("./embeddings.npy")
def normalize_vector(v):
norm = np.linalg.norm(v)
if norm == 0:
return v
return v / norm
def embed_one(model, tokenizer, text, normalize=True):
tokens = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
embedding = model.model.encoder(**tokens).last_hidden_state.mean(axis=1)
embedding = embedding.detach().numpy()[0]
if normalize:
return normalize_vector(embedding)
else:
return embedding
def knn(query_embedding, embeddings, df, k=5, hebrew=True):
sims = np.dot(embeddings, query_embedding.T)
outs = np.argsort(sims, axis=0)[-k:][::-1]
select = outs.ravel()
if hebrew:
return df.iloc[select][["arabic", "hebrew", "validated"]]
else:
return df.iloc[select][["arabic", "english", "validated"]]
def run_knn(text, k=5):
print(text)
query_embedding = embed_one(translate.model_from_ar,
translate.tokenizer_from_ar, text)
return knn(query_embedding, embeddings, data, k=k, hebrew=True)
def style_dataframe(df):
styled_df = df.style.set_properties(**{
'font-family': 'Arial, sans-serif',
'font-size': '20px',
'text-align': 'right',
'direction': 'rtl',
'align': 'right'
}).set_table_styles([
{'selector': 'th', 'props': [('text-align', 'right')]}
])
return styled_df
def style_dataframe(df):
return df.style.set_table_styles([
{'selector': 'thead', 'props': [('text-align', 'right')]},
{'selector': '.index_name', 'props': [('text-align', 'right')]},
]).set_properties(**{
'text-align': 'right',
}) # Replace 'column_name' with your actual column name
def update_df(hidden_arabic):
df = run_knn(hidden_arabic, 100)
# replace true and false in validated column with checkmark and x emoji
df["validated"] = df["validated"].apply(lambda x: "โœ…" if x else "โŒ")
# replace name validated with "ืžืื•ืžืช"
df = df.rename(columns={"validated": "ืžืื•ืžืช"})
# replace name arabic with "ืขืจื‘ื™ืช"
df = df.rename(columns={"arabic": "ืขืจื‘ื™ืช"})
# replace name hebrew with "ืขื‘ืจื™ืช"
df = df.rename(columns={"hebrew": "ืขื‘ืจื™ืช"})
styled_df = style_dataframe(df)
return gr.DataFrame(value=styled_df, visible=True)