Doux Thibault commited on
Commit
9a30a8c
1 Parent(s): 025e412

rag to streamlit + new pdf

Browse files
Modules/rag.py CHANGED
@@ -9,17 +9,24 @@ from langchain_community.document_loaders import PyPDFLoader
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.document_loaders import WebBaseLoader
11
  from langchain_community.vectorstores import Chroma, FAISS
 
12
  from langchain_mistralai import MistralAIEmbeddings
13
  from langchain import hub
 
 
 
 
14
  from typing import Literal
15
- from langchain_core.prompts import ChatPromptTemplate
16
  from langchain_core.pydantic_v1 import BaseModel, Field
17
  from langchain_mistralai import ChatMistralAI
18
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
19
  from langchain_community.tools import DuckDuckGoSearchRun
 
20
 
21
  def load_chunk_persist_pdf() -> Chroma:
22
- pdf_folder_path = "data/pdf_folder/"
 
23
  documents = []
24
  for file in os.listdir(pdf_folder_path):
25
  if file.endswith('.pdf'):
@@ -32,7 +39,7 @@ def load_chunk_persist_pdf() -> Chroma:
32
  vectorstore = Chroma.from_documents(
33
  documents=chunked_documents,
34
  embedding=MistralAIEmbeddings(),
35
- persist_directory="data/chroma_store/"
36
  )
37
  vectorstore.persist()
38
  return vectorstore
@@ -54,26 +61,29 @@ class RouteQuery(BaseModel):
54
  # LLM with function call
55
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
56
 
57
- # structured_llm_router = llm.with_structured_output(RouteQuery, method="json_mode")
58
-
59
- # Prompt
60
- system = """You are an expert at routing a user question to a vectorstore or web search.
61
- The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
62
- Use the vectorstore for questions on these topics. For all else, use web-search."""
63
- route_prompt = ChatPromptTemplate.from_messages(
64
- [
65
- ("system", system),
66
- ("human", "{question}"),
67
- ]
 
 
 
 
68
  )
69
- prompt = hub.pull("rlm/rag-prompt")
70
  from langchain_core.output_parsers import StrOutputParser
71
  from langchain_core.runnables import RunnablePassthrough
72
 
73
  def format_docs(docs):
74
  return "\n\n".join(doc.page_content for doc in docs)
75
 
76
-
77
  rag_chain = (
78
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
79
  | prompt
@@ -81,6 +91,7 @@ rag_chain = (
81
  | StrOutputParser()
82
  )
83
 
84
- print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
 
85
 
86
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))
 
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain_community.document_loaders import WebBaseLoader
11
  from langchain_community.vectorstores import Chroma, FAISS
12
+ from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain_mistralai import MistralAIEmbeddings
14
  from langchain import hub
15
+ from langchain.chains import (
16
+ create_history_aware_retriever,
17
+ create_retrieval_chain,
18
+ )
19
  from typing import Literal
20
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
  from langchain_core.pydantic_v1 import BaseModel, Field
22
  from langchain_mistralai import ChatMistralAI
23
  from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
24
  from langchain_community.tools import DuckDuckGoSearchRun
25
+ from pathlib import Path
26
 
27
  def load_chunk_persist_pdf() -> Chroma:
28
+
29
+ pdf_folder_path = os.path.join(os.getcwd(),Path("data/pdf/"))
30
  documents = []
31
  for file in os.listdir(pdf_folder_path):
32
  if file.endswith('.pdf'):
 
39
  vectorstore = Chroma.from_documents(
40
  documents=chunked_documents,
41
  embedding=MistralAIEmbeddings(),
42
+ persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
43
  )
44
  vectorstore.persist()
45
  return vectorstore
 
61
  # LLM with function call
62
  llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
63
 
64
+
65
+ prompt = ChatPromptTemplate.from_template(
66
+ """
67
+ You are a professional AI coach specialized in fitness, bodybuilding and nutrition.
68
+ You must adapt to the user : if he is a beginner, use simple words. You are gentle and motivative.
69
+ Use the following pieces of retrieved context to answer the question.
70
+ If you don't know the answer, just say that you don't know, and to refer to a nutritionist or a doctor.
71
+ Use three sentences maximum and keep the answer concise.
72
+
73
+ Question: {question}
74
+
75
+ Context: {context}
76
+
77
+ Answer:
78
+ """,
79
  )
 
80
  from langchain_core.output_parsers import StrOutputParser
81
  from langchain_core.runnables import RunnablePassthrough
82
 
83
  def format_docs(docs):
84
  return "\n\n".join(doc.page_content for doc in docs)
85
 
86
+
87
  rag_chain = (
88
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
89
  | prompt
 
91
  | StrOutputParser()
92
  )
93
 
94
+
95
+ # print(rag_chain.invoke("Build a fitness program for me. Be precise in terms of exercises"))
96
 
97
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program"))
app.py CHANGED
@@ -6,8 +6,19 @@ from langchain_mistralai import ChatMistralAI
6
  from dotenv import load_dotenv
7
  load_dotenv() # load .env api keys
8
  import os
 
 
 
9
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
10
 
 
 
 
 
 
 
 
 
11
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
12
  # Create two columns
13
  col1, col2 = st.columns(2)
@@ -43,8 +54,10 @@ with col1:
43
 
44
  with st.chat_message("assistant"):
45
  # Build answer from LLM
46
-
47
- response = llm.invoke(st.session_state.messages).content
 
 
48
  st.session_state.messages.append({"role": "assistant", "content": response})
49
  st.markdown(response)
50
 
 
6
  from dotenv import load_dotenv
7
  load_dotenv() # load .env api keys
8
  import os
9
+
10
+ from Modules.rag import rag_chain
11
+
12
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
13
 
14
+ def format_messages(messages):
15
+ formatted_messages = ""
16
+ for message in messages:
17
+ role = message["role"]
18
+ content = message["content"]
19
+ formatted_messages += f"{role}: {content}\n"
20
+ return formatted_messages
21
+
22
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
23
  # Create two columns
24
  col1, col2 = st.columns(2)
 
54
 
55
  with st.chat_message("assistant"):
56
  # Build answer from LLM
57
+ response = rag_chain.invoke(
58
+ instruction
59
+ )
60
+ print(type(response))
61
  st.session_state.messages.append({"role": "assistant", "content": response})
62
  st.markdown(response)
63
 
data/pdf/F12_Strength&Conditioning_Program.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b6d7c1c04d0a98433e00e4a3ce1586311164a3ac50fc0e14a8fffb65ca7356b
3
+ size 17579128