|
import streamlit as st |
|
import json |
|
import os |
|
from sentence_transformers import SentenceTransformer, util |
|
import torch |
|
from huggingface_hub import InferenceClient |
|
import asyncio |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
if not HF_TOKEN: |
|
raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the application.") |
|
|
|
|
|
@st.cache_resource |
|
def load_data(file_path): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return SentenceTransformer('distiluse-base-multilingual-cased-v1') |
|
|
|
async def generate_keywords(query): |
|
client = InferenceClient(token=HF_TOKEN) |
|
|
|
prompt = f"Na podstawie poniższego pytania, wygeneruj 3-5 słów kluczowych, które najlepiej opisują główne tematy i koncepcje prawne zawarte w pytaniu. Podaj tylko słowa kluczowe, oddzielone przecinkami.\n\nPytanie: {query}\n\nSłowa kluczowe:" |
|
|
|
response = await client.text_generation( |
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
prompt=prompt, |
|
max_new_tokens=50, |
|
temperature=0.3, |
|
top_p=0.9 |
|
) |
|
|
|
keywords = [keyword.strip() for keyword in response.split(',')] |
|
return keywords |
|
|
|
async def generate_ai_response(query, relevant_chunks): |
|
client = InferenceClient(token=HF_TOKEN) |
|
|
|
context = "Kontekst prawny:\n\n" |
|
for chunk in relevant_chunks: |
|
context += f"{chunk['metadata']['nazwa']} - Artykuł {chunk['metadata']['article']}:\n" |
|
context += f"{chunk['text']}\n\n" |
|
|
|
prompt = f"Jesteś asystentem prawniczym. Odpowiedz na poniższe pytanie na podstawie podanego kontekstu prawnego.\n\nKontekst: {context}\n\nPytanie: {query}\n\nOdpowiedź:" |
|
|
|
response = "" |
|
async for token in client.text_generation( |
|
model="Qwen/Qwen2.5-72B-Instruct", |
|
prompt=prompt, |
|
max_new_tokens=2048, |
|
temperature=0.5, |
|
top_p=0.7, |
|
stream=True |
|
): |
|
response += token |
|
yield token |
|
|
|
def search_relevant_chunks(keywords, chunks, model, top_k=3): |
|
keyword_embedding = model.encode(keywords, convert_to_tensor=True) |
|
chunk_embeddings = model.encode([chunk['text'] for chunk in chunks], convert_to_tensor=True) |
|
|
|
cos_scores = util.pytorch_cos_sim(keyword_embedding, chunk_embeddings) |
|
top_results = torch.topk(cos_scores.mean(dim=0), k=top_k) |
|
|
|
return [chunks[idx] for idx in top_results.indices] |
|
|
|
|
|
def main(): |
|
st.title("Chatbot Prawny z AI") |
|
|
|
|
|
data_file = "processed_kodeksy.json" |
|
if not os.path.exists(data_file): |
|
st.error(f"Plik {data_file} nie istnieje. Najpierw przetwórz dane kodeksów.") |
|
return |
|
|
|
chunks = load_data(data_file) |
|
model = load_model() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Zadaj pytanie dotyczące prawa..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.spinner("Analizuję pytanie i szukam odpowiednich informacji..."): |
|
keywords = asyncio.run(generate_keywords(prompt)) |
|
relevant_chunks = search_relevant_chunks(keywords, chunks, model) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
for chunk in asyncio.run(generate_ai_response(prompt, relevant_chunks)): |
|
full_response += chunk |
|
message_placeholder.markdown(full_response + "▌") |
|
message_placeholder.markdown(full_response) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("Opcje") |
|
if st.button("Wyczyść historię czatu"): |
|
st.session_state.messages = [] |
|
st.experimental_rerun() |
|
|
|
st.subheader("Informacje o bazie danych") |
|
st.write(f"Liczba chunków: {len(chunks)}") |
|
st.write(f"Przykładowy chunk:") |
|
st.json(chunks[0] if chunks else {}) |
|
|
|
if __name__ == "__main__": |
|
main() |