Spaces:
Running
Running
import logging | |
import queue | |
from collections import deque | |
from concurrent.futures import ThreadPoolExecutor | |
import streamlit as st | |
import cv2 | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
from model import Predictor | |
import openai | |
# Настройки | |
DEFAULT_WIDTH = 50 | |
openai.api_key = 'sk-proj-GDxupB1DFvTTWBg38VyST3BlbkFJ7MdcACLwu3u0U1QvWeMb' | |
logger = logging.getLogger(__name__) | |
def correct_text_gpt3(input_text): | |
prompt = f"Исправь грамматические ошибки в тексте: '{input_text}'" | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant that corrects grammatical errors."}, | |
{"role": "user", "content": prompt} | |
], | |
max_tokens=50, | |
n=1, | |
stop=None, | |
temperature=0.5, | |
) | |
return response.choices[0].message['content'].strip() | |
# Центрируем контент | |
width = 50 | |
side = max((100 - width) / 1.2, 0.01) | |
_, container, _ = st.columns([side, width, side]) | |
# Модель инференса | |
class SLInference: | |
def __init__(self, config_path): | |
self.config = self.load_config(config_path) | |
self.predictor = Predictor(self.config) | |
self.input_queue = deque(maxlen=32) | |
self.pred = '' | |
def load_config(self, config_path): | |
import json | |
with open(config_path, 'r') as f: | |
return json.load(f) | |
def start(self): | |
pass | |
def predict(self, frames): | |
frames_resized = [cv2.resize(frame, (224, 224)) for frame in frames] | |
while len(frames_resized) < 32: | |
frames_resized.append(frames_resized[-1]) | |
result = self.predictor.predict(frames_resized) | |
if result: | |
return result["labels"][0] | |
return 'no' | |
def process_batch(inference_thread, frames, gestures): | |
gesture = inference_thread.predict(frames) | |
if gesture not in ['no', ''] and gesture not in gestures: | |
gestures.append(gesture) | |
# Основной интерфейс | |
def main(config_path): | |
# --- Заголовок блока --- | |
st.markdown(""" | |
<div class="upload-section"> | |
<h3>🎥 Sign Language Recognition Demo</h3> | |
<p>Upload a short video clip to detect sign gestures:</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# --- Скрытый лейбл uploader'а в стилизованной обёртке --- | |
with st.container(): | |
uploaded_file = st.file_uploader(" ", type=["mp4", "avi", "mov", "gif"], label_visibility="collapsed") | |
if uploaded_file is not None: | |
video_bytes = uploaded_file.read() | |
container.video(data=video_bytes) | |
inference_thread = SLInference(config_path) | |
inference_thread.start() | |
text_output = st.empty() | |
if st.button("🔍 Predict Gestures"): | |
import tempfile | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(video_bytes) | |
cap = cv2.VideoCapture(tfile.name) | |
gestures = [] | |
frames = [] | |
batch_size = 32 | |
def process_frames(batch): | |
process_batch(inference_thread, batch, gestures) | |
with ThreadPoolExecutor() as executor: | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame) | |
if len(frames) == batch_size: | |
executor.submit(process_frames, frames) | |
frames = [] | |
if frames: | |
executor.submit(process_frames, frames) | |
cap.release() | |
# Вывод результата | |
text_output.markdown( | |
f'<div class="section"><p style="font-size:20px">🖐️ Detected gestures: <b>{" ".join(gestures)}</b></p></div>', | |
unsafe_allow_html=True | |
) | |
# Исправление текста | |
st.text(correct_text_gpt3(" ".join(gestures))) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
main("configs/config.json") | |