Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- .env +6 -0
- app.py +138 -0
- requirements.txt +7 -0
.env
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_API_KEY="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"
|
2 |
+
LANGCHAIN_TRACING_V2=true
|
3 |
+
LANGCHAIN_ENDPOINT=https://api.smith.langchain.com
|
4 |
+
LANGCHAIN_API_KEY="ls__481915cb2eaa4a53876c4bcf592457b0"
|
5 |
+
LANGCHAIN_PROJECT="Beitrag POC"
|
6 |
+
ACCESS_TOKEN_SECRET="hpr;F3H678%H"
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from operator import itemgetter
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langchain_community.document_loaders import PyPDFLoader
|
10 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
+
from langchain_core.prompts import ChatPromptTemplate
|
12 |
+
from langchain_community.vectorstores.chroma import Chroma
|
13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from langchain.schema import AIMessage, HumanMessage
|
15 |
+
from langchain_core.output_parsers import StrOutputParser
|
16 |
+
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
17 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
18 |
+
from langchain.chains import create_retrieval_chain
|
19 |
+
from langchain.globals import set_debug
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
|
22 |
+
# configure logging
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
+
|
25 |
+
set_debug(True)
|
26 |
+
load_dotenv()
|
27 |
+
|
28 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
29 |
+
|
30 |
+
persist_dir = "./chroma_db"
|
31 |
+
device='cuda:0'
|
32 |
+
model_name="all-mpnet-base-v2"
|
33 |
+
model_kwargs = {'device': device if torch.cuda.is_available() else 'cpu'}
|
34 |
+
logging.info(f"Using device {model_kwargs['device']}")
|
35 |
+
# Create embeddings and store in vectordb
|
36 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
|
37 |
+
|
38 |
+
def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
|
39 |
+
logging.info("Configuring retriever")
|
40 |
+
|
41 |
+
if not os.path.exists(persist_dir):
|
42 |
+
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
|
43 |
+
# Read documents
|
44 |
+
docs = []
|
45 |
+
temp_dir = tempfile.TemporaryDirectory()
|
46 |
+
for filename in local_files:
|
47 |
+
logging.info(f"Reading file {filename}")
|
48 |
+
# Read the file once
|
49 |
+
if not os.path.exists(os.path.join("docs", filename)):
|
50 |
+
file_content = open(os.path.join(".", filename), "rb").read()
|
51 |
+
else:
|
52 |
+
file_content = open(os.path.join("docs", filename), "rb").read()
|
53 |
+
temp_filepath = os.path.join(temp_dir.name, filename)
|
54 |
+
with open(temp_filepath, "wb") as f:
|
55 |
+
f.write(file_content)
|
56 |
+
loader = PyPDFLoader(temp_filepath)
|
57 |
+
docs.extend(loader.load())
|
58 |
+
|
59 |
+
# Split documents
|
60 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
61 |
+
splits = text_splitter.split_documents(docs)
|
62 |
+
|
63 |
+
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
|
64 |
+
|
65 |
+
# Define retriever
|
66 |
+
retriever = vectordb.as_retriever(
|
67 |
+
search_type="similarity_score_threshold",
|
68 |
+
search_kwargs={'score_threshold': 0.8}
|
69 |
+
)
|
70 |
+
|
71 |
+
return retriever
|
72 |
+
else:
|
73 |
+
logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
|
74 |
+
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
|
75 |
+
|
76 |
+
# Define retriever
|
77 |
+
retriever = vectordb.as_retriever(
|
78 |
+
search_type="similarity_score_threshold",
|
79 |
+
search_kwargs={'score_threshold': 0.8}
|
80 |
+
)
|
81 |
+
|
82 |
+
return retriever
|
83 |
+
|
84 |
+
directory = "docs" if os.path.exists("docs") else "."
|
85 |
+
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
|
86 |
+
|
87 |
+
# Setup LLM
|
88 |
+
llm = ChatOpenAI(
|
89 |
+
model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True
|
90 |
+
)
|
91 |
+
|
92 |
+
retriever = configure_retriever(local_files)
|
93 |
+
|
94 |
+
template = """Answer the question based only on the following context:
|
95 |
+
{context}
|
96 |
+
|
97 |
+
Question: {question}
|
98 |
+
|
99 |
+
Answer in German language.
|
100 |
+
"""
|
101 |
+
|
102 |
+
prompt = ChatPromptTemplate.from_template(template)
|
103 |
+
|
104 |
+
chain = (
|
105 |
+
{
|
106 |
+
"context": itemgetter("question") | retriever,
|
107 |
+
"question": itemgetter("question"),
|
108 |
+
}
|
109 |
+
| prompt
|
110 |
+
| llm
|
111 |
+
| StrOutputParser()
|
112 |
+
)
|
113 |
+
|
114 |
+
def predict(message, history):
|
115 |
+
message = f"Translate the following text to German: {message}"
|
116 |
+
history_langchain_format = []
|
117 |
+
for human, ai in history:
|
118 |
+
history_langchain_format.append(HumanMessage(content=human))
|
119 |
+
history_langchain_format.append(AIMessage(content=ai))
|
120 |
+
history_langchain_format.append(HumanMessage(content=message))
|
121 |
+
gpt_response = llm(history_langchain_format)
|
122 |
+
return chain.invoke({"question": gpt_response.content})
|
123 |
+
|
124 |
+
demo = gr.ChatInterface(
|
125 |
+
predict,
|
126 |
+
chatbot=gr.Chatbot(height=500, show_share_button=True),
|
127 |
+
textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
|
128 |
+
title="Beitrag Service",
|
129 |
+
description="Ich bin Ihr hilfreicher KI-Assistent",
|
130 |
+
theme="soft",
|
131 |
+
examples=["Hello"],
|
132 |
+
cache_examples=True,
|
133 |
+
retry_btn="Wiederholen",
|
134 |
+
undo_btn="Vorheriges löschen",
|
135 |
+
clear_btn="Löschen").launch(show_api= False)
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
openai==1.12.0
|
3 |
+
langchain==0.1.10
|
4 |
+
langchain-openai==0.0.8
|
5 |
+
pypdf==4.0.1
|
6 |
+
python-dotenv==1.0.1
|
7 |
+
chromadb==0.4.22
|