|
import math |
|
from typing import List |
|
|
|
try: |
|
from typing import Literal, cast |
|
except ImportError: |
|
from typing_extensions import Literal |
|
|
|
import av |
|
import cv2 |
|
import numpy as np |
|
import streamlit as st |
|
from streamlit_webrtc import ( |
|
WebRtcMode, |
|
create_mix_track, |
|
create_process_track, |
|
webrtc_streamer, |
|
) |
|
|
|
from sample_utils.turn import get_ice_servers |
|
|
|
st.markdown( |
|
""" |
|
Mix multiple inputs with different video filters into one stream. |
|
""" |
|
) |
|
|
|
VideoFilterType = Literal["noop", "cartoon", "edges", "rotate"] |
|
|
|
|
|
def make_video_frame_callback(_type: VideoFilterType): |
|
def callback(frame: av.VideoFrame) -> av.VideoFrame: |
|
img = frame.to_ndarray(format="bgr24") |
|
|
|
if _type == "noop": |
|
pass |
|
elif _type == "cartoon": |
|
|
|
img_color = cv2.pyrDown(cv2.pyrDown(img)) |
|
for _ in range(6): |
|
img_color = cv2.bilateralFilter(img_color, 9, 9, 7) |
|
img_color = cv2.pyrUp(cv2.pyrUp(img_color)) |
|
|
|
|
|
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
img_edges = cv2.adaptiveThreshold( |
|
cv2.medianBlur(img_edges, 7), |
|
255, |
|
cv2.ADAPTIVE_THRESH_MEAN_C, |
|
cv2.THRESH_BINARY, |
|
9, |
|
2, |
|
) |
|
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
img = cv2.bitwise_and(img_color, img_edges) |
|
elif _type == "edges": |
|
|
|
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) |
|
elif _type == "rotate": |
|
|
|
rows, cols, _ = img.shape |
|
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) |
|
img = cv2.warpAffine(img, M, (cols, rows)) |
|
|
|
return av.VideoFrame.from_ndarray(img, format="bgr24") |
|
|
|
return callback |
|
|
|
|
|
def mixer_callback(frames: List[av.VideoFrame]) -> av.VideoFrame: |
|
buf_w = 640 |
|
buf_h = 480 |
|
buffer = np.zeros((buf_h, buf_w, 3), dtype=np.uint8) |
|
|
|
n_inputs = len(frames) |
|
|
|
n_cols = math.ceil(math.sqrt(n_inputs)) |
|
n_rows = math.ceil(n_inputs / n_cols) |
|
grid_w = buf_w // n_cols |
|
grid_h = buf_h // n_rows |
|
|
|
for i in range(n_inputs): |
|
frame = frames[i] |
|
if frame is None: |
|
continue |
|
|
|
grid_x = (i % n_cols) * grid_w |
|
grid_y = (i // n_cols) * grid_h |
|
|
|
img = frame.to_ndarray(format="bgr24") |
|
src_h, src_w = img.shape[0:2] |
|
|
|
aspect_ratio = src_w / src_h |
|
|
|
window_w = min(grid_w, int(grid_h * aspect_ratio)) |
|
window_h = min(grid_h, int(window_w / aspect_ratio)) |
|
|
|
window_offset_x = (grid_w - window_w) // 2 |
|
window_offset_y = (grid_h - window_h) // 2 |
|
|
|
window_x0 = grid_x + window_offset_x |
|
window_y0 = grid_y + window_offset_y |
|
window_x1 = window_x0 + window_w |
|
window_y1 = window_y0 + window_h |
|
|
|
buffer[window_y0:window_y1, window_x0:window_x1, :] = cv2.resize( |
|
img, (window_w, window_h) |
|
) |
|
|
|
new_frame = av.VideoFrame.from_ndarray(buffer, format="bgr24") |
|
|
|
return new_frame |
|
|
|
|
|
COMMON_RTC_CONFIG = {"iceServers": get_ice_servers()} |
|
|
|
st.header("Input 1") |
|
input1_ctx = webrtc_streamer( |
|
key="input1_ctx", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration=COMMON_RTC_CONFIG, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
) |
|
filter1_type = st.radio( |
|
"Select transform type", |
|
("noop", "cartoon", "edges", "rotate"), |
|
key="mix-filter1-type", |
|
) |
|
callback = make_video_frame_callback(cast(VideoFilterType, filter1_type)) |
|
input1_video_process_track = None |
|
if input1_ctx.output_video_track: |
|
input1_video_process_track = create_process_track( |
|
input_track=input1_ctx.output_video_track, |
|
frame_callback=callback, |
|
) |
|
|
|
st.header("Input 2") |
|
input2_ctx = webrtc_streamer( |
|
key="input2_ctx", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration=COMMON_RTC_CONFIG, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
) |
|
filter2_type = st.radio( |
|
"Select transform type", |
|
("noop", "cartoon", "edges", "rotate"), |
|
key="mix-filter2-type", |
|
) |
|
callback = make_video_frame_callback(cast(VideoFilterType, filter2_type)) |
|
input2_video_process_track = None |
|
if input2_ctx.output_video_track: |
|
input2_video_process_track = create_process_track( |
|
input_track=input2_ctx.output_video_track, frame_callback=callback |
|
) |
|
|
|
st.header("Input 3 (no filter)") |
|
input3_ctx = webrtc_streamer( |
|
key="input3_ctx", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration=COMMON_RTC_CONFIG, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
) |
|
|
|
st.header("Mixed output") |
|
mix_track = create_mix_track(kind="video", mixer_callback=mixer_callback, key="mix") |
|
mix_ctx = webrtc_streamer( |
|
key="mix", |
|
mode=WebRtcMode.RECVONLY, |
|
rtc_configuration=COMMON_RTC_CONFIG, |
|
source_video_track=mix_track, |
|
desired_playing_state=input1_ctx.state.playing |
|
or input2_ctx.state.playing |
|
or input3_ctx.state.playing, |
|
) |
|
|
|
if mix_ctx.source_video_track and input1_video_process_track: |
|
mix_ctx.source_video_track.add_input_track(input1_video_process_track) |
|
if mix_ctx.source_video_track and input2_video_process_track: |
|
mix_ctx.source_video_track.add_input_track(input2_video_process_track) |
|
if mix_ctx.source_video_track and input3_ctx.output_video_track: |
|
|
|
mix_ctx.source_video_track.add_input_track(input3_ctx.output_video_track) |
|
|