Spaces:
Runtime error
Runtime error
| import base64 | |
| import os | |
| import streamlit as st | |
| from langchain.chains import RetrievalQA | |
| from langchain.document_loaders import PDFMinerLoader | |
| from langchain.embeddings import SentenceTransformerEmbeddings | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from streamlit_chat import message | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
| import torch | |
| st.set_page_config(layout="wide") | |
| def process_answer(instruction, qa_chain): | |
| response = '' | |
| generated_text = qa_chain.run(instruction) | |
| return generated_text | |
| def get_file_size(file): | |
| file.seek(0, os.SEEK_END) | |
| file_size = file.tell() | |
| file.seek(0) | |
| return file_size | |
| def data_ingestion(): | |
| for root, dirs, files in os.walk("docs"): | |
| for file in files: | |
| if file.endswith(".pdf"): | |
| print(file) | |
| loader = PDFMinerLoader(os.path.join(root, file)) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500) | |
| splits = text_splitter.split_documents(documents) | |
| # Hier Embeddings erstellen | |
| embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectordb = FAISS.from_documents(splits, embeddings) | |
| vectordb.save_local("faiss_index") | |
| def initialize_qa_chain(selected_model): | |
| # Konstanten | |
| CHECKPOINT = selected_model | |
| TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT) | |
| BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32) | |
| pipe = pipeline( | |
| 'text2text-generation', | |
| model=BASE_MODEL, | |
| tokenizer=TOKENIZER, | |
| max_length=256, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.95, | |
| # device=torch.device('cpu') | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectordb = FAISS.load_local("faiss_index", embeddings) | |
| # QA-Kette erstellen | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=vectordb.as_retriever(), | |
| ) | |
| return qa_chain | |
| # Funktion zum Anzeigen der PDF einer bestimmten Datei | |
| def display_pdf(file): | |
| try: | |
| # Datei von Dateipfad öffnen | |
| with open(file, "rb") as f: | |
| base64_pdf = base64.b64encode(f.read()).decode('utf-8') | |
| # PDF in HTML einbetten | |
| pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>' | |
| # Datei anzeigen | |
| st.markdown(pdf_display, unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(f"Ein Fehler ist beim Anzeigen der PDF aufgetreten: {e}") | |
| # Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen | |
| def display_conversation(history): | |
| for i in range(len(history["generated"])): | |
| message(history["past"][i], is_user=True, key=f"{i}_user") | |
| message(history["generated"][i], key=str(i)) | |
| def main(): | |
| # Sidebar für die Modellauswahl hinzufügen | |
| model_options = ["hkunlp/instructor-base", "google/flan-t5-base", "google/flan-t5-small"] | |
| selected_model = st.sidebar.selectbox("Modell auswählen", model_options) | |
| st.markdown("<h1 style='text-align: center; color: blue;'>Dateiupload für Behörden und Organisationen mit Sicherheitsaufgaben 📄 </h1>", unsafe_allow_html=True) | |
| st.markdown("<h2 style='text-align: center; color:black;'>Laden Sie Ihr PDF hoch und stellen Sie Fragen 👇</h3>", unsafe_allow_html=True) | |
| uploaded_file = st.file_uploader("", type=["pdf"]) | |
| if uploaded_file is not None: | |
| file_details = { | |
| "Dateiname": uploaded_file.name, | |
| "Dateigröße": get_file_size(uploaded_file) | |
| } | |
| os.makedirs("docs", exist_ok=True) | |
| filepath = os.path.join("docs", uploaded_file.name) | |
| try: | |
| with open(filepath, "wb") as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.markdown("<h4 style color:black;'>Dateidetails</h4>", unsafe_allow_html=True) | |
| st.json(file_details) | |
| st.markdown("<h4 style color:black;'>Dateivorschau</h4>", unsafe_allow_html=True) | |
| pdf_view = display_pdf(filepath) | |
| with col2: | |
| st.success(f'Modell erfolgreich ausgewählt: {selected_model}') | |
| with st.spinner('Embeddings werden erstellt...'): | |
| ingested_data = data_ingestion() | |
| st.success('Embeddings wurden erfolgreich erstellt!') | |
| st.markdown("<h4 style color:black;'>Hier chatten</h4>", unsafe_allow_html=True) | |
| user_input = st.text_input("", key="input") | |
| # Sitzungszustand für generierte Antworten und vergangene Nachrichten initialisieren | |
| if "generated" not in st.session_state: | |
| st.session_state["generated"] = ["Ich bin bereit, Ihnen zu helfen"] | |
| if "past" not in st.session_state: | |
| st.session_state["past"] = ["Hallo!"] | |
| # In der Datenbank nach einer Antwort basierend auf der Benutzereingabe suchen und den Sitzungszustand aktualisieren | |
| if user_input: | |
| answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model)) | |
| st.session_state["past"].append(user_input) | |
| response = answer | |
| st.session_state["generated"].append(response) | |
| # Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen | |
| if st.session_state["generated"]: | |
| display_conversation(st.session_state) | |
| except Exception as e: | |
| st.error(f"Ein Fehler ist aufgetreten: {e}") | |
| if __name__ == "__main__": | |
| main() | |