Spaces:
Runtime error
Runtime error
import os | |
import random | |
import re | |
import PIL | |
import torch | |
import tqdm | |
import numpy as np | |
from PIL import Image | |
from .hnutil import get_closest | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from modules import shared, devices | |
from modules.textual_inversion.dataset import DatasetEntry, re_numbers_at_start | |
class PersonalizedBase(Dataset): | |
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): | |
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None | |
self.placeholder_token = placeholder_token | |
self.batch_size = batch_size | |
self.width = width | |
self.height = height | |
self.flip = transforms.RandomHorizontalFlip(p=flip_p) | |
self.dataset = [] | |
with open(template_file, "r") as file: | |
lines = [x.strip() for x in file.readlines()] | |
self.lines = lines | |
assert data_root, 'dataset directory not specified' | |
assert os.path.isdir(data_root), "Dataset directory doesn't exist" | |
assert os.listdir(data_root), "Dataset directory is empty" | |
cond_model = shared.sd_model.cond_stage_model | |
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] * batch_size | |
print("Preparing dataset...") | |
for path in tqdm.tqdm(self.image_paths): | |
try: | |
image = Image.open(path).convert('RGB') | |
w, h = image.size | |
r = max(1, w / self.width, h / self.height) # divide by this | |
amp = min(self.width / w, self.height / h) # if amp < 1, then ignore, else, multiply. | |
if amp > 1: | |
w, h = w * amp, h * amp | |
w, h = int(w/r), int(h/r) | |
w, h = get_closest(w), get_closest(h) | |
image = image.resize((w,h), PIL.Image.LANCZOS) | |
except Exception: | |
continue | |
text_filename = os.path.splitext(path)[0] + ".txt" | |
filename = os.path.basename(path) | |
if os.path.exists(text_filename): | |
with open(text_filename, "r", encoding="utf8") as file: | |
filename_text = file.read() | |
else: | |
filename_text = os.path.splitext(filename)[0] | |
filename_text = re.sub(re_numbers_at_start, '', filename_text) | |
if re_word: | |
tokens = re_word.findall(filename_text) | |
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) | |
npimage = np.array(image).astype(np.uint8) | |
npimage = (npimage / 127.5 - 1.0).astype(np.float32) | |
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) | |
torchdata = torch.moveaxis(torchdata, 2, 0) | |
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() | |
init_latent = init_latent.to(devices.cpu) | |
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) | |
if include_cond: | |
entry.cond_text = self.create_text(filename_text) | |
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) | |
self.dataset.append(entry) | |
assert len(self.dataset) > 0, "No images have been found in the dataset." | |
self.length = len(self.dataset) * repeats // batch_size | |
self.dataset_length = len(self.dataset) | |
self.indexes = None | |
self.random = np.random.default_rng(42) | |
self.shuffle() | |
def shuffle(self): | |
self.indexes = self.random.permutation(self.dataset_length) | |
def create_text(self, filename_text): | |
text = random.choice(self.lines) | |
text = text.replace("[name]", self.placeholder_token) | |
tags = filename_text.split(',') | |
if shared.opts.tag_drop_out != 0: | |
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] | |
if shared.opts.shuffle_tags: | |
random.shuffle(tags) | |
text = text.replace("[filewords]", ','.join(tags)) | |
return text | |
def __len__(self): | |
return self.length | |
def __getitem__(self, i): | |
res = [] | |
for j in range(self.batch_size): | |
position = i * self.batch_size + j | |
if position % len(self.indexes) == 0: | |
self.shuffle() | |
index = self.indexes[position % len(self.indexes)] | |
entry = self.dataset[index] | |
if entry.cond is None: | |
entry.cond_text = self.create_text(entry.filename_text) | |
res.append(entry) | |
return res |