|
|
|
|
|
"""Python script to generate TFRecords of SequenceExample from raw videos.""" |
|
|
|
import contextlib |
|
import math |
|
import os |
|
import cv2 |
|
from typing import Dict, Optional, Sequence |
|
import moviepy.editor |
|
from absl import app |
|
from absl import flags |
|
import ffmpeg |
|
import numpy as np |
|
import pandas as pd |
|
import tensorflow as tf |
|
|
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
flags.DEFINE_string("csv_path", "fakeavceleb_1k.csv", "Input csv") |
|
flags.DEFINE_string("output_path", "fakeavceleb_tfrec", "Tfrecords output path.") |
|
flags.DEFINE_string("video_root_path", "./", |
|
"Root directory containing the raw videos.") |
|
flags.DEFINE_integer( |
|
"num_shards", 4, "Number of shards to output, -1 means" |
|
"it will automatically adapt to the sqrt(num_examples).") |
|
flags.DEFINE_bool("decode_audio", False, "Whether or not to decode the audio") |
|
flags.DEFINE_bool("shuffle_csv", False, "Whether or not to shuffle the csv.") |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
_JPEG_HEADER = b"\xff\xd8" |
|
|
|
|
|
@contextlib.contextmanager |
|
def _close_on_exit(writers): |
|
"""Call close on all writers on exit.""" |
|
try: |
|
yield writers |
|
finally: |
|
for writer in writers: |
|
writer.close() |
|
|
|
|
|
def add_float_list(key: str, values: Sequence[float], |
|
sequence: tf.train.SequenceExample): |
|
sequence.feature_lists.feature_list[key].feature.add( |
|
).float_list.value[:] = values |
|
|
|
|
|
def add_bytes_list(key: str, values: Sequence[bytes], |
|
sequence: tf.train.SequenceExample): |
|
sequence.feature_lists.feature_list[key].feature.add().bytes_list.value[:] = values |
|
|
|
|
|
def add_int_list(key: str, values: Sequence[int], |
|
sequence: tf.train.SequenceExample): |
|
sequence.feature_lists.feature_list[key].feature.add().int64_list.value[:] = values |
|
|
|
|
|
def set_context_int_list(key: str, value: Sequence[int], |
|
sequence: tf.train.SequenceExample): |
|
sequence.context.feature[key].int64_list.value[:] = value |
|
|
|
|
|
def set_context_bytes(key: str, value: bytes, |
|
sequence: tf.train.SequenceExample): |
|
sequence.context.feature[key].bytes_list.value[:] = (value,) |
|
|
|
def set_context_bytes_list(key: str, value: Sequence[bytes], |
|
sequence: tf.train.SequenceExample): |
|
sequence.context.feature[key].bytes_list.value[:] = value |
|
|
|
|
|
def set_context_float(key: str, value: float, |
|
sequence: tf.train.SequenceExample): |
|
sequence.context.feature[key].float_list.value[:] = (value,) |
|
|
|
|
|
def set_context_int(key: str, value: int, sequence: tf.train.SequenceExample): |
|
sequence.context.feature[key].int64_list.value[:] = (value,) |
|
|
|
|
|
def extract_frames(video_path, fps = 10, min_resize = 256): |
|
'''Load n number of frames from a video''' |
|
v_cap = cv2.VideoCapture(video_path) |
|
v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
if fps is None: |
|
sample = np.arange(0, v_len) |
|
else: |
|
sample = np.linspace(0, v_len - 1, fps).astype(int) |
|
|
|
frames = [] |
|
for j in range(v_len): |
|
success = v_cap.grab() |
|
if j in sample: |
|
success, frame = v_cap.retrieve() |
|
if not success: |
|
continue |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = cv2.resize(frame, (min_resize, min_resize)) |
|
frames.append(frame) |
|
|
|
v_cap.release() |
|
frame_np = np.stack(frames) |
|
return frame_np.tobytes() |
|
|
|
def extract_audio(video_path: str, |
|
sampling_rate: int = 16_000): |
|
"""Extract raw mono audio float list from video_path with ffmpeg.""" |
|
video = moviepy.editor.VideoFileClip(video_path) |
|
audio = video.audio.to_soundarray() |
|
|
|
audio = audio[:, 0] |
|
|
|
return np.array(audio) |
|
|
|
|
|
|
|
|
|
def serialize_example(video_path: str, label_name: str, label_map: Optional[Dict[str, int]] = None): |
|
|
|
seq_example = tf.train.SequenceExample() |
|
|
|
imgs_encoded = extract_frames(video_path, fps = 10) |
|
|
|
audio = extract_audio(video_path) |
|
|
|
set_context_bytes(f'image/encoded', imgs_encoded, seq_example) |
|
set_context_bytes("video_path", video_path.encode(), seq_example) |
|
set_context_bytes("WAVEFORM/feature/floats", audio.tobytes(), seq_example) |
|
set_context_int("clip/label/index", label_map[label_name], seq_example) |
|
set_context_bytes("clip/label/text", label_name.encode(), seq_example) |
|
return seq_example |
|
|
|
|
|
def main(argv): |
|
del argv |
|
|
|
input_csv = pd.read_csv(FLAGS.csv_path) |
|
if FLAGS.num_shards == -1: |
|
num_shards = int(math.sqrt(len(input_csv))) |
|
else: |
|
num_shards = FLAGS.num_shards |
|
|
|
basename = os.path.splitext(os.path.basename(FLAGS.csv_path))[0] |
|
shard_names = [ |
|
os.path.join(FLAGS.output_path, f"{basename}-{i:05d}-of-{num_shards:05d}") |
|
for i in range(num_shards) |
|
] |
|
writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] |
|
|
|
if "label" in input_csv: |
|
unique_labels = list(set(input_csv["label"].values)) |
|
l_map = {unique_labels[i]: i for i in range(len(unique_labels))} |
|
else: |
|
l_map = None |
|
|
|
if FLAGS.shuffle_csv: |
|
input_csv = input_csv.sample(frac=1) |
|
with _close_on_exit(writers) as writers: |
|
row_count = 0 |
|
for row in input_csv.itertuples(): |
|
index = row[0] |
|
v = row[1] |
|
if os.name == 'posix': |
|
v = v.str.replace('\\', '/') |
|
l = row[2] |
|
row_count += 1 |
|
print("Processing example %d of %d (%d%%) \r" %(row_count, len(input_csv), row_count * 100 / len(input_csv)), end="") |
|
seq_ex = serialize_example(video_path = v, label_name = l,label_map = l_map) |
|
writers[index % len(writers)].write(seq_ex.SerializeToString()) |
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|