Spaces:
Runtime error
Runtime error
import pickle | |
from PTI.utils.ImagesDataset import ImagesDataset, Image2Dataset | |
import torch | |
from PTI.utils.models_utils import load_old_G | |
from PTI.utils.alignment import align_face | |
from PTI.training.coaches.single_id_coach import SingleIDCoach | |
from PTI.configs import global_config, paths_config | |
import dlib | |
import os | |
from torchvision.transforms import transforms | |
from torch.utils.data import DataLoader | |
from string import ascii_uppercase | |
import sys | |
from pathlib import Path | |
sys.path.append(".") | |
# sys.path.append('PTI/') | |
# sys.path.append('PTI/training/') | |
def run_PTI(img, run_name): | |
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
# os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices | |
global_config.run_name = run_name | |
global_config.pivotal_training_steps = 1 | |
global_config.training_step = 1 | |
embedding_dir_path = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}" | |
os.makedirs(embedding_dir_path, exist_ok=True) | |
# dataset = ImagesDataset(paths_config.input_data_path, transforms.Compose([ | |
# transforms.ToTensor(), | |
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])) | |
G = load_old_G() | |
IMAGE_SIZE = 1024 | |
predictor = dlib.shape_predictor(paths_config.dlib) | |
aligned_image = align_face(img, predictor=predictor, output_size=IMAGE_SIZE) | |
img = aligned_image.resize([G.img_resolution, G.img_resolution]) | |
dataset = Image2Dataset(img) | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False) | |
coach = SingleIDCoach(dataloader, use_wandb=False) | |
new_G, w_pivot = coach.train() | |
return new_G, w_pivot | |
def export_updated_pickle(new_G, out_path, run_name): | |
image_name = "customIMG" | |
with open(paths_config.stylegan2_ada_ffhq, "rb") as f: | |
old_G = pickle.load(f)["G_ema"].cuda() | |
embedding = Path(f"{paths_config.checkpoints_dir}/model_{run_name}_{image_name}.pt") | |
with open(embedding, "rb") as f_new: | |
new_G = torch.load(f_new).cuda() | |
print("Exporting large updated pickle based off new generator and ffhq.pkl") | |
with open(paths_config.stylegan2_ada_ffhq, "rb") as f: | |
d = pickle.load(f) | |
old_G = d["G_ema"].cuda() # tensor | |
old_D = d["D"].eval().requires_grad_(False).cpu() | |
tmp = {} | |
tmp["G"] = old_G.eval().requires_grad_(False).cpu() | |
tmp["G_ema"] = new_G.eval().requires_grad_(False).cpu() | |
tmp["D"] = old_D | |
tmp["training_set_kwargs"] = None | |
tmp["augment_pipe"] = None | |
with open(out_path, "wb") as f: | |
pickle.dump(tmp, f) | |
# delete | |
embedding.unlink() | |
# if __name__ == '__main__': | |
# from PIL import Image | |
# img = Image.open('PTI/test/test.jpg') | |
# new_G, w_pivot = run_PTI(img, use_wandb=False, use_multi_id_training=False) | |
# out_path = f'checkpoints/stylegan2_custom_512_pytorch.pkl' | |
# export_updated_pickle(new_G, out_path) | |