Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pandas as pd | |
import numpy as np | |
import faiss | |
import re | |
import ast | |
import random | |
import tempfile | |
import time | |
from sentence_transformers import SentenceTransformer | |
from langchain_groq import ChatGroq | |
from langchain_core.messages import SystemMessage, HumanMessage | |
# --- Настройки путей и констант --- | |
HERE = os.path.dirname(os.path.abspath(__file__)) | |
CSV_PATH = os.path.join(HERE, "tvshows_processed2.csv") | |
EMB_PATH = os.path.join(HERE, "embeddings.npy") | |
FAISS_PATH = os.path.join(HERE, "faiss_index.index") | |
# --- Константы для очистки --- | |
BASIC_GENRES = [ | |
"комедия", "драма", "боевик", "фэнтези", "ужасы", "триллер", "романтика", | |
"научная фантастика", "приключения", "криминал", "мюзикл", | |
"семейный", "детектив", "биография", "документальный" | |
] | |
BAD_ACTORS = ["я не знаю", "нет информации", "не указан", "unknown", "—", ""] | |
GENRE_KEYWORDS_MAP = { | |
"доктор": "драма", "медицина": "драма", "врач": "драма", "школа": "драма", | |
"ужас": "ужасы", "фантастика": "научная фантастика", "война": "боевик", | |
"волшебство": "фэнтези", "дракон": "фэнтези" | |
} | |
# --- Определение жанра по ключевым словам --- | |
def infer_genre_from_query(query): | |
query_lower = query.lower() | |
words = re.findall(r'\b\w+\b', query_lower) | |
for word in words: | |
if word in GENRE_KEYWORDS_MAP: | |
return GENRE_KEYWORDS_MAP[word] | |
return None | |
# --- Очистка поля актёров --- | |
def clean_actors_string(val): | |
v = str(val).strip().lower() | |
if any(bad in v for bad in BAD_ACTORS) or not re.search(r'[a-zа-яё]', v): | |
return "Неизвестно" | |
return val | |
# --- Фильтрация жанров --- | |
def filter_to_basic_genres(genres_str): | |
if not isinstance(genres_str, str): | |
return "" | |
genres_lower = genres_str.lower() | |
matched = [g for g in BASIC_GENRES if g in genres_lower] | |
return ", ".join(matched) if matched else "Другие" | |
# --- Вводное описание --- | |
def extract_intro_paragraph(text, max_sentences=4): | |
sentences = re.split(r'(?<=[.!?]) +', str(text).strip()) | |
return " ".join(sentences[:max_sentences]) | |
# --- Очистка данных --- | |
def clean_tvshows_data(path): | |
if not os.path.exists(path): | |
raise FileNotFoundError(f"Файл данных не найден: {path}") | |
df = pd.read_csv(path) | |
df["actors"] = df.get("actors", "").astype(str).apply(clean_actors_string) | |
df["genres"] = df.get("genres", "").astype(str) | |
df["year"] = pd.to_numeric(df.get("year", 0), errors="coerce").fillna(0).astype(int) | |
df["num_seasons"] = pd.to_numeric(df.get("num_seasons", 0), errors="coerce").fillna(0).astype(int) | |
df["tvshow_title"] = df.get("tvshow_title", "").fillna("Неизвестно") | |
df["description"] = df.get("description", "").fillna("Нет описания").astype(str).str.strip() | |
df = df[df["description"].str.len() > 50] | |
df.drop_duplicates(subset=["tvshow_title", "description"], inplace=True) | |
for col in ["image_url", "url", "rating", "language", "country"]: | |
if col not in df.columns: | |
df[col] = None | |
df["basic_genres"] = df["genres"].apply(filter_to_basic_genres) | |
df["type"] = df["num_seasons"].apply(lambda x: "Сериал" if int(x) > 1 else "Фильм") | |
return df.reset_index(drop=True) | |
# --- Кэширование --- | |
def cached_load_data(path): | |
return clean_tvshows_data(path) | |
def cached_init_embedder(): | |
cache_dir = os.path.join(HERE, "sbert_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
return SentenceTransformer("sberbank-ai/sbert_large_nlu_ru", cache_folder=cache_dir) | |
def cached_load_embeddings_and_index(): | |
if not os.path.exists(EMB_PATH) or not os.path.exists(FAISS_PATH): | |
st.warning("Файлы эмбеддингов или индекса не найдены. Создаём новые...") | |
df = cached_load_data(CSV_PATH) | |
embedder = cached_init_embedder() | |
texts = df.apply(lambda row: f"Название: {row['tvshow_title']}. Описание: {row['description']}. Жанр: {row['genres']}. Актёры: {row['actors']}.", axis=1).tolist() | |
embeddings = embedder.encode(texts, show_progress_bar=True) | |
faiss.normalize_L2(embeddings) | |
np.save(EMB_PATH, embeddings) | |
index = faiss.IndexFlatIP(embeddings.shape[1]) | |
index.add(embeddings) | |
faiss.write_index(index, FAISS_PATH) | |
st.success("Новые эмбеддинги и индекс успешно созданы.") | |
st.stop() | |
embeddings = np.load(EMB_PATH) | |
index = faiss.read_index(FAISS_PATH) | |
return embeddings, index | |
# --- Инициализация Groq LLM --- | |
def init_groq_llm(): | |
try: | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
if not groq_api_key: | |
return None | |
os.environ["GROQ_API_KEY"] = groq_api_key | |
return ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0, max_tokens=2000) | |
except Exception as e: | |
st.error(f"Ошибка инициализации Groq: {e}") | |
return None | |
# --- Семантический поиск --- | |
def semantic_search(query, embedder, index, df, genre=None, year=None, country=None, vtype=None, k=5): | |
if not isinstance(query, str) or not query.strip(): | |
return pd.DataFrame() | |
inferred_genre = infer_genre_from_query(query) | |
if inferred_genre and (genre is None or genre == "Все"): | |
genre = inferred_genre | |
query_embedding = embedder.encode([query]) | |
faiss.normalize_L2(query_embedding) | |
dists, idxs = index.search(query_embedding, 500) | |
res = df.iloc[idxs[0]].copy() | |
res["score"] = dists[0] | |
if genre and genre != "Все": | |
res = res[res["basic_genres"].str.lower().str.contains(genre.lower(), na=False)] | |
if year and year != "Все": | |
res = res[res["year"] == int(year)] | |
if country and country != "Все": | |
res = res[res["country"].astype(str).str.lower().str.contains(country.lower(), na=False)] | |
if vtype and vtype != "Все": | |
res = res[res["type"].str.lower() == vtype.lower()] | |
if res.empty: | |
return res | |
query_lower = query.lower() | |
res['exact_match_title'] = res['tvshow_title'].str.lower() == query_lower | |
query_words = re.findall(r'\b\w+\b', query_lower) | |
keyword_pattern = '|'.join([re.escape(word) for word in query_words if len(word) > 2]) | |
res['has_keyword'] = res.apply(lambda row: bool(re.search(keyword_pattern, (str(row['tvshow_title']).lower() + str(row['description']).lower()))), axis=1) | |
res['final_score'] = res['score'] + res['exact_match_title'] * 1.5 + res['has_keyword'] * 0.4 | |
return res.sort_values(by="final_score", ascending=False).head(k) | |
# --- Форматирование результатов --- | |
def format_docs_for_prompt(results_df): | |
if results_df.empty: | |
return "Нет подходящих результатов." | |
return "\n\n".join([ | |
f"Название: {row['tvshow_title']} ({row['year']})\n" | |
f"Жанр: {row['basic_genres']}\n" | |
f"Рейтинг: {row['rating'] or '—'} | Тип: {row['type']} | Страна: {row['country'] or '—'} | Сезонов: {row['num_seasons'] or '—'}\n" | |
f"Актёры: {row['actors']}\nСюжет: {extract_intro_paragraph(row['description'])}" | |
for _, row in results_df.iterrows() | |
]) | |
# --- RAG ответ --- | |
def generate_rag_response(user_query, search_results, llm): | |
if llm is None: | |
return "LLM не инициализирован." | |
ctx = format_docs_for_prompt(search_results) | |
prompt = f""" | |
Ты — эксперт по кино и сериалам. Анализируй только данные ниже. | |
Результаты поиска: | |
{ctx} | |
Вопрос пользователя: {user_query} | |
Ответ: | |
""" | |
try: | |
response = llm.invoke([ | |
SystemMessage(content="Ты — эксперт по кино и сериалам. Не выдумывай лишнего."), | |
HumanMessage(content=prompt) | |
]).content.strip() | |
return response | |
except Exception as e: | |
return f"Ошибка при генерации ответа LLM: {e}" | |
# --- Основная функция --- | |
def main(): | |
st.set_page_config(page_title="Поиск фильмов и сериалов + Groq AI", layout="wide") | |
st.title("📽️ Поиск фильмов и сериалов с AI") | |
if "df" not in st.session_state: st.session_state.df = cached_load_data(CSV_PATH) | |
if "embedder" not in st.session_state: st.session_state.embedder = cached_init_embedder() | |
if "embeddings_index" not in st.session_state: | |
st.session_state.embeddings, st.session_state.index = cached_load_embeddings_and_index() | |
if "llm" not in st.session_state: st.session_state.llm = init_groq_llm() | |
df = st.session_state.df | |
embedder = st.session_state.embedder | |
index = st.session_state.index | |
llm = st.session_state.llm | |
# --- Фильтры --- | |
st.sidebar.header("Фильтры") | |
basic_genres_list = sorted([g for g in set(", ".join(df["basic_genres"].dropna().unique()).split(",")) if g]) | |
genre_filter = st.sidebar.selectbox("Жанр", ["Все"] + basic_genres_list, index=0) | |
years = ["Все"] + sorted([str(y) for y in df["year"].unique() if y != 0], reverse=True) | |
year_filter = st.sidebar.selectbox("Год", years, index=0) | |
countries = ["Все"] + sorted([c for c in df["country"].dropna().unique()]) | |
country_filter = st.sidebar.selectbox("Страна", countries, index=0) | |
vtypes = ["Все"] + sorted(df["type"].dropna().unique()) | |
type_filter = st.sidebar.selectbox("Тип", vtypes, index=0) | |
k = st.sidebar.slider("Количество результатов:", 1, 20, 5) | |
# --- Поиск --- | |
st.markdown("---") | |
user_input = st.text_input("Введите ключевые слова или сюжет:") | |
col_buttons = st.columns(4) | |
search_clicked = col_buttons[0].button("Искать") | |
random_clicked = col_buttons[1].button("Случайный") | |
genre_clicked = col_buttons[2].button("ТОП по жанру") | |
new_clicked = col_buttons[3].button("Новинки") | |
if search_clicked and user_input.strip(): | |
st.session_state.results = semantic_search(user_input, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k) | |
st.session_state.last_query = user_input | |
elif random_clicked: | |
rq = random.choice(df["tvshow_title"].tolist()) | |
st.session_state.results = semantic_search(rq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k) | |
st.session_state.last_query = rq | |
elif genre_clicked and genre_filter != "Все": | |
gq = f"Лучшие {genre_filter}" | |
st.session_state.results = semantic_search(gq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k) | |
st.session_state.last_query = gq | |
elif new_clicked: | |
nq = f"Новинки {df['year'].max()}" | |
st.session_state.results = semantic_search(nq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k) | |
st.session_state.last_query = nq | |
# --- Вывод --- | |
if "results" in st.session_state and not st.session_state.results.empty: | |
st.markdown("## Результаты поиска") | |
for _, row in st.session_state.results.iterrows(): | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
if row.get("image_url", "").startswith("http"): | |
st.image(row["image_url"], width=150) | |
else: | |
st.info("Нет изображения.") | |
with col2: | |
st.markdown(f"### {row['tvshow_title']} ({row['year']})") | |
st.caption(f"{row['basic_genres']} | {row['country'] or '—'} | {row['rating'] or '—'} | {row['type']} | {row['num_seasons']} сез.") | |
st.write(extract_intro_paragraph(row["description"])) | |
if row.get("actors"): | |
st.caption(f"Актёры: {row['actors']}") | |
if row.get("url"): | |
st.markdown(f"[Подробнее]({row['url']})") | |
st.divider() | |
if llm and st.button("AI: рекомендации"): | |
rag = generate_rag_response(st.session_state.last_query, st.session_state.results, llm) | |
st.markdown("### Рекомендации AI:") | |
st.write(rag) | |
st.sidebar.markdown("---") | |
st.sidebar.write(f"Всего записей: {len(df)}") | |
st.sidebar.markdown(f"LLM: {'Готов' if llm else 'Отключён'}") | |
if __name__ == "__main__": | |
main() | |