import streamlit as st import google.generativeai as genai import fitz from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter import numpy as np import faiss def warn(*args, **kwargs): pass import warnings warnings.warn = warn warnings.filterwarnings('ignore') from langchain_community.document_loaders import PyPDFLoader # Initialize session state variables if "messages" not in st.session_state: st.session_state.messages = [] if "uploaded_files" not in st.session_state: st.session_state.uploaded_files = [] if "api_key" not in st.session_state: st.session_state.api_key = "" def extract_text_from_pdf(file): file.seek(0) pdf_bytes = file.read() if not pdf_bytes: raise ValueError(f"Le fichier {file.name} est vide ou n’a pas pu être lu.") try: doc = fitz.open(stream=pdf_bytes, filetype="pdf") except Exception as e: raise RuntimeError(f"Erreur lors de l'ouverture du fichier {file.name} : {e}") text = "" for page in doc: text += page.get_text() return text def process_files(files): texts = [] for file in files: if file.type == "text/plain": content = file.getvalue().decode("utf-8") texts.append(content) elif file.type == "application/pdf": content = extract_text_from_pdf(file) texts.append(content) return "\n".join(texts) def build_index(text): splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) chunks = splitter.split_text(text) embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") vectors = [embeddings.embed_query(chunk) for chunk in chunks] dimension = len(vectors[0]) index = faiss.IndexFlatL2(dimension) index.add(np.array(vectors).astype("float32")) return index, chunks def retrieve_chunks(query, index, chunks, k=3): embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") query_vector = np.array([embeddings.embed_query(query)]).astype("float32") distances, indices = index.search(query_vector, k) return [chunks[i] for i in indices[0]] def create_sidebar(): with st.sidebar: st.title("🤖 Gemini Chatbot") # API key input + validate button api_key_input = st.text_input("Google API Key:", type="password") if st.button("Validate API"): if api_key_input.strip(): st.session_state.api_key = api_key_input.strip() st.success("API Key saved ✅") else: st.error("Please enter a valid API key.") # Show file uploader only if API key is set if st.session_state.api_key: uploaded_files = st.file_uploader( "📂 Upload your files (txt, pdf, etc.)", accept_multiple_files=True ) if uploaded_files: for file in uploaded_files: # Avoid duplicates if file.name not in [f.name for f in st.session_state.uploaded_files]: st.session_state.uploaded_files.append(file) st.success(f"{len(st.session_state.uploaded_files)} files loaded.") if st.session_state.uploaded_files: st.markdown("**Files currently loaded:**") for f in st.session_state.uploaded_files: st.write(f.name) def main(): if "faiss_index" not in st.session_state: st.session_state.faiss_index = None if "chunks" not in st.session_state: st.session_state.chunks = [] st.set_page_config(page_title="Gemini Chatbot") create_sidebar() if not st.session_state.api_key: st.warning("👆 Please enter and validate your API key in the sidebar.") return genai.configure(api_key=st.session_state.api_key) model = genai.GenerativeModel("gemini-2.0-flash") # Build index if not done yet if st.session_state.uploaded_files and st.session_state.faiss_index is None: full_text = process_files(st.session_state.uploaded_files) if full_text: index, chunks = build_index(full_text) st.write(f"Nombre de chunks : {len(chunks)}") if "faiss_index" in st.session_state and st.session_state.faiss_index is not None: st.write(f"Dimension FAISS : {st.session_state.faiss_index.d}") st.write(f"Taille index FAISS : {st.session_state.faiss_index.ntotal}") else: st.warning("L'index FAISS n'est pas encore initialisé.") st.session_state.faiss_index = index st.session_state.chunks = chunks st.title("💬 Gemini Chatbot") # Show chat history chat_container = st.container() with chat_container: for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # Champ texte toujours en bas prompt = st.text_area("Type your question here...", key="input") if st.button("Send"): # Ajouter message utilisateur if not prompt.strip(): st.warning("Please write a message before sending.") return st.session_state.messages.append({"role": "user", "content": prompt}) #st.experimental_rerun() with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Thinking..."): try: # Retrieve relevant chunks from your files if st.session_state.faiss_index: relevant_chunks = retrieve_chunks(prompt, st.session_state.faiss_index, st.session_state.chunks, k=3) context = "\n---\n".join(chunk if isinstance(chunk, str) else chunk.page_content for chunk in relevant_chunks) prompt_with_context = f"Use the following context to answer the question:\n{context}\n\nQuestion: {prompt}" else: prompt_with_context = prompt # fallback response = model.generate_content(prompt_with_context) if response.text: st.markdown(response.text) st.session_state.messages.append({"role": "assistant", "content": response.text}) else: st.markdown("Sorry, I couldn't get a response.") st.session_state.messages.append({"role": "assistant", "content": "Sorry, I couldn't get a response."}) except Exception as e: st.error(f"Error: {e}") st.session_state.messages.append({"role": "assistant", "content": f"Error: {e}"}) if __name__ == "__main__": main()