from models.structure.Advanced_Conditional_Unet import Unet from diffusers import DDPMScheduler import torch import os import glob from torchvision import transforms import pathlib from safetensors.torch import load_model, save_model import time as tm denoising_timesteps = 4000 image_size = 128 channels = 3 device = "cuda" if torch.cuda.is_available() else "cpu" device = "mps" if torch.backends.mps.is_available() else device model = Unet( dim=image_size, channels=channels, dim_mults=(1, 2, 4, 8), use_convnext=False, ).to(device) results_folder = pathlib.Path("models") checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st")) checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt")) if checkpoint_files_st: # Sort the list of matching files by modification time (newest first) checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True) # Select the newest file checkpoint_files = checkpoint_files_st[0] # Now, newest_model_file contains the path to the newest "model" file load_model(model, checkpoint_files) model.eval() print("Loaded model from checkpoint", checkpoint_files) elif checkpoint_files_pt: # Sort the list of matching files by modification time (newest first) checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True) # Select the newest file checkpoint_files = checkpoint_files_pt[0] # Now, newest_model_file contains the path to the newest "model" file checkpoint = torch.load(checkpoint_files, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) epoch = checkpoint["epoch"] model.eval() print("Loaded model from checkpoint", checkpoint_files) if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists(): save_model(model, results_folder / "model-epoch_{}.st".format(epoch)) print("Saved model as a safetensor", results_folder) else: raise Exception("No model files found in the folder.") def sample(sketch, scribbles, sampling_steps, seed_nr, progress): torch.manual_seed(seed_nr) noise_scheduler = DDPMScheduler( num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2" ) noise_scheduler.set_timesteps(sampling_steps, device=device) sketch = sketch.to(device) scribbles = scribbles.to(device) sketch = sketch.unsqueeze(0) scribbles = scribbles.unsqueeze(0) with torch.no_grad(): b = sketch.shape[0] noise_for_plain = torch.randn_like(sketch, device=device) for t in progress.tqdm( noise_scheduler.timesteps, desc="Painting 🖌🖌🖌", ): noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to( device ) time = t.expand( b, ).to(device) plain_noise_pred = model( x=noise_for_plain, time=time, implicit_conditioning=scribbles, explicit_conditioning=sketch, ) noise_for_plain = noise_scheduler.step( plain_noise_pred, t.long(), noise_for_plain, ).prev_sample tm.sleep(0.01) sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1) image = transforms.ToPILImage()(sample[0].cpu()) image.save("results/sample.png") return image