Théo Rousseaux commited on
Commit
a755c90
1 Parent(s): 025e412

début pose agent

Browse files
Modules/PoseEstimation/__init__.py ADDED
File without changes
Modules/PoseEstimation/pose_agent.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Modules.PoseEstimation.pose_estimator import calculate_angle, joints_id_dict, model
2
+ from langchain.tools import tool
3
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.messages import HumanMessage
6
+ from langchain_mistralai.chat_models import ChatMistralAI
7
+
8
+ # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
9
+ llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
10
+
11
+ @tool
12
+ def compute_right_knee_angle(pose: list) -> float:
13
+
14
+ """
15
+ Computes the knee angle.
16
+
17
+ Args:
18
+ pose (list): list of keypoints
19
+
20
+ Returns:
21
+ knee_angle (float): knee angle
22
+ """
23
+
24
+ right_hip = pose[joints_id_dict['right_hip']]
25
+ right_knee = pose[joints_id_dict['right_knee']]
26
+ right_ankle = pose[joints_id_dict['right_ankle']]
27
+
28
+ knee_angle = calculate_angle(right_hip, right_knee, right_ankle)
29
+
30
+ print(knee_angle)
31
+
32
+ return str(knee_angle)
33
+
34
+ @tool
35
+ def get_keypoints_from_path(video_path: str):
36
+ """
37
+ Get keypoints from a video.
38
+
39
+ Args:
40
+ video_path (str): path to the video
41
+ model (YOLO): model to use
42
+
43
+ Returns:
44
+ keypoints (list): list of keypoints
45
+ """
46
+
47
+ keypoints = []
48
+ results = model(video_path, save=True, show_conf=False, show_boxes=False)
49
+ for frame in results:
50
+ tensor = frame.keypoints.xy[0]
51
+ keypoints.append(tensor.tolist())
52
+
53
+ return keypoints
54
+
55
+
56
+
57
+ tools = [compute_right_knee_angle]
58
+
59
+ prompt = ChatPromptTemplate.from_messages(
60
+ [
61
+ (
62
+ "system",
63
+ "You are a helpful assistant. Make sure to use the compute_right_knee_angle tool for information.",
64
+ ),
65
+ ("placeholder", "{chat_history}"),
66
+ ("human", "{input}"),
67
+ ("placeholder", "{agent_scratchpad}"),
68
+ ]
69
+ )
70
+
71
+ # Construct the Tools agent
72
+ agent = create_tool_calling_agent(llm, tools, prompt)
73
+
74
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
Modules/PoseEstimation/pose_estimation.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
Modules/PoseEstimation/pose_estimator.py CHANGED
@@ -31,8 +31,6 @@ def get_keypoints_from_keypoints(model, video_path):
31
 
32
  return keypoints
33
 
34
- keypoints = get_keypoints_from_keypoints(model, '../../data/pose/squat.mp4')
35
-
36
  def calculate_angle(a, b, c):
37
 
38
  """
@@ -112,5 +110,4 @@ def moving_average(data, window_size):
112
  for i in range(len(data) - window_size + 1):
113
  avg.append(sum(data[i:i + window_size]) / window_size)
114
 
115
- return avg
116
-
 
31
 
32
  return keypoints
33
 
 
 
34
  def calculate_angle(a, b, c):
35
 
36
  """
 
110
  for i in range(len(data) - window_size + 1):
111
  avg.append(sum(data[i:i + window_size]) / window_size)
112
 
113
+ return avg
 
app.py CHANGED
@@ -7,6 +7,8 @@ from dotenv import load_dotenv
7
  load_dotenv() # load .env api keys
8
  import os
9
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
 
 
10
 
11
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
12
  # Create two columns
@@ -59,12 +61,17 @@ with col2:
59
  if video_uploaded is None:
60
  video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
61
  if video_uploaded:
 
62
  ask_video.empty()
63
- with st.spin("Processing video"):
64
- pass # TO DO
65
  _left, mid, _right = st.columns(3)
66
  with mid:
67
  st.video(video_uploaded)
 
 
 
 
 
 
68
 
69
  st.subheader("Graph Displayer")
70
  # TO DO
 
7
  load_dotenv() # load .env api keys
8
  import os
9
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
10
+ from Modules.PoseEstimation import pose_estimator
11
+ from utils import save_uploaded_file
12
 
13
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
14
  # Create two columns
 
61
  if video_uploaded is None:
62
  video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
63
  if video_uploaded:
64
+ video_uploaded = save_uploaded_file(video_uploaded)
65
  ask_video.empty()
 
 
66
  _left, mid, _right = st.columns(3)
67
  with mid:
68
  st.video(video_uploaded)
69
+ apply_pose = st.button("Apply Pose Estimation")
70
+
71
+ if apply_pose:
72
+ with st.spinner("Processing video"):
73
+ keypoints = pose_estimator.get_keypoints_from_keypoints(pose_estimator.model, video_uploaded)
74
+
75
 
76
  st.subheader("Graph Displayer")
77
  # TO DO
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+
4
+ def save_uploaded_file(uploaded_file):
5
+ try:
6
+ file_path = os.path.join('uploaded', uploaded_file.name)
7
+ with open(file_path, 'wb') as f:
8
+ f.write(uploaded_file.getvalue())
9
+ return file_path
10
+ except Exception as e:
11
+ st.error(f"Error: {e}")
12
+ return None