osheina commited on
Commit
9220f04
·
verified ·
1 Parent(s): a7b063a

Upload 3 files

Browse files
pages/pages_.DS_Store ADDED
Binary file (6.15 kB). View file
 
pages/pages_1_Camera.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import queue
3
+ from collections import deque
4
+
5
+ import streamlit as st
6
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
7
+
8
+ from utils import SLInference
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def main():
14
+ """
15
+ Main function of the app.
16
+ """
17
+ config = {
18
+ "path_to_model": "S3D.onnx",
19
+ "threshold": 0.8,
20
+ "topk": 5,
21
+ "path_to_class_list": "RSL_class_list.txt",
22
+ "window_size": 8,
23
+ "provider": "OpenVINOExecutionProvider"
24
+ }
25
+
26
+ inference_thread = SLInference(config)
27
+ inference_thread.start()
28
+
29
+ webrtc_ctx = webrtc_streamer(
30
+ key="video-sendonly",
31
+ mode=WebRtcMode.SENDONLY,
32
+ media_stream_constraints={"video": True},
33
+ )
34
+
35
+ gestures_deque = deque(maxlen=5)
36
+
37
+ # Set up Streamlit interface
38
+ st.title("Sign Language Recognition Demo")
39
+ image_place = st.empty()
40
+ text_output = st.empty()
41
+ last_5_gestures = st.empty()
42
+
43
+
44
+ while True:
45
+ if webrtc_ctx.video_receiver:
46
+ try:
47
+ video_frame = webrtc_ctx.video_receiver.get_frame(timeout=1)
48
+ except queue.Empty:
49
+ logger.warning("Queue is empty")
50
+ continue
51
+
52
+ img_rgb = video_frame.to_ndarray(format="rgb24")
53
+ image_place.image(img_rgb)
54
+ inference_thread.input_queue.append(video_frame.reformat(224,224).to_ndarray(format="rgb24"))
55
+
56
+ gesture = inference_thread.pred
57
+ if gesture not in ['no', '']:
58
+ if not gestures_deque:
59
+ gestures_deque.append(gesture)
60
+ elif gesture != gestures_deque[-1]:
61
+ gestures_deque.append(gesture)
62
+
63
+ text_output.markdown(f'<p style="font-size:20px"> Current gesture: {gesture}</p>',
64
+ unsafe_allow_html=True)
65
+ last_5_gestures.markdown(f'<p style="font-size:20px"> Last 5 gestures: {" ".join(gestures_deque)}</p>',
66
+ unsafe_allow_html=True)
67
+ print(gestures_deque)
68
+
69
+ if __name__ == "__main__":
70
+ main()
pages/pages_2_Openai.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import queue
3
+ from collections import deque
4
+ from concurrent.futures import ThreadPoolExecutor
5
+
6
+ import streamlit as st
7
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
8
+
9
+ import cv2
10
+ from model import Predictor # Import Predictor from your model file
11
+ DEFAULT_WIDTH = 50
12
+
13
+ import openai
14
+
15
+ # Initialize the OpenAI client
16
+ openai.api_key ='sk-proj-mWWgUXxgG6SWCbRj6QpkT3BlbkFJYhthNppsE5NA8IUxcBmd'
17
+
18
+ def correct_text_gpt3(input_text):
19
+ prompt = f"Исправь грамматические ошибки в тексте: '{input_text}'"
20
+ response = openai.ChatCompletion.create(
21
+ model="gpt-3.5-turbo",
22
+ messages=[
23
+ {"role": "system", "content": "You are a helpful assistant that corrects grammatical errors."},
24
+ {"role": "user", "content": prompt}
25
+ ],
26
+ max_tokens=50,
27
+ n=1,
28
+ stop=None,
29
+ temperature=0.5,
30
+ )
31
+
32
+ corrected_text = response.choices[0].message['content'].strip()
33
+ return corrected_text
34
+
35
+ #st.set_page_config(layout="wide")
36
+
37
+
38
+
39
+ width = 50
40
+ side = max((100 - width) / 1.2, 0.01)
41
+
42
+ _, container, _ = st.columns([side, width, side])
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ class SLInference:
47
+ def __init__(self, config_path):
48
+ self.config = self.load_config(config_path)
49
+ self.predictor = Predictor(self.config)
50
+ self.input_queue = deque(maxlen=32) # Queue to store 32 frames
51
+ self.pred = ''
52
+
53
+ def load_config(self, config_path):
54
+ import json
55
+ with open(config_path, 'r') as f:
56
+ return json.load(f)
57
+
58
+ def start(self):
59
+ pass # This method can be left empty or add initialization logic
60
+
61
+ def predict(self, frames):
62
+ frames_resized = [cv2.resize(frame, (224, 224)) for frame in frames]
63
+ while len(frames_resized) < 32:
64
+ frames_resized.append(frames_resized[-1])
65
+ result = self.predictor.predict(frames_resized)
66
+ if result:
67
+ return result["labels"][0]
68
+ return 'no'
69
+
70
+ def process_batch(inference_thread, frames, gestures):
71
+ gesture = inference_thread.predict(frames)
72
+ if gesture not in ['no', ''] and gesture not in gestures:
73
+ gestures.append(gesture)
74
+
75
+ def main(config_path):
76
+ #st.set_page_config(layout="wide")
77
+ st.title("Sign Language Recognition Demo")
78
+
79
+ st.warning("Please upload a video file for prediction.")
80
+
81
+ uploaded_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "gif"])
82
+
83
+ if uploaded_file is not None:
84
+ video_bytes = uploaded_file.read()
85
+ container.video(data=video_bytes)
86
+ #st.video(video_bytes)
87
+
88
+ inference_thread = SLInference(config_path)
89
+ inference_thread.start()
90
+
91
+ text_output = st.empty()
92
+
93
+ if st.button("Predict"):
94
+ import tempfile
95
+ tfile = tempfile.NamedTemporaryFile(delete=False)
96
+ tfile.write(video_bytes)
97
+ cap = cv2.VideoCapture(tfile.name)
98
+
99
+ gestures = []
100
+ frames = []
101
+ batch_size = 32
102
+
103
+ def process_frames(batch):
104
+ process_batch(inference_thread, batch, gestures)
105
+
106
+ with ThreadPoolExecutor() as executor:
107
+ while cap.isOpened():
108
+ ret, frame = cap.read()
109
+ if not ret:
110
+ break
111
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
112
+ frames.append(frame)
113
+ if len(frames) == batch_size:
114
+ executor.submit(process_frames, frames)
115
+ frames = []
116
+
117
+ if frames:
118
+ executor.submit(process_frames, frames)
119
+
120
+ cap.release()
121
+ text_output.markdown(f'<p style="font-size:20px"> Gestures in video: {" ".join(gestures)}</p>',
122
+ unsafe_allow_html=True)
123
+ st.text(correct_text_gpt3(" ".join(gestures)))
124
+
125
+ print(gestures)
126
+
127
+ if __name__ == "__main__":
128
+ logging.basicConfig(level=logging.INFO)
129
+ main("configs/config.json")