LittleApple-fp16's picture
Upload 88 files
4f8ad24
import glob
import logging
import os
from typing import Iterator
from urllib.error import HTTPError
from tqdm.auto import tqdm
from .base import BaseDataSource, EmptySource
from ..model import ImageItem
try:
import av
import av.datasets
from av.error import InvalidDataError
except (ImportError, ModuleNotFoundError):
av = None
class VideoSource(BaseDataSource):
def __init__(self, video_file):
if av is None:
raise ImportError(f'pyav not installed, {self.__class__.__name__} is unavailable. '
f'Please install this with `pip install git+https://github.com/deepghs/waifuc.git@main#egg=waifuc[video]` to solve this problem.')
self.video_file = video_file
def _iter(self) -> Iterator[ImageItem]:
try:
content = av.datasets.curated(self.video_file)
except HTTPError:
logging.error(f'Video {self.video_file!r} is invalid, skipped')
return
try:
with av.open(content) as container:
stream = container.streams.video[0]
stream.codec_context.skip_frame = "NONKEY"
for i, frame in enumerate(tqdm(
container.decode(stream),
desc=f'Video Extracting - {os.path.basename(self.video_file)}')):
meta = {
'video_file': self.video_file,
'time': frame.time,
'index': i,
}
yield ImageItem(frame.to_image(), meta)
except (InvalidDataError, av.error.ValueError, IndexError) as err:
logging.warning(f'Video extraction skipped due to error - {err!r}')
@classmethod
def from_directory(cls, directory: str, recursive: bool = True) -> BaseDataSource:
if recursive:
files = glob.glob(os.path.join(glob.escape(directory), '**', '*'), recursive=True)
else:
files = glob.glob(os.path.join(glob.escape(directory), '*'))
source = EmptySource()
for file in files:
if os.path.isfile(file) and os.access(file, os.R_OK):
source = source + cls(file)
return source