anime_diffusion / load_model.py
pawlo2013's picture
fixed code readability
5086590
raw
history blame
3.49 kB
from models.structure.Advanced_Conditional_Unet import Unet
from diffusers import DDPMScheduler
import torch
import os
import glob
from torchvision import transforms
import pathlib
from safetensors.torch import load_model, save_model
import time as tm
denoising_timesteps = 4000
image_size = 128
channels = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4, 8),
use_convnext=False,
).to(device)
results_folder = pathlib.Path("models")
checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st"))
checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt"))
if checkpoint_files_st:
# Sort the list of matching files by modification time (newest first)
checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True)
# Select the newest file
checkpoint_files = checkpoint_files_st[0]
# Now, newest_model_file contains the path to the newest "model" file
load_model(model, checkpoint_files)
model.eval()
print("Loaded model from checkpoint", checkpoint_files)
elif checkpoint_files_pt:
# Sort the list of matching files by modification time (newest first)
checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True)
# Select the newest file
checkpoint_files = checkpoint_files_pt[0]
# Now, newest_model_file contains the path to the newest "model" file
checkpoint = torch.load(checkpoint_files, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
epoch = checkpoint["epoch"]
model.eval()
print("Loaded model from checkpoint", checkpoint_files)
if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists():
save_model(model, results_folder / "model-epoch_{}.st".format(epoch))
print("Saved model as a safetensor", results_folder)
else:
raise Exception("No model files found in the folder.")
def sample(sketch, scribbles, sampling_steps, seed_nr, progress):
torch.manual_seed(seed_nr)
noise_scheduler = DDPMScheduler(
num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2"
)
noise_scheduler.set_timesteps(sampling_steps, device=device)
sketch = sketch.to(device)
scribbles = scribbles.to(device)
sketch = sketch.unsqueeze(0)
scribbles = scribbles.unsqueeze(0)
with torch.no_grad():
b = sketch.shape[0]
noise_for_plain = torch.randn_like(sketch, device=device)
for t in progress.tqdm(
noise_scheduler.timesteps,
desc="Painting πŸ–ŒπŸ–ŒπŸ–Œ",
):
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
device
)
time = t.expand(
b,
).to(device)
plain_noise_pred = model(
x=noise_for_plain,
time=time,
implicit_conditioning=scribbles,
explicit_conditioning=sketch,
)
noise_for_plain = noise_scheduler.step(
plain_noise_pred,
t.long(),
noise_for_plain,
).prev_sample
tm.sleep(0.01)
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
image = transforms.ToPILImage()(sample[0].cpu())
image.save("results/sample.png")
return image