import gc from time import time_ns import random import numpy as np import torch from typing import Optional from dataclasses import dataclass, field from enum import Enum START_SEGMENT_TEMPLATE = 'START_{}_TOKEN' END_SEGMENT_TEMPLATE = 'END_{}_TOKEN' class CustomTokens(Enum): EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: ' # Preprocessing tokens URL = 'URL_TOKEN' HYPHENATED_URL = 'HYPHENATED_URL_TOKEN' NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN' NUMBER = 'NUMBER_TOKEN' SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN' LONG_WORD = 'LONG_WORD_TOKEN' # Custom YouTube tokens MUSIC = '[Music]' APPLAUSE = '[Applause]' LAUGHTER = '[Laughter]' PROFANITY = 'PROFANITY_TOKEN' # Segment tokens NO_SEGMENT = 'NO_SEGMENT_TOKEN' START_SPONSOR = START_SEGMENT_TEMPLATE.format('SPONSOR') END_SPONSOR = END_SEGMENT_TEMPLATE.format('SPONSOR') START_SELFPROMO = START_SEGMENT_TEMPLATE.format('SELFPROMO') END_SELFPROMO = END_SEGMENT_TEMPLATE.format('SELFPROMO') START_INTERACTION = START_SEGMENT_TEMPLATE.format('INTERACTION') END_INTERACTION = END_SEGMENT_TEMPLATE.format('INTERACTION') BETWEEN_SEGMENTS = 'BETWEEN_SEGMENTS_TOKEN' @classmethod def custom_tokens(cls): return [e.value for e in cls] @classmethod def add_custom_tokens(cls, tokenizer): tokenizer.add_tokens(cls.custom_tokens()) @dataclass class OutputArguments: output_dir: str = field( default='out', metadata={ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.' }, ) checkpoint: Optional[str] = field( default=None, metadata={ 'help': 'Choose the checkpoint/model to train from or test with. Defaults to the latest checkpoint found in `output_dir`.' }, ) models_dir: str = field( default='models', metadata={ 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.' }, ) # classifier_dir: str = field( # default='out', # metadata={ # 'help': 'The output directory where the model predictions and checkpoints will be written to and read from.' # }, # ) def seed_factory(): return time_ns() % (2**32 - 1) @dataclass class GeneralArguments: seed: Optional[int] = field(default_factory=seed_factory, metadata={ 'help': 'Set seed for deterministic training and testing. By default, it uses the current time (results in essentially random results).' }) def __post_init__(self): random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) def device(): return torch.device('cuda' if torch.cuda.is_available() else 'cpu') def reset(): torch.clear_autocast_cache() torch.cuda.empty_cache() gc.collect() print(torch.cuda.memory_summary(device=None, abbreviated=False))