|
import logging |
|
import logging.handlers |
|
import queue |
|
import urllib.request |
|
from pathlib import Path |
|
|
|
try: |
|
from typing import Literal |
|
except ImportError: |
|
from typing_extensions import Literal |
|
|
|
import av |
|
import cv2 |
|
import numpy as np |
|
import PIL |
|
import streamlit as st |
|
from aiortc.contrib.media import MediaPlayer |
|
|
|
from streamlit_webrtc import ( |
|
ClientSettings, |
|
VideoTransformerBase, |
|
WebRtcMode, |
|
webrtc_streamer, |
|
) |
|
|
|
HERE = Path(__file__).parent |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def download_file(url, download_to: Path, expected_size=None): |
|
|
|
|
|
if download_to.exists(): |
|
if expected_size: |
|
if download_to.stat().st_size == expected_size: |
|
return |
|
else: |
|
st.info(f"{url} is already downloaded.") |
|
if not st.button("Download again?"): |
|
return |
|
|
|
download_to.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
weights_warning, progress_bar = None, None |
|
try: |
|
weights_warning = st.warning("Downloading %s..." % url) |
|
progress_bar = st.progress(0) |
|
with open(download_to, "wb") as output_file: |
|
with urllib.request.urlopen(url) as response: |
|
length = int(response.info()["Content-Length"]) |
|
counter = 0.0 |
|
MEGABYTES = 2.0 ** 20.0 |
|
while True: |
|
data = response.read(8192) |
|
if not data: |
|
break |
|
counter += len(data) |
|
output_file.write(data) |
|
|
|
|
|
weights_warning.warning( |
|
"Downloading %s... (%6.2f/%6.2f MB)" |
|
% (url, counter / MEGABYTES, length / MEGABYTES) |
|
) |
|
progress_bar.progress(min(counter / length, 1.0)) |
|
|
|
finally: |
|
if weights_warning is not None: |
|
weights_warning.empty() |
|
if progress_bar is not None: |
|
progress_bar.empty() |
|
|
|
|
|
def main(): |
|
st.header("WebRTC demo") |
|
|
|
object_detection_page = "Real time object detection (sendrecv)" |
|
video_filters_page = ( |
|
"Real time video transform with simple OpenCV filters (sendrecv)" |
|
) |
|
streaming_page = ( |
|
"Consuming media files on server-side and streaming it to browser (recvonly)" |
|
) |
|
sendonly_page = "WebRTC is sendonly and images are shown via st.image() (sendonly)" |
|
loopback_page = "Simple video loopback (sendrecv)" |
|
app_mode = st.sidebar.selectbox( |
|
"Choose the app mode", |
|
[ |
|
object_detection_page, |
|
video_filters_page, |
|
streaming_page, |
|
sendonly_page, |
|
loopback_page, |
|
], |
|
) |
|
st.subheader(app_mode) |
|
|
|
if app_mode == video_filters_page: |
|
app_video_filters() |
|
elif app_mode == object_detection_page: |
|
app_object_detection() |
|
elif app_mode == streaming_page: |
|
app_streaming() |
|
elif app_mode == sendonly_page: |
|
app_sendonly() |
|
elif app_mode == loopback_page: |
|
app_loopback() |
|
|
|
|
|
def app_loopback(): |
|
""" Simple video loopback """ |
|
webrtc_streamer( |
|
key="loopback", |
|
mode=WebRtcMode.SENDRECV, |
|
client_settings=WEBRTC_CLIENT_SETTINGS, |
|
video_transformer_factory=None, |
|
) |
|
|
|
|
|
def app_video_filters(): |
|
""" Video transforms with OpenCV """ |
|
|
|
class OpenCVVideoTransformer(VideoTransformerBase): |
|
type: Literal["noop", "cartoon", "edges", "rotate"] |
|
|
|
def __init__(self) -> None: |
|
self.type = "noop" |
|
|
|
def transform(self, frame: av.VideoFrame) -> av.VideoFrame: |
|
img = frame.to_ndarray(format="bgr24") |
|
|
|
if self.type == "noop": |
|
pass |
|
elif self.type == "cartoon": |
|
|
|
img_color = cv2.pyrDown(cv2.pyrDown(img)) |
|
for _ in range(6): |
|
img_color = cv2.bilateralFilter(img_color, 9, 9, 7) |
|
img_color = cv2.pyrUp(cv2.pyrUp(img_color)) |
|
|
|
|
|
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
img_edges = cv2.adaptiveThreshold( |
|
cv2.medianBlur(img_edges, 7), |
|
255, |
|
cv2.ADAPTIVE_THRESH_MEAN_C, |
|
cv2.THRESH_BINARY, |
|
9, |
|
2, |
|
) |
|
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) |
|
|
|
|
|
img = cv2.bitwise_and(img_color, img_edges) |
|
elif self.type == "edges": |
|
|
|
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) |
|
elif self.type == "rotate": |
|
|
|
rows, cols, _ = img.shape |
|
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) |
|
img = cv2.warpAffine(img, M, (cols, rows)) |
|
|
|
return img |
|
|
|
webrtc_ctx = webrtc_streamer( |
|
key="opencv-filter", |
|
mode=WebRtcMode.SENDRECV, |
|
client_settings=WEBRTC_CLIENT_SETTINGS, |
|
video_transformer_factory=OpenCVVideoTransformer, |
|
async_transform=True, |
|
) |
|
|
|
transform_type = st.radio( |
|
"Select transform type", ("noop", "cartoon", "edges", "rotate") |
|
) |
|
if webrtc_ctx.video_transformer: |
|
webrtc_ctx.video_transformer.type = transform_type |
|
|
|
st.markdown( |
|
"This demo is based on " |
|
"https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " |
|
"Many thanks to the project." |
|
) |
|
|
|
|
|
def app_object_detection(): |
|
"""Object detection demo with MobileNet SSD. |
|
This model and code are based on |
|
https://github.com/robmarkcole/object-detection-app |
|
""" |
|
MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" |
|
MODEL_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.caffemodel" |
|
PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" |
|
PROTOTXT_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.prototxt.txt" |
|
|
|
CLASSES = [ |
|
"background", |
|
"aeroplane", |
|
"bicycle", |
|
"bird", |
|
"boat", |
|
"bottle", |
|
"bus", |
|
"car", |
|
"cat", |
|
"chair", |
|
"cow", |
|
"diningtable", |
|
"dog", |
|
"horse", |
|
"motorbike", |
|
"person", |
|
"pottedplant", |
|
"sheep", |
|
"sofa", |
|
"train", |
|
"tvmonitor", |
|
] |
|
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) |
|
|
|
download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) |
|
download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) |
|
|
|
DEFAULT_CONFIDENCE_THRESHOLD = 0.5 |
|
|
|
class NNVideoTransformer(VideoTransformerBase): |
|
confidence_threshold: float |
|
|
|
def __init__(self) -> None: |
|
self._net = cv2.dnn.readNetFromCaffe( |
|
str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH) |
|
) |
|
self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD |
|
|
|
def _annotate_image(self, image, detections): |
|
|
|
(h, w) = image.shape[:2] |
|
labels = [] |
|
for i in np.arange(0, detections.shape[2]): |
|
confidence = detections[0, 0, i, 2] |
|
|
|
if confidence > self.confidence_threshold: |
|
|
|
|
|
|
|
idx = int(detections[0, 0, i, 1]) |
|
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) |
|
(startX, startY, endX, endY) = box.astype("int") |
|
|
|
|
|
label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%" |
|
labels.append(label) |
|
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2) |
|
y = startY - 15 if startY - 15 > 15 else startY + 15 |
|
cv2.putText( |
|
image, |
|
label, |
|
(startX, y), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, |
|
COLORS[idx], |
|
2, |
|
) |
|
return image, labels |
|
|
|
def transform(self, frame: av.VideoFrame) -> np.ndarray: |
|
image = frame.to_ndarray(format="bgr24") |
|
blob = cv2.dnn.blobFromImage( |
|
cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5 |
|
) |
|
self._net.setInput(blob) |
|
detections = self._net.forward() |
|
annotated_image, labels = self._annotate_image(image, detections) |
|
|
|
|
|
return annotated_image |
|
|
|
webrtc_ctx = webrtc_streamer( |
|
key="object-detection", |
|
mode=WebRtcMode.SENDRECV, |
|
client_settings=WEBRTC_CLIENT_SETTINGS, |
|
video_transformer_factory=NNVideoTransformer, |
|
async_transform=True, |
|
) |
|
|
|
confidence_threshold = st.slider( |
|
"Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05 |
|
) |
|
if webrtc_ctx.video_transformer: |
|
webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold |
|
|
|
st.markdown( |
|
"This demo uses a model and code from " |
|
"https://github.com/robmarkcole/object-detection-app. " |
|
"Many thanks to the project." |
|
) |
|
|
|
|
|
def app_streaming(): |
|
""" Media streamings """ |
|
MEDIAFILES = { |
|
"big_buck_bunny_720p_2mb.mp4": { |
|
"url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_2mb.mp4", |
|
"local_file_path": HERE / "data/big_buck_bunny_720p_2mb.mp4", |
|
"type": "video", |
|
}, |
|
"big_buck_bunny_720p_10mb.mp4": { |
|
"url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_10mb.mp4", |
|
"local_file_path": HERE / "data/big_buck_bunny_720p_10mb.mp4", |
|
"type": "video", |
|
}, |
|
"file_example_MP3_700KB.mp3": { |
|
"url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_700KB.mp3", |
|
"local_file_path": HERE / "data/file_example_MP3_700KB.mp3", |
|
"type": "audio", |
|
}, |
|
"file_example_MP3_5MG.mp3": { |
|
"url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_5MG.mp3", |
|
"local_file_path": HERE / "data/file_example_MP3_5MG.mp3", |
|
"type": "audio", |
|
}, |
|
} |
|
media_file_label = st.radio( |
|
"Select a media file to stream", tuple(MEDIAFILES.keys()) |
|
) |
|
media_file_info = MEDIAFILES[media_file_label] |
|
download_file(media_file_info["url"], media_file_info["local_file_path"]) |
|
|
|
def create_player(): |
|
return MediaPlayer(str(media_file_info["local_file_path"])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEBRTC_CLIENT_SETTINGS.update( |
|
{ |
|
"fmedia_stream_constraints": { |
|
"video": media_file_info["type"] == "video", |
|
"audio": media_file_info["type"] == "audio", |
|
} |
|
} |
|
) |
|
|
|
webrtc_streamer( |
|
key=f"media-streaming-{media_file_label}", |
|
mode=WebRtcMode.RECVONLY, |
|
client_settings=WEBRTC_CLIENT_SETTINGS, |
|
player_factory=create_player, |
|
) |
|
|
|
|
|
def app_sendonly(): |
|
"""A sample to use WebRTC in sendonly mode to transfer frames |
|
from the browser to the server and to render frames via `st.image`.""" |
|
webrtc_ctx = webrtc_streamer( |
|
key="loopback", |
|
mode=WebRtcMode.SENDONLY, |
|
client_settings=WEBRTC_CLIENT_SETTINGS, |
|
) |
|
|
|
if webrtc_ctx.video_receiver: |
|
image_loc = st.empty() |
|
while True: |
|
try: |
|
frame = webrtc_ctx.video_receiver.frames_queue.get(timeout=1) |
|
except queue.Empty: |
|
print("Queue is empty. Stop the loop.") |
|
webrtc_ctx.video_receiver.stop() |
|
break |
|
|
|
img = frame.to_ndarray(format="bgr24") |
|
img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
image_loc.image(img) |
|
|
|
|
|
WEBRTC_CLIENT_SETTINGS = ClientSettings( |
|
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, |
|
media_stream_constraints={"video": True, "audio": True}, |
|
) |
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig( |
|
format="[%(asctime)s] %(levelname)7s from %(name)s in %(filename)s:%(lineno)d: " |
|
"%(message)s", |
|
force=True, |
|
) |
|
|
|
logger.setLevel(level=logging.DEBUG) |
|
|
|
st_webrtc_logger = logging.getLogger("streamlit_webrtc") |
|
st_webrtc_logger.setLevel(logging.DEBUG) |
|
|
|
main() |
|
|