File size: 4,298 Bytes
473c0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d28b08
 
 
 
 
 
 
473c0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d28b08
b98d63e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import logging
import queue
from collections import deque
from concurrent.futures import ThreadPoolExecutor

import streamlit as st
import cv2
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from model import Predictor
import openai

# Настройки
DEFAULT_WIDTH = 50
openai.api_key = 'sk-proj-GDxupB1DFvTTWBg38VyST3BlbkFJ7MdcACLwu3u0U1QvWeMb'
logger = logging.getLogger(__name__)

def correct_text_gpt3(input_text):
    prompt = f"Исправь грамматические ошибки в тексте: '{input_text}'"
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that corrects grammatical errors."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=50,
        n=1,
        stop=None,
        temperature=0.5,
    )
    return response.choices[0].message['content'].strip()

# Центрируем контент
width = 50
side = max((100 - width) / 1.2, 0.01)
_, container, _ = st.columns([side, width, side])

# Модель инференса
class SLInference:
    def __init__(self, config_path):
        self.config = self.load_config(config_path)
        self.predictor = Predictor(self.config)
        self.input_queue = deque(maxlen=32)
        self.pred = ''

    def load_config(self, config_path):
        import json
        with open(config_path, 'r') as f:
            return json.load(f)

    def start(self):
        pass

    def predict(self, frames):
        frames_resized = [cv2.resize(frame, (224, 224)) for frame in frames]
        while len(frames_resized) < 32:
            frames_resized.append(frames_resized[-1])
        result = self.predictor.predict(frames_resized)
        if result:
            return result["labels"][0]
        return 'no'

def process_batch(inference_thread, frames, gestures):
    gesture = inference_thread.predict(frames)
    if gesture not in ['no', ''] and gesture not in gestures:
        gestures.append(gesture)

# Основной интерфейс
def main(config_path):
    # --- Заголовок блока ---
    st.markdown("""
    <div class="upload-section">
        <h3>🎥 Sign Language Recognition Demo</h3>
        <p>Upload a short video clip to detect sign gestures:</p>
    </div>
    """, unsafe_allow_html=True)

    # --- Скрытый лейбл uploader'а в стилизованной обёртке ---
    with st.container():
        uploaded_file = st.file_uploader(" ", type=["mp4", "avi", "mov", "gif"], label_visibility="collapsed")

    if uploaded_file is not None:
        video_bytes = uploaded_file.read()
        container.video(data=video_bytes)

        inference_thread = SLInference(config_path)
        inference_thread.start()

        text_output = st.empty()

        if st.button("🔍 Predict Gestures"):
            import tempfile
            tfile = tempfile.NamedTemporaryFile(delete=False)
            tfile.write(video_bytes)
            cap = cv2.VideoCapture(tfile.name)

            gestures = []
            frames = []
            batch_size = 32

            def process_frames(batch):
                process_batch(inference_thread, batch, gestures)

            with ThreadPoolExecutor() as executor:
                while cap.isOpened():
                    ret, frame = cap.read()
                    if not ret:
                        break
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frames.append(frame)
                    if len(frames) == batch_size:
                        executor.submit(process_frames, frames)
                        frames = []

                if frames:
                    executor.submit(process_frames, frames)

            cap.release()

            # Вывод результата
            text_output.markdown(
                f'<div class="section"><p style="font-size:20px">🖐️ Detected gestures: <b>{" ".join(gestures)}</b></p></div>',
                unsafe_allow_html=True
            )

            # Исправление текста
            st.text(correct_text_gpt3(" ".join(gestures)))

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main("configs/config.json")