Doux Thibault commited on
Commit
c104abf
1 Parent(s): ed250fe

add workout plan generator

Browse files
Files changed (4) hide show
  1. Modules/rag.py +7 -1
  2. Modules/router.py +4 -3
  3. Modules/workout_plan.py +139 -0
  4. app.py +20 -2
Modules/rag.py CHANGED
@@ -63,6 +63,12 @@ prompt = ChatPromptTemplate.from_template(
63
  Use the following pieces of retrieved context to answer the question.
64
  If you don't know the answer, use your common knowledge.
65
  Use three sentences maximum and keep the answer concise.
 
 
 
 
 
 
66
 
67
  Question: {question}
68
 
@@ -86,6 +92,6 @@ rag_chain = (
86
 
87
 
88
 
89
- print(rag_chain.invoke("WHi I'm Susan. Can you make a fitness program for me please?"))
90
 
91
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
 
63
  Use the following pieces of retrieved context to answer the question.
64
  If you don't know the answer, use your common knowledge.
65
  Use three sentences maximum and keep the answer concise.
66
+ If the user asks you a full program workout, structure your response in this way (this is an example):
67
+ - First workout : Lower body (1 hour)
68
+ 1. Barbelle squat / 4 sets of 8 reps / 2'30 recovery
69
+ 2. Lunges / 4 sets of 10 reps / 2'recovery
70
+ 3. etc
71
+ - Second workout .... and so on.
72
 
73
  Question: {question}
74
 
 
92
 
93
 
94
 
95
+ # print(rag_chain.invoke("WHi I'm Susan. Can you make a fitness program for me please?"))
96
 
97
  # print(rag_chain.invoke("I am a 45 years old woman and I have to loose weight for the summer. Provide me with a fitness program, and a nutrition program"))
Modules/router.py CHANGED
@@ -10,9 +10,10 @@ from langchain_core.output_parsers import StrOutputParser
10
  router_chain = (
11
  ChatPromptTemplate.from_template(
12
  """Given the user question below, classify it as either being about :
13
- - `fitness_advices` if the user query is about nutrition or fitness program, exercices
14
- - `movement_analysis` if the user asks to analyse or give advice on his exercice execution?
15
- - `smalltalk` if other.
 
16
 
17
  Do not respond with more than one word.
18
 
 
10
  router_chain = (
11
  ChatPromptTemplate.from_template(
12
  """Given the user question below, classify it as either being about :
13
+ - 'fitness_advices`' if the user query is about nutrition or fitness strategies, exercices
14
+ - 'workout_plan' if the user asks for a detailed workout plan or a full fitness program
15
+ - 'movement_analysis' if the user asks to analyse or give advice on his exercice execution?
16
+ - 'smalltalk' if other.
17
 
18
  Do not respond with more than one word.
19
 
Modules/workout_plan.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
3
+ from dotenv import load_dotenv
4
+ load_dotenv() # load .env api keys
5
+
6
+ mistral_api_key = os.getenv("MISTRAL_API_KEY")
7
+ print("mistral_api_key", mistral_api_key)
8
+ import pandas as pd
9
+ from langchain.output_parsers import PandasDataFrameOutputParser
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain_mistralai import MistralAIEmbeddings
14
+ from langchain import hub
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from typing import Literal
18
+ from langchain_core.prompts import PromptTemplate
19
+ from langchain_mistralai import ChatMistralAI
20
+ from pathlib import Path
21
+ from langchain.retrievers import (
22
+ MergerRetriever,
23
+ )
24
+ import pprint
25
+ from typing import Any, Dict
26
+ from huggingface_hub import login
27
+ login(token=os.getenv("HUGGING_FACE_TOKEN"))
28
+
29
+ def load_chunk_persist_pdf(task) -> Chroma:
30
+
31
+ pdf_folder_path = os.path.join(os.getcwd(),Path(f"data/pdf/{task}"))
32
+ documents = []
33
+ for file in os.listdir(pdf_folder_path):
34
+ if file.endswith('.pdf'):
35
+ pdf_path = os.path.join(pdf_folder_path, file)
36
+ loader = PyPDFLoader(pdf_path)
37
+ documents.extend(loader.load())
38
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
39
+ chunked_documents = text_splitter.split_documents(documents)
40
+ os.makedirs("data/chroma_store/", exist_ok=True)
41
+ vectorstore = Chroma.from_documents(
42
+ documents=chunked_documents,
43
+ embedding=MistralAIEmbeddings(),
44
+ persist_directory= os.path.join(os.getcwd(),Path("data/chroma_store/"))
45
+ )
46
+ vectorstore.persist()
47
+ return vectorstore
48
+
49
+ df = pd.DataFrame(
50
+ {
51
+ "exercise": ["Squat","Bench Press","Lunges","Pull ups"],
52
+ "sets": [4, 4, 3, 3],
53
+ "repetitions": [10, 8, 8, 8],
54
+ "rest":["2:30","2:00","1:30","2:00"]
55
+ }
56
+ )
57
+
58
+ # parser = PandasDataFrameOutputParser(dataframe=df)
59
+
60
+ # personal_info_vectorstore = load_chunk_persist_pdf("personal_info")
61
+ # zero2hero_vectorstore = load_chunk_persist_pdf("zero2hero")
62
+ # bodyweight_vectorstore = load_chunk_persist_pdf("bodyweight")
63
+ # nutrition_vectorstore = load_chunk_persist_pdf("nutrition")
64
+ # workout_vectorstore = load_chunk_persist_pdf("workout")
65
+ # zero2hero_retriever = zero2hero_vectorstore.as_retriever()
66
+ # nutrition_retriever = nutrition_vectorstore.as_retriever()
67
+ # bodyweight_retriever = bodyweight_vectorstore.as_retriever()
68
+ # workout_retriever = workout_vectorstore.as_retriever()
69
+ # personal_info_retriever = personal_info_vectorstore.as_retriever()
70
+
71
+ llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
72
+
73
+ # prompt = PromptTemplate(
74
+ # template="""
75
+ # You are a professional AI coach specialized in building fitness plans, full workout programs.
76
+ # You must adapt to the user according to personal informations in the context. A You are gentle and motivative.
77
+ # Use the following pieces of retrieved context to answer the user's query.
78
+
79
+ # Context: {context}
80
+
81
+ # \n{format_instructions}\n{question}\n
82
+ # """,
83
+ # input_variables=["question","context"],
84
+ # partial_variables={"format_instructions": parser.get_format_instructions()},
85
+ # )
86
+
87
+ # def format_docs(docs):
88
+ # return "\n\n".join(doc.page_content for doc in docs)
89
+
90
+ # def format_parser_output(parser_output: Dict[str, Any]) -> None:
91
+ # for key in parser_output.keys():
92
+ # parser_output[key] = parser_output[key].to_dict()
93
+ # return pprint.PrettyPrinter(width=4, compact=True).pprint(parser_output)
94
+
95
+ # retriever = MergerRetriever(retrievers=[zero2hero_retriever, bodyweight_retriever, nutrition_retriever, workout_retriever, personal_info_retriever])
96
+
97
+ # chain = (
98
+ # {"context": zero2hero_retriever | format_docs, "question": RunnablePassthrough()}
99
+ # | prompt
100
+ # | llm
101
+ # | parser
102
+ # )
103
+
104
+ # # chain = prompt | llm | parser
105
+ # format_parser_output(chain.invoke("Build me a full body workout plan for summer body."))
106
+
107
+
108
+ from pydantic import BaseModel, Field
109
+ from typing import List
110
+ from langchain_core.output_parsers import JsonOutputParser
111
+
112
+ class Exercise(BaseModel):
113
+ exercice: str = Field(description="Name of the exercise")
114
+ nombre_series: int = Field(description="Number of sets for the exercise")
115
+ nombre_repetitions: int = Field(description="Number of repetitions for the exercise")
116
+ temps_repos: str = Field(description="Rest time between sets")
117
+
118
+ class MusculationProgram(BaseModel):
119
+ exercises: List[Exercise]
120
+
121
+
122
+ from langchain.prompts import PromptTemplate
123
+
124
+ # Define your query to get a musculation program.
125
+ musculation_query = "Provide a musculation program with exercises, number of sets, number of repetitions, and rest time between sets."
126
+
127
+ # Set up a parser + inject instructions into the prompt template.
128
+ parser = JsonOutputParser(pydantic_object=MusculationProgram)
129
+
130
+ prompt = PromptTemplate(
131
+ template="Answer the user query.\n{format_instructions}\n{query}\n",
132
+ input_variables=["query"],
133
+ partial_variables={"format_instructions": parser.get_format_instructions()},
134
+ )
135
+
136
+ # Set up a chain to invoke the language model with the prompt and parser.
137
+ workout_chain = prompt | llm | parser
138
+
139
+
app.py CHANGED
@@ -4,12 +4,15 @@ from Modules.Speech2Text.transcribe import transcribe
4
  import base64
5
  from langchain_mistralai import ChatMistralAI
6
  from langchain_core.prompts import ChatPromptTemplate
 
 
7
  from dotenv import load_dotenv
8
  load_dotenv() # load .env api keys
9
  import os
10
 
11
  from Modules.rag import rag_chain
12
  from Modules.router import router_chain
 
13
  # from Modules.PoseEstimation.pose_agent import agent_executor
14
 
15
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
@@ -27,11 +30,13 @@ prompt = ChatPromptTemplate.from_template(
27
  You are having a conversation with your client, which is either a beginner or an advanced athlete.
28
  You must be gentle, kind, and motivative.
29
  Always try to answer concisely to the queries.
 
30
  User: {question}
31
  AI Coach:"""
32
  )
33
  base_chain = prompt | llm
34
 
 
35
  # First column containers
36
  with col1:
37
 
@@ -52,6 +57,8 @@ with col1:
52
  with st.chat_message("assistant"):
53
  # Build answer from LLM
54
  direction = router_chain.invoke({"question":prompt})
 
 
55
  if direction=='fitness_advices':
56
  response = rag_chain.invoke(
57
  prompt
@@ -60,15 +67,26 @@ with col1:
60
  response = base_chain.invoke(
61
  {"question":prompt}
62
  ).content
63
- # elif direction =='movement_analysis':
 
64
  # response = agent_executor.invoke(
65
  # {"input" : instruction}
66
  # )["output"]
 
 
 
 
 
 
 
 
67
  print(type(response))
68
  st.session_state.messages.append({"role": "assistant", "content": response})
69
  st.markdown(response)
70
 
71
- st.subheader("Movement Analysis")
 
 
72
  # TO DO
73
  # Second column containers
74
  with col2:
 
4
  import base64
5
  from langchain_mistralai import ChatMistralAI
6
  from langchain_core.prompts import ChatPromptTemplate
7
+ import pandas as pd
8
+ import json
9
  from dotenv import load_dotenv
10
  load_dotenv() # load .env api keys
11
  import os
12
 
13
  from Modules.rag import rag_chain
14
  from Modules.router import router_chain
15
+ from Modules.workout_plan import workout_chain
16
  # from Modules.PoseEstimation.pose_agent import agent_executor
17
 
18
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
 
30
  You are having a conversation with your client, which is either a beginner or an advanced athlete.
31
  You must be gentle, kind, and motivative.
32
  Always try to answer concisely to the queries.
33
+
34
  User: {question}
35
  AI Coach:"""
36
  )
37
  base_chain = prompt | llm
38
 
39
+ display_workout = False
40
  # First column containers
41
  with col1:
42
 
 
57
  with st.chat_message("assistant"):
58
  # Build answer from LLM
59
  direction = router_chain.invoke({"question":prompt})
60
+ print(type(direction))
61
+ print(direction)
62
  if direction=='fitness_advices':
63
  response = rag_chain.invoke(
64
  prompt
 
67
  response = base_chain.invoke(
68
  {"question":prompt}
69
  ).content
70
+ elif direction =='movement_analysis':
71
+ response = "I can't do that for the moment"
72
  # response = agent_executor.invoke(
73
  # {"input" : instruction}
74
  # )["output"]
75
+ # elif direction == 'workout_plan':
76
+ else:
77
+ response = "Sure! I just made a workout for you. Check on the table I just provided you."
78
+ json_output = workout_chain.invoke({"query":prompt})
79
+ exercises_list = json_output['exercises']
80
+ workout_df = pd.DataFrame(exercises_list)
81
+ workout_df.columns = ["exercice", "nombre_series", "nombre_repetitions", "temps_repos"]
82
+ display_workout=True
83
  print(type(response))
84
  st.session_state.messages.append({"role": "assistant", "content": response})
85
  st.markdown(response)
86
 
87
+ if display_workout:
88
+ st.subheader("Workout")
89
+ st.data_editor(workout_df)
90
  # TO DO
91
  # Second column containers
92
  with col2: