import math from typing import List try: from typing import Literal, cast except ImportError: from typing_extensions import Literal # type: ignore import av import cv2 import numpy as np import streamlit as st from streamlit_webrtc import ( WebRtcMode, create_mix_track, create_process_track, webrtc_streamer, ) from sample_utils.turn import get_ice_servers st.markdown( """ Mix multiple inputs with different video filters into one stream. """ ) VideoFilterType = Literal["noop", "cartoon", "edges", "rotate"] def make_video_frame_callback(_type: VideoFilterType): def callback(frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") if _type == "noop": pass elif _type == "cartoon": # prepare color img_color = cv2.pyrDown(cv2.pyrDown(img)) for _ in range(6): img_color = cv2.bilateralFilter(img_color, 9, 9, 7) img_color = cv2.pyrUp(cv2.pyrUp(img_color)) # prepare edges img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img_edges = cv2.adaptiveThreshold( cv2.medianBlur(img_edges, 7), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2, ) img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) # combine color and edges img = cv2.bitwise_and(img_color, img_edges) elif _type == "edges": # perform edge detection img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) elif _type == "rotate": # rotate image rows, cols, _ = img.shape M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) img = cv2.warpAffine(img, M, (cols, rows)) return av.VideoFrame.from_ndarray(img, format="bgr24") return callback def mixer_callback(frames: List[av.VideoFrame]) -> av.VideoFrame: buf_w = 640 buf_h = 480 buffer = np.zeros((buf_h, buf_w, 3), dtype=np.uint8) n_inputs = len(frames) n_cols = math.ceil(math.sqrt(n_inputs)) n_rows = math.ceil(n_inputs / n_cols) grid_w = buf_w // n_cols grid_h = buf_h // n_rows for i in range(n_inputs): frame = frames[i] if frame is None: continue grid_x = (i % n_cols) * grid_w grid_y = (i // n_cols) * grid_h img = frame.to_ndarray(format="bgr24") src_h, src_w = img.shape[0:2] aspect_ratio = src_w / src_h window_w = min(grid_w, int(grid_h * aspect_ratio)) window_h = min(grid_h, int(window_w / aspect_ratio)) window_offset_x = (grid_w - window_w) // 2 window_offset_y = (grid_h - window_h) // 2 window_x0 = grid_x + window_offset_x window_y0 = grid_y + window_offset_y window_x1 = window_x0 + window_w window_y1 = window_y0 + window_h buffer[window_y0:window_y1, window_x0:window_x1, :] = cv2.resize( img, (window_w, window_h) ) new_frame = av.VideoFrame.from_ndarray(buffer, format="bgr24") return new_frame COMMON_RTC_CONFIG = {"iceServers": get_ice_servers()} st.header("Input 1") input1_ctx = webrtc_streamer( key="input1_ctx", mode=WebRtcMode.SENDRECV, rtc_configuration=COMMON_RTC_CONFIG, media_stream_constraints={"video": True, "audio": False}, ) filter1_type = st.radio( "Select transform type", ("noop", "cartoon", "edges", "rotate"), key="mix-filter1-type", ) callback = make_video_frame_callback(cast(VideoFilterType, filter1_type)) input1_video_process_track = None if input1_ctx.output_video_track: input1_video_process_track = create_process_track( input_track=input1_ctx.output_video_track, frame_callback=callback, ) st.header("Input 2") input2_ctx = webrtc_streamer( key="input2_ctx", mode=WebRtcMode.SENDRECV, rtc_configuration=COMMON_RTC_CONFIG, media_stream_constraints={"video": True, "audio": False}, ) filter2_type = st.radio( "Select transform type", ("noop", "cartoon", "edges", "rotate"), key="mix-filter2-type", ) callback = make_video_frame_callback(cast(VideoFilterType, filter2_type)) input2_video_process_track = None if input2_ctx.output_video_track: input2_video_process_track = create_process_track( input_track=input2_ctx.output_video_track, frame_callback=callback ) st.header("Input 3 (no filter)") input3_ctx = webrtc_streamer( key="input3_ctx", mode=WebRtcMode.SENDRECV, rtc_configuration=COMMON_RTC_CONFIG, media_stream_constraints={"video": True, "audio": False}, ) st.header("Mixed output") mix_track = create_mix_track(kind="video", mixer_callback=mixer_callback, key="mix") mix_ctx = webrtc_streamer( key="mix", mode=WebRtcMode.RECVONLY, rtc_configuration=COMMON_RTC_CONFIG, source_video_track=mix_track, desired_playing_state=input1_ctx.state.playing or input2_ctx.state.playing or input3_ctx.state.playing, ) if mix_ctx.source_video_track and input1_video_process_track: mix_ctx.source_video_track.add_input_track(input1_video_process_track) if mix_ctx.source_video_track and input2_video_process_track: mix_ctx.source_video_track.add_input_track(input2_video_process_track) if mix_ctx.source_video_track and input3_ctx.output_video_track: # Input3 is sourced without any filter. mix_ctx.source_video_track.add_input_track(input3_ctx.output_video_track)