beitrag-service / app.py
muhtasham's picture
Upload 5 files
6f7484c verified
raw
history blame
4.82 kB
import os
import tempfile
import gradio as gr
import torch
import logging
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores.chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.globals import set_debug
from dotenv import load_dotenv
# configure logging
logging.basicConfig(level=logging.INFO)
set_debug(True)
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
persist_dir = "./chroma_db"
device='cuda:0'
model_name="all-mpnet-base-v2"
model_kwargs = {'device': device if torch.cuda.is_available() else 'cpu'}
logging.info(f"Using device {model_kwargs['device']}")
# Create embeddings and store in vectordb
embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
logging.info("Configuring retriever")
if not os.path.exists(persist_dir):
logging.info(f"Persist directory {persist_dir} does not exist. Creating it.")
# Read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for filename in local_files:
logging.info(f"Reading file {filename}")
# Read the file once
if not os.path.exists(os.path.join("docs", filename)):
file_content = open(os.path.join(".", filename), "rb").read()
else:
file_content = open(os.path.join("docs", filename), "rb").read()
temp_filepath = os.path.join(temp_dir.name, filename)
with open(temp_filepath, "wb") as f:
f.write(file_content)
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
splits = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
# Define retriever
retriever = vectordb.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
return retriever
else:
logging.info(f"Persist directory {persist_dir} exists. Loading from it.")
vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Define retriever
retriever = vectordb.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'score_threshold': 0.8}
)
return retriever
directory = "docs" if os.path.exists("docs") else "."
local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
# Setup LLM
llm = ChatOpenAI(
model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True
)
retriever = configure_retriever(local_files)
template = """Answer the question based only on the following context:
{context}
Question: {question}
Answer in German language.
"""
prompt = ChatPromptTemplate.from_template(template)
chain = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
| prompt
| llm
| StrOutputParser()
)
def predict(message, history):
message = f"Translate the following text to German: {message}"
history_langchain_format = []
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))
gpt_response = llm(history_langchain_format)
return chain.invoke({"question": gpt_response.content})
demo = gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=500, show_share_button=True),
textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
title="Beitrag Service",
description="Ich bin Ihr hilfreicher KI-Assistent",
theme="soft",
examples=["Hello"],
cache_examples=True,
retry_btn="Wiederholen",
undo_btn="Vorheriges löschen",
clear_btn="Löschen").launch(show_api= False)
if __name__ == "__main__":
demo.launch()