Spaces:
Running
Running
import torch | |
import numpy as np | |
import random | |
import os | |
from typing import Tuple, Optional | |
def get_prompt_template(mode: str = 'default') -> Tuple[str, int, int]: | |
''' | |
Generate a prompt template based on the specified mode. | |
Args: | |
mode (str, optional): The mode for selecting the prompt template. Default is 'default'. | |
Returns: | |
Tuple[str, int, int]: A tuple containing the generated prompt template, the position of the placeholder '{}', | |
and the length of the prompt. | |
Notes: | |
If the mode is 'random', a random prompt template is chosen from a predefined list. | |
''' | |
prompt_template = 'A photo of {}' | |
if mode == 'random': | |
prompt_templates = [ | |
'a photo of a {}', 'a photograph of a {}', 'an image of a {}', '{}', | |
'a cropped photo of a {}', 'a good photo of a {}', 'a photo of one {}', | |
'a bad photo of a {}', 'a photo of the {}', 'a photo of {}', 'a blurry photo of a {}', | |
'a picture of a {}', 'a photo of a scene where {}' | |
] | |
prompt_template = random.choice(prompt_templates) | |
# Calculate prompt length and text position | |
prompt_length = 1 + len(prompt_template.split(' ')) + 1 - 1 # eos, sos => 1 + 1, {} => -1 | |
text_pos_at_prompt = 1 + prompt_template.split(' ').index('{}') | |
return prompt_template, text_pos_at_prompt, prompt_length | |
# Reproducibility | |
def fix_seed(seed: int = 0) -> None: | |
''' | |
Set seeds for random number generators to ensure reproducibility. | |
Args: | |
seed (int, optional): The seed value. Default is 0. | |
''' | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) # multi-GPU | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
def seed_worker(worker_id: int) -> None: | |
''' | |
Set a seed for a worker process to ensure reproducibility in PyTorch DataLoader. | |
Args: | |
worker_id (int): The ID of the worker process. | |
''' | |
worker_seed = torch.initial_seed() % 2**32 | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) |