find_my_show / resource /functions.py
IvT-DS's picture
Upload 10 files
41b0868 verified
raw
history blame
2.57 kB
import pandas as pd
import torch
import faiss
import numpy as np
from numpy import dot
from numpy.linalg import norm
def table_maker(
df: pd.DataFrame,
country: list = [],
min_year: int = 1999,
max_year: int = None,
tagger=set(),
rating: bool = True,
):
x = df.copy()
# фильтр по рейтингк
if rating:
rat_con = ~(x["rating"].isna())
else:
rat_con = ~(x["url"].isna())
# фильтр по стране
if country == []:
con_con = ~(x["url"].isna())
else:
con_con = x["county"].isin(country)
# фильтр по тегам
if tagger == set():
tagger_con = ~(x["url"].isna())
else:
tagger_con = x["tags"].ge(tagger)
# Условие для фильтрации по минимальному году
year_cond = x["year"] >= min_year
# Добавляем условие для фильтрации по максимальному году, если оно задано
if max_year is not None:
year_cond &= x["year"] <= max_year
condi = rat_con & con_con & tagger_con & year_cond
return x.loc[condi]
class RecSys:
def __init__(self, df: pd.DataFrame, input_, model):
self.df = df
self.input_ = input_
self.model = model
with torch.no_grad():
self.emb = model.encode(self.input_)
def __call__(self):
def compute(a):
return dot(a, self.emb) / (norm(a) * norm(self.emb))
res = self.df.copy()
res["compute"] = res["vec"].map(compute)
res["compute2"] = res["vec2"].map(compute)
self.df["score"] = res["compute"] * 0.8 + res["compute2"] * 0.2
return self.df.sort_values("score", ascending=False)
class FAISS_inference:
def __init__(self, df, emb, k=5):
self.df = df
self.emb = emb.reshape(1, -1)
self.k = k
vec = df["vec"].to_numpy()
self.d = vec[0].shape[0]
for i, e in enumerate(vec):
if i == 0:
vex = e.T
else:
temp = e.T
vex = np.append(vex, temp)
self.vex = np.reshape(vex, (-1, 384))
# self.index = faiss.IndexFlatIP(self.d)
# self.index = faiss.IndexFlatL2(self.d)
self.index = faiss.IndexFlat(self.d)
self.index.add(self.vex)
def __call__(self):
d, i = self.index.search(self.emb, self.k)
faiss_table = self.df.iloc[i[0]]
faiss_table.loc[:, "score"] = d[0]
return faiss_table