| | |
| | |
| | |
| | |
| |
|
| | import hashlib |
| | import os |
| | import shutil |
| | import tempfile |
| | from pathlib import Path |
| | from typing import Iterable, List, Optional, Tuple, Union |
| |
|
| | import av |
| | import strawberry |
| | from app_conf import ( |
| | DATA_PATH, |
| | DEFAULT_VIDEO_PATH, |
| | MAX_UPLOAD_VIDEO_DURATION, |
| | UPLOADS_PATH, |
| | UPLOADS_PREFIX, |
| | ) |
| | from data.data_types import ( |
| | AddPointsInput, |
| | CancelPropagateInVideo, |
| | CancelPropagateInVideoInput, |
| | ClearPointsInFrameInput, |
| | ClearPointsInVideo, |
| | ClearPointsInVideoInput, |
| | CloseSession, |
| | CloseSessionInput, |
| | RemoveObjectInput, |
| | RLEMask, |
| | RLEMaskForObject, |
| | RLEMaskListOnFrame, |
| | StartSession, |
| | StartSessionInput, |
| | Video, |
| | ) |
| | from data.loader import get_video |
| | from data.store import get_videos |
| | from data.transcoder import get_video_metadata, transcode, VideoMetadata |
| | from inference.data_types import ( |
| | AddPointsRequest, |
| | CancelPropagateInVideoRequest, |
| | CancelPropagateInVideoRequest, |
| | ClearPointsInFrameRequest, |
| | ClearPointsInVideoRequest, |
| | CloseSessionRequest, |
| | RemoveObjectRequest, |
| | StartSessionRequest, |
| | ) |
| | from inference.predictor import InferenceAPI |
| | from strawberry import relay |
| | from strawberry.file_uploads import Upload |
| |
|
| |
|
| | @strawberry.type |
| | class Query: |
| |
|
| | @strawberry.field |
| | def default_video(self) -> Video: |
| | """ |
| | Return the default video. |
| | |
| | The default video can be set with the DEFAULT_VIDEO_PATH environment |
| | variable. It will return the video that matches this path. If no video |
| | is found, it will return the first video. |
| | """ |
| | all_videos = get_videos() |
| |
|
| | |
| | |
| | for _, v in all_videos.items(): |
| | if v.path == DEFAULT_VIDEO_PATH: |
| | return v |
| |
|
| | |
| | return next(iter(all_videos.values())) |
| |
|
| | @relay.connection(relay.ListConnection[Video]) |
| | def videos( |
| | self, |
| | ) -> Iterable[Video]: |
| | """ |
| | Return all available videos. |
| | """ |
| | all_videos = get_videos() |
| | return all_videos.values() |
| |
|
| |
|
| | @strawberry.type |
| | class Mutation: |
| |
|
| | @strawberry.mutation |
| | def upload_video( |
| | self, |
| | file: Upload, |
| | start_time_sec: Optional[float] = None, |
| | duration_time_sec: Optional[float] = None, |
| | ) -> Video: |
| | """ |
| | Receive a video file and store it in the configured S3 bucket. |
| | """ |
| | max_time = MAX_UPLOAD_VIDEO_DURATION |
| | filepath, file_key, vm = process_video( |
| | file, |
| | max_time=max_time, |
| | start_time_sec=start_time_sec, |
| | duration_time_sec=duration_time_sec, |
| | ) |
| |
|
| | video = get_video( |
| | filepath, |
| | UPLOADS_PATH, |
| | file_key=file_key, |
| | width=vm.width, |
| | height=vm.height, |
| | generate_poster=False, |
| | ) |
| | return video |
| |
|
| | @strawberry.mutation |
| | def start_session( |
| | self, input: StartSessionInput, info: strawberry.Info |
| | ) -> StartSession: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = StartSessionRequest( |
| | type="start_session", |
| | path=f"{DATA_PATH}/{input.path}", |
| | ) |
| |
|
| | response = inference_api.start_session(request=request) |
| |
|
| | return StartSession(session_id=response.session_id) |
| |
|
| | @strawberry.mutation |
| | def close_session( |
| | self, input: CloseSessionInput, info: strawberry.Info |
| | ) -> CloseSession: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = CloseSessionRequest( |
| | type="close_session", |
| | session_id=input.session_id, |
| | ) |
| | response = inference_api.close_session(request) |
| | return CloseSession(success=response.success) |
| |
|
| | @strawberry.mutation |
| | def add_points( |
| | self, input: AddPointsInput, info: strawberry.Info |
| | ) -> RLEMaskListOnFrame: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = AddPointsRequest( |
| | type="add_points", |
| | session_id=input.session_id, |
| | frame_index=input.frame_index, |
| | object_id=input.object_id, |
| | points=input.points, |
| | labels=input.labels, |
| | clear_old_points=input.clear_old_points, |
| | ) |
| | reponse = inference_api.add_points(request) |
| |
|
| | return RLEMaskListOnFrame( |
| | frame_index=reponse.frame_index, |
| | rle_mask_list=[ |
| | RLEMaskForObject( |
| | object_id=r.object_id, |
| | rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), |
| | ) |
| | for r in reponse.results |
| | ], |
| | ) |
| |
|
| | @strawberry.mutation |
| | def remove_object( |
| | self, input: RemoveObjectInput, info: strawberry.Info |
| | ) -> List[RLEMaskListOnFrame]: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = RemoveObjectRequest( |
| | type="remove_object", session_id=input.session_id, object_id=input.object_id |
| | ) |
| |
|
| | response = inference_api.remove_object(request) |
| |
|
| | return [ |
| | RLEMaskListOnFrame( |
| | frame_index=res.frame_index, |
| | rle_mask_list=[ |
| | RLEMaskForObject( |
| | object_id=r.object_id, |
| | rle_mask=RLEMask( |
| | counts=r.mask.counts, size=r.mask.size, order="F" |
| | ), |
| | ) |
| | for r in res.results |
| | ], |
| | ) |
| | for res in response.results |
| | ] |
| |
|
| | @strawberry.mutation |
| | def clear_points_in_frame( |
| | self, input: ClearPointsInFrameInput, info: strawberry.Info |
| | ) -> RLEMaskListOnFrame: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = ClearPointsInFrameRequest( |
| | type="clear_points_in_frame", |
| | session_id=input.session_id, |
| | frame_index=input.frame_index, |
| | object_id=input.object_id, |
| | ) |
| |
|
| | response = inference_api.clear_points_in_frame(request) |
| |
|
| | return RLEMaskListOnFrame( |
| | frame_index=response.frame_index, |
| | rle_mask_list=[ |
| | RLEMaskForObject( |
| | object_id=r.object_id, |
| | rle_mask=RLEMask(counts=r.mask.counts, size=r.mask.size, order="F"), |
| | ) |
| | for r in response.results |
| | ], |
| | ) |
| |
|
| | @strawberry.mutation |
| | def clear_points_in_video( |
| | self, input: ClearPointsInVideoInput, info: strawberry.Info |
| | ) -> ClearPointsInVideo: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = ClearPointsInVideoRequest( |
| | type="clear_points_in_video", |
| | session_id=input.session_id, |
| | ) |
| | response = inference_api.clear_points_in_video(request) |
| | return ClearPointsInVideo(success=response.success) |
| |
|
| | @strawberry.mutation |
| | def cancel_propagate_in_video( |
| | self, input: CancelPropagateInVideoInput, info: strawberry.Info |
| | ) -> CancelPropagateInVideo: |
| | inference_api: InferenceAPI = info.context["inference_api"] |
| |
|
| | request = CancelPropagateInVideoRequest( |
| | type="cancel_propagate_in_video", |
| | session_id=input.session_id, |
| | ) |
| | response = inference_api.cancel_propagate_in_video(request) |
| | return CancelPropagateInVideo(success=response.success) |
| |
|
| |
|
| | def get_file_hash(video_path_or_file) -> str: |
| | if isinstance(video_path_or_file, str): |
| | with open(video_path_or_file, "rb") as in_f: |
| | result = hashlib.sha256(in_f.read()).hexdigest() |
| | else: |
| | video_path_or_file.seek(0) |
| | result = hashlib.sha256(video_path_or_file.read()).hexdigest() |
| | return result |
| |
|
| |
|
| | def _get_start_sec_duration_sec( |
| | start_time_sec: Union[float, None], |
| | duration_time_sec: Union[float, None], |
| | max_time: float, |
| | ) -> Tuple[float, float]: |
| | default_seek_t = int(os.environ.get("VIDEO_ENCODE_SEEK_TIME", "0")) |
| | if start_time_sec is None: |
| | start_time_sec = default_seek_t |
| |
|
| | if duration_time_sec is not None: |
| | duration_time_sec = min(duration_time_sec, max_time) |
| | else: |
| | duration_time_sec = max_time |
| | return start_time_sec, duration_time_sec |
| |
|
| |
|
| | def process_video( |
| | file: Upload, |
| | max_time: float, |
| | start_time_sec: Optional[float] = None, |
| | duration_time_sec: Optional[float] = None, |
| | ) -> Tuple[Optional[str], str, str, VideoMetadata]: |
| | """ |
| | Process file upload including video trimming and content moderation checks. |
| | |
| | Returns the filepath, s3_file_key, hash & video metaedata as a tuple. |
| | """ |
| | with tempfile.TemporaryDirectory() as tempdir: |
| | in_path = f"{tempdir}/in.mp4" |
| | out_path = f"{tempdir}/out.mp4" |
| | with open(in_path, "wb") as in_f: |
| | in_f.write(file.read()) |
| |
|
| | try: |
| | video_metadata = get_video_metadata(in_path) |
| | except av.InvalidDataError: |
| | raise Exception("not valid video file") |
| |
|
| | if video_metadata.num_video_streams == 0: |
| | raise Exception("video container does not contain a video stream") |
| | if video_metadata.width is None or video_metadata.height is None: |
| | raise Exception("video container does not contain width or height metadata") |
| |
|
| | if video_metadata.duration_sec in (None, 0): |
| | raise Exception("video container does time duration metadata") |
| |
|
| | start_time_sec, duration_time_sec = _get_start_sec_duration_sec( |
| | max_time=max_time, |
| | start_time_sec=start_time_sec, |
| | duration_time_sec=duration_time_sec, |
| | ) |
| |
|
| | |
| | |
| | transcode( |
| | in_path, |
| | out_path, |
| | video_metadata, |
| | seek_t=start_time_sec, |
| | duration_time_sec=duration_time_sec, |
| | ) |
| |
|
| | os.remove(in_path) |
| |
|
| | out_video_metadata = get_video_metadata(out_path) |
| | if out_video_metadata.num_video_frames == 0: |
| | raise Exception( |
| | "transcode produced empty video; check seek time or your input video" |
| | ) |
| |
|
| | filepath = None |
| | file_key = None |
| | with open(out_path, "rb") as file_data: |
| | file_hash = get_file_hash(file_data) |
| | file_data.seek(0) |
| |
|
| | file_key = UPLOADS_PREFIX + "/" + f"{file_hash}.mp4" |
| | filepath = os.path.join(UPLOADS_PATH, f"{file_hash}.mp4") |
| |
|
| | assert filepath is not None and file_key is not None |
| | shutil.move(out_path, filepath) |
| |
|
| | return filepath, file_key, out_video_metadata |
| |
|
| |
|
| | schema = strawberry.Schema( |
| | query=Query, |
| | mutation=Mutation, |
| | ) |
| |
|