LD3 / models /condition_loader.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import random
import torch
import json
class RandomNumberIterator:
def __init__(self, model, scale, batch_size, n_classes=1000):
self.model = model
self.scale = scale
self.batch_size = batch_size
self.n_classes = n_classes
def __iter__(self):
return self
def __next__(self):
label = torch.LongTensor([random.randint(0, self.n_classes - 1) for _ in range(self.batch_size)]).to(self.model.device)
conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label})
if self.scale != 1.0:
conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)})
else:
conditioned_unconditioning = None
return conditioning, conditioned_unconditioning
class UniformNumberIterator:
def __init__(self, model, scale, batch_size, num_samples_per_class, n_classes=1000):
self.model = model
self.scale = scale
self.batch_size = batch_size
self.num_samples_per_class = num_samples_per_class
self.n_classes = n_classes
self.current_value = 0
self.current_num_cls_sample = 0
def __iter__(self):
return self
def __next__(self):
# Prepare the batch with the current value
batch = [self.current_value] * self.batch_size
self.current_num_cls_sample += self.batch_size
if self.current_num_cls_sample >= self.num_samples_per_class:
# Update the current value, cycling through 0 to 1000
self.current_value = (self.current_value + 1) % self.n_classes
self.current_num_cls_sample = 0
label = torch.LongTensor(batch).to(self.model.device)
conditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: label})
if self.scale != 1.0:
conditioned_unconditioning = self.model.get_learned_conditioning({self.model.cond_stage_key: torch.LongTensor([self.n_classes] * self.batch_size).to(self.model.device)})
else:
conditioned_unconditioning = None
return conditioning, conditioned_unconditioning
class TextFileIterator:
def __init__(self, model, scale, file_path, batch_size, max_prompts=None, n_samples_per_prompt=1):
self.model = model
self.scale = scale
self.unconditional_conditioning = self.model.get_learned_conditioning([""])
self.file_path = file_path
self.batch_size = batch_size
self.max_prompts = max_prompts
self.n_samples_per_prompt = n_samples_per_prompt
self.prompt_index = 0
self.prompts = self._load_prompts()
def __iter__(self):
return self
def __next__(self):
if self.prompt_index >= len(self.prompts):
raise StopIteration
batch_prompts = self.prompts[self.prompt_index:self.prompt_index + self.batch_size]
self.prompt_index += len(batch_prompts)
conditioning = self.model.get_learned_conditioning(batch_prompts)
conditioned_unconditioning = self.unconditional_conditioning.repeat(len(batch_prompts), 1, 1)
return conditioning, conditioned_unconditioning
def _load_prompts(self):
try:
prompts = []
if self.file_path.endswith('json'):
with open(self.file_path, 'r', encoding='utf-8') as file:
mscoco_data = json.load(file)
for annotation in mscoco_data['annotations']:
prompts.append(annotation['caption'])
else:
for prompt in open(self.file_path):
prompts = [prompt.strip() for prompt in open(self.file_path)]
if self.max_prompts is not None:
prompts = prompts[:self.max_prompts]
prompts = [prompt for prompt in prompts for _ in range(self.n_samples_per_prompt)]
return prompts
except FileNotFoundError:
print(f"File not found: {self.file_path}")
return []
except IOError as e:
print(f"Error reading file {self.file_path}: {e}")
return []
except json.JSONDecodeError as e:
print(f"Error decoding JSON in file {self.file_path}: {e}")
return []