streamlit-webrtc-example / pages /6_mix_multi_inputs.py
whitphx's picture
whitphx HF staff
Fix all the demos to use Twilio STUN/TURN servers
79ac659
import math
from typing import List
try:
from typing import Literal, cast
except ImportError:
from typing_extensions import Literal # type: ignore
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import (
WebRtcMode,
create_mix_track,
create_process_track,
webrtc_streamer,
)
from sample_utils.turn import get_ice_servers
st.markdown(
"""
Mix multiple inputs with different video filters into one stream.
"""
)
VideoFilterType = Literal["noop", "cartoon", "edges", "rotate"]
def make_video_frame_callback(_type: VideoFilterType):
def callback(frame: av.VideoFrame) -> av.VideoFrame:
img = frame.to_ndarray(format="bgr24")
if _type == "noop":
pass
elif _type == "cartoon":
# prepare color
img_color = cv2.pyrDown(cv2.pyrDown(img))
for _ in range(6):
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
img_color = cv2.pyrUp(cv2.pyrUp(img_color))
# prepare edges
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img_edges = cv2.adaptiveThreshold(
cv2.medianBlur(img_edges, 7),
255,
cv2.ADAPTIVE_THRESH_MEAN_C,
cv2.THRESH_BINARY,
9,
2,
)
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
# combine color and edges
img = cv2.bitwise_and(img_color, img_edges)
elif _type == "edges":
# perform edge detection
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
elif _type == "rotate":
# rotate image
rows, cols, _ = img.shape
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
img = cv2.warpAffine(img, M, (cols, rows))
return av.VideoFrame.from_ndarray(img, format="bgr24")
return callback
def mixer_callback(frames: List[av.VideoFrame]) -> av.VideoFrame:
buf_w = 640
buf_h = 480
buffer = np.zeros((buf_h, buf_w, 3), dtype=np.uint8)
n_inputs = len(frames)
n_cols = math.ceil(math.sqrt(n_inputs))
n_rows = math.ceil(n_inputs / n_cols)
grid_w = buf_w // n_cols
grid_h = buf_h // n_rows
for i in range(n_inputs):
frame = frames[i]
if frame is None:
continue
grid_x = (i % n_cols) * grid_w
grid_y = (i // n_cols) * grid_h
img = frame.to_ndarray(format="bgr24")
src_h, src_w = img.shape[0:2]
aspect_ratio = src_w / src_h
window_w = min(grid_w, int(grid_h * aspect_ratio))
window_h = min(grid_h, int(window_w / aspect_ratio))
window_offset_x = (grid_w - window_w) // 2
window_offset_y = (grid_h - window_h) // 2
window_x0 = grid_x + window_offset_x
window_y0 = grid_y + window_offset_y
window_x1 = window_x0 + window_w
window_y1 = window_y0 + window_h
buffer[window_y0:window_y1, window_x0:window_x1, :] = cv2.resize(
img, (window_w, window_h)
)
new_frame = av.VideoFrame.from_ndarray(buffer, format="bgr24")
return new_frame
COMMON_RTC_CONFIG = {"iceServers": get_ice_servers()}
st.header("Input 1")
input1_ctx = webrtc_streamer(
key="input1_ctx",
mode=WebRtcMode.SENDRECV,
rtc_configuration=COMMON_RTC_CONFIG,
media_stream_constraints={"video": True, "audio": False},
)
filter1_type = st.radio(
"Select transform type",
("noop", "cartoon", "edges", "rotate"),
key="mix-filter1-type",
)
callback = make_video_frame_callback(cast(VideoFilterType, filter1_type))
input1_video_process_track = None
if input1_ctx.output_video_track:
input1_video_process_track = create_process_track(
input_track=input1_ctx.output_video_track,
frame_callback=callback,
)
st.header("Input 2")
input2_ctx = webrtc_streamer(
key="input2_ctx",
mode=WebRtcMode.SENDRECV,
rtc_configuration=COMMON_RTC_CONFIG,
media_stream_constraints={"video": True, "audio": False},
)
filter2_type = st.radio(
"Select transform type",
("noop", "cartoon", "edges", "rotate"),
key="mix-filter2-type",
)
callback = make_video_frame_callback(cast(VideoFilterType, filter2_type))
input2_video_process_track = None
if input2_ctx.output_video_track:
input2_video_process_track = create_process_track(
input_track=input2_ctx.output_video_track, frame_callback=callback
)
st.header("Input 3 (no filter)")
input3_ctx = webrtc_streamer(
key="input3_ctx",
mode=WebRtcMode.SENDRECV,
rtc_configuration=COMMON_RTC_CONFIG,
media_stream_constraints={"video": True, "audio": False},
)
st.header("Mixed output")
mix_track = create_mix_track(kind="video", mixer_callback=mixer_callback, key="mix")
mix_ctx = webrtc_streamer(
key="mix",
mode=WebRtcMode.RECVONLY,
rtc_configuration=COMMON_RTC_CONFIG,
source_video_track=mix_track,
desired_playing_state=input1_ctx.state.playing
or input2_ctx.state.playing
or input3_ctx.state.playing,
)
if mix_ctx.source_video_track and input1_video_process_track:
mix_ctx.source_video_track.add_input_track(input1_video_process_track)
if mix_ctx.source_video_track and input2_video_process_track:
mix_ctx.source_video_track.add_input_track(input2_video_process_track)
if mix_ctx.source_video_track and input3_ctx.output_video_track:
# Input3 is sourced without any filter.
mix_ctx.source_video_track.add_input_track(input3_ctx.output_video_track)