tv_show_symantic / src /streamlit_app.py
Kapex13's picture
Update src/streamlit_app.py
4a61f17 verified
import streamlit as st
import os
import pandas as pd
import numpy as np
import faiss
import re
import ast
import random
import tempfile
import time
from sentence_transformers import SentenceTransformer
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage
# --- Настройки путей и констант ---
HERE = os.path.dirname(os.path.abspath(__file__))
CSV_PATH = os.path.join(HERE, "tvshows_processed2.csv")
EMB_PATH = os.path.join(HERE, "embeddings.npy")
FAISS_PATH = os.path.join(HERE, "faiss_index.index")
# --- Константы для очистки ---
BASIC_GENRES = [
"комедия", "драма", "боевик", "фэнтези", "ужасы", "триллер", "романтика",
"научная фантастика", "приключения", "криминал", "мюзикл",
"семейный", "детектив", "биография", "документальный"
]
BAD_ACTORS = ["я не знаю", "нет информации", "не указан", "unknown", "—", ""]
GENRE_KEYWORDS_MAP = {
"доктор": "драма", "медицина": "драма", "врач": "драма", "школа": "драма",
"ужас": "ужасы", "фантастика": "научная фантастика", "война": "боевик",
"волшебство": "фэнтези", "дракон": "фэнтези"
}
# --- Определение жанра по ключевым словам ---
def infer_genre_from_query(query):
query_lower = query.lower()
words = re.findall(r'\b\w+\b', query_lower)
for word in words:
if word in GENRE_KEYWORDS_MAP:
return GENRE_KEYWORDS_MAP[word]
return None
# --- Очистка поля актёров ---
def clean_actors_string(val):
v = str(val).strip().lower()
if any(bad in v for bad in BAD_ACTORS) or not re.search(r'[a-zа-яё]', v):
return "Неизвестно"
return val
# --- Фильтрация жанров ---
def filter_to_basic_genres(genres_str):
if not isinstance(genres_str, str):
return ""
genres_lower = genres_str.lower()
matched = [g for g in BASIC_GENRES if g in genres_lower]
return ", ".join(matched) if matched else "Другие"
# --- Вводное описание ---
def extract_intro_paragraph(text, max_sentences=4):
sentences = re.split(r'(?<=[.!?]) +', str(text).strip())
return " ".join(sentences[:max_sentences])
# --- Очистка данных ---
def clean_tvshows_data(path):
if not os.path.exists(path):
raise FileNotFoundError(f"Файл данных не найден: {path}")
df = pd.read_csv(path)
df["actors"] = df.get("actors", "").astype(str).apply(clean_actors_string)
df["genres"] = df.get("genres", "").astype(str)
df["year"] = pd.to_numeric(df.get("year", 0), errors="coerce").fillna(0).astype(int)
df["num_seasons"] = pd.to_numeric(df.get("num_seasons", 0), errors="coerce").fillna(0).astype(int)
df["tvshow_title"] = df.get("tvshow_title", "").fillna("Неизвестно")
df["description"] = df.get("description", "").fillna("Нет описания").astype(str).str.strip()
df = df[df["description"].str.len() > 50]
df.drop_duplicates(subset=["tvshow_title", "description"], inplace=True)
for col in ["image_url", "url", "rating", "language", "country"]:
if col not in df.columns:
df[col] = None
df["basic_genres"] = df["genres"].apply(filter_to_basic_genres)
df["type"] = df["num_seasons"].apply(lambda x: "Сериал" if int(x) > 1 else "Фильм")
return df.reset_index(drop=True)
# --- Кэширование ---
@st.cache_data
def cached_load_data(path):
return clean_tvshows_data(path)
@st.cache_resource
def cached_init_embedder():
cache_dir = os.path.join(HERE, "sbert_cache")
os.makedirs(cache_dir, exist_ok=True)
return SentenceTransformer("sberbank-ai/sbert_large_nlu_ru", cache_folder=cache_dir)
@st.cache_resource
def cached_load_embeddings_and_index():
if not os.path.exists(EMB_PATH) or not os.path.exists(FAISS_PATH):
st.warning("Файлы эмбеддингов или индекса не найдены. Создаём новые...")
df = cached_load_data(CSV_PATH)
embedder = cached_init_embedder()
texts = df.apply(lambda row: f"Название: {row['tvshow_title']}. Описание: {row['description']}. Жанр: {row['genres']}. Актёры: {row['actors']}.", axis=1).tolist()
embeddings = embedder.encode(texts, show_progress_bar=True)
faiss.normalize_L2(embeddings)
np.save(EMB_PATH, embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, FAISS_PATH)
st.success("Новые эмбеддинги и индекс успешно созданы.")
st.stop()
embeddings = np.load(EMB_PATH)
index = faiss.read_index(FAISS_PATH)
return embeddings, index
# --- Инициализация Groq LLM ---
@st.cache_resource(ttl=3600)
def init_groq_llm():
try:
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
return None
os.environ["GROQ_API_KEY"] = groq_api_key
return ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0, max_tokens=2000)
except Exception as e:
st.error(f"Ошибка инициализации Groq: {e}")
return None
# --- Семантический поиск ---
def semantic_search(query, embedder, index, df, genre=None, year=None, country=None, vtype=None, k=5):
if not isinstance(query, str) or not query.strip():
return pd.DataFrame()
inferred_genre = infer_genre_from_query(query)
if inferred_genre and (genre is None or genre == "Все"):
genre = inferred_genre
query_embedding = embedder.encode([query])
faiss.normalize_L2(query_embedding)
dists, idxs = index.search(query_embedding, 500)
res = df.iloc[idxs[0]].copy()
res["score"] = dists[0]
if genre and genre != "Все":
res = res[res["basic_genres"].str.lower().str.contains(genre.lower(), na=False)]
if year and year != "Все":
res = res[res["year"] == int(year)]
if country and country != "Все":
res = res[res["country"].astype(str).str.lower().str.contains(country.lower(), na=False)]
if vtype and vtype != "Все":
res = res[res["type"].str.lower() == vtype.lower()]
if res.empty:
return res
query_lower = query.lower()
res['exact_match_title'] = res['tvshow_title'].str.lower() == query_lower
query_words = re.findall(r'\b\w+\b', query_lower)
keyword_pattern = '|'.join([re.escape(word) for word in query_words if len(word) > 2])
res['has_keyword'] = res.apply(lambda row: bool(re.search(keyword_pattern, (str(row['tvshow_title']).lower() + str(row['description']).lower()))), axis=1)
res['final_score'] = res['score'] + res['exact_match_title'] * 1.5 + res['has_keyword'] * 0.4
return res.sort_values(by="final_score", ascending=False).head(k)
# --- Форматирование результатов ---
def format_docs_for_prompt(results_df):
if results_df.empty:
return "Нет подходящих результатов."
return "\n\n".join([
f"Название: {row['tvshow_title']} ({row['year']})\n"
f"Жанр: {row['basic_genres']}\n"
f"Рейтинг: {row['rating'] or '—'} | Тип: {row['type']} | Страна: {row['country'] or '—'} | Сезонов: {row['num_seasons'] or '—'}\n"
f"Актёры: {row['actors']}\nСюжет: {extract_intro_paragraph(row['description'])}"
for _, row in results_df.iterrows()
])
# --- RAG ответ ---
def generate_rag_response(user_query, search_results, llm):
if llm is None:
return "LLM не инициализирован."
ctx = format_docs_for_prompt(search_results)
prompt = f"""
Ты — эксперт по кино и сериалам. Анализируй только данные ниже.
Результаты поиска:
{ctx}
Вопрос пользователя: {user_query}
Ответ:
"""
try:
response = llm.invoke([
SystemMessage(content="Ты — эксперт по кино и сериалам. Не выдумывай лишнего."),
HumanMessage(content=prompt)
]).content.strip()
return response
except Exception as e:
return f"Ошибка при генерации ответа LLM: {e}"
# --- Основная функция ---
def main():
st.set_page_config(page_title="Поиск фильмов и сериалов + Groq AI", layout="wide")
st.title("📽️ Поиск фильмов и сериалов с AI")
if "df" not in st.session_state: st.session_state.df = cached_load_data(CSV_PATH)
if "embedder" not in st.session_state: st.session_state.embedder = cached_init_embedder()
if "embeddings_index" not in st.session_state:
st.session_state.embeddings, st.session_state.index = cached_load_embeddings_and_index()
if "llm" not in st.session_state: st.session_state.llm = init_groq_llm()
df = st.session_state.df
embedder = st.session_state.embedder
index = st.session_state.index
llm = st.session_state.llm
# --- Фильтры ---
st.sidebar.header("Фильтры")
basic_genres_list = sorted([g for g in set(", ".join(df["basic_genres"].dropna().unique()).split(",")) if g])
genre_filter = st.sidebar.selectbox("Жанр", ["Все"] + basic_genres_list, index=0)
years = ["Все"] + sorted([str(y) for y in df["year"].unique() if y != 0], reverse=True)
year_filter = st.sidebar.selectbox("Год", years, index=0)
countries = ["Все"] + sorted([c for c in df["country"].dropna().unique()])
country_filter = st.sidebar.selectbox("Страна", countries, index=0)
vtypes = ["Все"] + sorted(df["type"].dropna().unique())
type_filter = st.sidebar.selectbox("Тип", vtypes, index=0)
k = st.sidebar.slider("Количество результатов:", 1, 20, 5)
# --- Поиск ---
st.markdown("---")
user_input = st.text_input("Введите ключевые слова или сюжет:")
col_buttons = st.columns(4)
search_clicked = col_buttons[0].button("Искать")
random_clicked = col_buttons[1].button("Случайный")
genre_clicked = col_buttons[2].button("ТОП по жанру")
new_clicked = col_buttons[3].button("Новинки")
if search_clicked and user_input.strip():
st.session_state.results = semantic_search(user_input, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k)
st.session_state.last_query = user_input
elif random_clicked:
rq = random.choice(df["tvshow_title"].tolist())
st.session_state.results = semantic_search(rq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k)
st.session_state.last_query = rq
elif genre_clicked and genre_filter != "Все":
gq = f"Лучшие {genre_filter}"
st.session_state.results = semantic_search(gq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k)
st.session_state.last_query = gq
elif new_clicked:
nq = f"Новинки {df['year'].max()}"
st.session_state.results = semantic_search(nq, embedder, index, df, genre_filter, year_filter, country_filter, type_filter, k)
st.session_state.last_query = nq
# --- Вывод ---
if "results" in st.session_state and not st.session_state.results.empty:
st.markdown("## Результаты поиска")
for _, row in st.session_state.results.iterrows():
col1, col2 = st.columns([1, 3])
with col1:
if row.get("image_url", "").startswith("http"):
st.image(row["image_url"], width=150)
else:
st.info("Нет изображения.")
with col2:
st.markdown(f"### {row['tvshow_title']} ({row['year']})")
st.caption(f"{row['basic_genres']} | {row['country'] or '—'} | {row['rating'] or '—'} | {row['type']} | {row['num_seasons']} сез.")
st.write(extract_intro_paragraph(row["description"]))
if row.get("actors"):
st.caption(f"Актёры: {row['actors']}")
if row.get("url"):
st.markdown(f"[Подробнее]({row['url']})")
st.divider()
if llm and st.button("AI: рекомендации"):
rag = generate_rag_response(st.session_state.last_query, st.session_state.results, llm)
st.markdown("### Рекомендации AI:")
st.write(rag)
st.sidebar.markdown("---")
st.sidebar.write(f"Всего записей: {len(df)}")
st.sidebar.markdown(f"LLM: {'Готов' if llm else 'Отключён'}")
if __name__ == "__main__":
main()