muhtasham commited on
Commit
2d02398
1 Parent(s): 86b1799

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -28
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import logging
6
 
7
  from operator import itemgetter
8
- from langchain_openai import ChatOpenAI
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_core.prompts import ChatPromptTemplate
@@ -13,9 +13,6 @@ from langchain_community.vectorstores.chroma import Chroma
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.schema import AIMessage, HumanMessage
15
  from langchain_core.output_parsers import StrOutputParser
16
- from langchain_core.runnables import RunnableLambda, RunnablePassthrough
17
- from langchain.chains.combine_documents import create_stuff_documents_chain
18
- from langchain.chains import create_retrieval_chain
19
  from langchain.globals import set_debug
20
  from dotenv import load_dotenv
21
 
@@ -26,16 +23,27 @@ set_debug(True)
26
  load_dotenv()
27
 
28
  openai_api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
29
 
30
  persist_dir = "./chroma_db"
31
- device='cuda:0'
32
- model_name="all-mpnet-base-v2"
33
- model_kwargs = {'device': device if torch.cuda.is_available() else 'cpu'}
34
  logging.info(f"Using device {model_kwargs['device']}")
35
- # Create embeddings and store in vectordb
36
- embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
37
 
38
- def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
 
 
 
 
 
 
 
 
39
  logging.info("Configuring retriever")
40
 
41
  if not os.path.exists(persist_dir):
@@ -63,10 +71,8 @@ def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
63
  vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
64
 
65
  # Define retriever
66
- retriever = vectordb.as_retriever(
67
- search_type="similarity_score_threshold",
68
- search_kwargs={'score_threshold': 0.8}
69
- )
70
 
71
  return retriever
72
  else:
@@ -74,10 +80,7 @@ def configure_retriever(local_files, chunk_size=12500, chunk_overlap=2500):
74
  vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
75
 
76
  # Define retriever
77
- retriever = vectordb.as_retriever(
78
- search_type="similarity_score_threshold",
79
- search_kwargs={'score_threshold': 0.8}
80
- )
81
 
82
  return retriever
83
 
@@ -86,7 +89,11 @@ local_files = [f for f in os.listdir(directory) if f.endswith(".pdf")]
86
 
87
  # Setup LLM
88
  llm = ChatOpenAI(
89
- model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0, streaming=True
 
 
 
 
90
  )
91
 
92
  retriever = configure_retriever(local_files)
@@ -96,7 +103,7 @@ template = """Answer the question based only on the following context:
96
 
97
  Question: {question}
98
 
99
- Answer in German language.
100
  """
101
 
102
  prompt = ChatPromptTemplate.from_template(template)
@@ -111,28 +118,44 @@ chain = (
111
  | StrOutputParser()
112
  )
113
 
 
 
 
 
 
114
  def predict(message, history):
115
- message = f"Translate the following text to German: {message}"
116
  history_langchain_format = []
117
  for human, ai in history:
118
  history_langchain_format.append(HumanMessage(content=human))
119
  history_langchain_format.append(AIMessage(content=ai))
120
  history_langchain_format.append(HumanMessage(content=message))
121
  gpt_response = llm(history_langchain_format)
122
- return chain.invoke({"question": gpt_response.content})
 
 
 
 
123
 
124
- demo = gr.ChatInterface(
 
 
125
  predict,
126
  chatbot=gr.Chatbot(height=500, show_share_button=True),
127
  textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
128
  title="Beitrag Service",
129
  description="Ich bin Ihr hilfreicher KI-Assistent",
130
  theme="soft",
131
- examples=["Hello"],
 
 
 
 
 
 
 
132
  cache_examples=True,
133
- retry_btn="Wiederholen",
134
- undo_btn="Vorheriges löschen",
135
- clear_btn="Löschen").launch(show_api= False)
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
5
  import logging
6
 
7
  from operator import itemgetter
8
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
9
  from langchain_community.document_loaders import PyPDFLoader
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_core.prompts import ChatPromptTemplate
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.schema import AIMessage, HumanMessage
15
  from langchain_core.output_parsers import StrOutputParser
 
 
 
16
  from langchain.globals import set_debug
17
  from dotenv import load_dotenv
18
 
 
23
  load_dotenv()
24
 
25
  openai_api_key = os.getenv("OPENAI_API_KEY")
26
+ langchain_api_key = os.getenv("LANGCHAIN_API_KEY")
27
+ langchain_endpoint = os.getenv("LANGCHAIN_ENDPOINT")
28
+ langchain_project_id = os.getenv("LANGCHAIN_PROJECT")
29
+ access_key = os.getenv("ACCESS_TOKEN_SECRET")
30
 
31
  persist_dir = "./chroma_db"
32
+ device = 'cuda:0'
33
+ model_name = "all-mpnet-base-v2"
34
+ model_kwargs = {'device': device if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}
35
  logging.info(f"Using device {model_kwargs['device']}")
36
+ embed_money = False
 
37
 
38
+ # Create embeddings and store in vectordb
39
+ if embed_money:
40
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
41
+ logging.info(f"Using OpenAI embeddings")
42
+ else:
43
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, show_progress=True, model_kwargs=model_kwargs)
44
+ logging.info(f"Using HuggingFace embeddings")
45
+
46
+ def configure_retriever(local_files, chunk_size=15000, chunk_overlap=2500):
47
  logging.info("Configuring retriever")
48
 
49
  if not os.path.exists(persist_dir):
 
71
  vectordb = Chroma.from_documents(splits, embeddings, persist_directory=persist_dir)
72
 
73
  # Define retriever
74
+ retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
75
+
 
 
76
 
77
  return retriever
78
  else:
 
80
  vectordb = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
81
 
82
  # Define retriever
83
+ retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
 
 
 
84
 
85
  return retriever
86
 
 
89
 
90
  # Setup LLM
91
  llm = ChatOpenAI(
92
+ model_name="gpt-4-0125-preview", openai_api_key=openai_api_key, temperature=0.1, streaming=True
93
+ )
94
+
95
+ llm_translate = ChatOpenAI(
96
+ model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, temperature=0.0
97
  )
98
 
99
  retriever = configure_retriever(local_files)
 
103
 
104
  Question: {question}
105
 
106
+ Answer in German Language. If the question is not related to the context, answer with "I don't know" in German.
107
  """
108
 
109
  prompt = ChatPromptTemplate.from_template(template)
 
118
  | StrOutputParser()
119
  )
120
 
121
+ chain_translate = (llm_translate
122
+ | StrOutputParser()
123
+ )
124
+
125
+
126
  def predict(message, history):
127
+ message = chain_translate.invoke(f"Translate this sentence to English: {message}")
128
  history_langchain_format = []
129
  for human, ai in history:
130
  history_langchain_format.append(HumanMessage(content=human))
131
  history_langchain_format.append(AIMessage(content=ai))
132
  history_langchain_format.append(HumanMessage(content=message))
133
  gpt_response = llm(history_langchain_format)
134
+ for chunk in chain.stream({"question": gpt_response.content}): # Stream the response
135
+ yield chunk
136
+
137
+
138
+ image_path = "./ui/logo.png" if os.path.exists("./ui/logo.png") else "./logo.png"
139
 
140
+ with gr.Blocks() as demo:
141
+ gr.Image(image_path)
142
+ gr.ChatInterface(
143
  predict,
144
  chatbot=gr.Chatbot(height=500, show_share_button=True),
145
  textbox=gr.Textbox(placeholder="stell mir Fragen", container=False, scale=7),
146
  title="Beitrag Service",
147
  description="Ich bin Ihr hilfreicher KI-Assistent",
148
  theme="soft",
149
+ examples=[
150
+ "Generate auditing questions about Change Management",
151
+ "Generate auditing questions about Software Maintenance",
152
+ "Generate auditing questions about Data Protection",
153
+ "Generate auditing questions about IT",
154
+ "Generate auditing questions about control systems",
155
+ "Generate auditing questions about GDPR compliance",
156
+ ],
157
  cache_examples=True,
158
+ ).launch(show_api= False)
 
 
159
 
160
  if __name__ == "__main__":
161
+ demo.launch()