Spaces:
Running
Running
from models.structure.Advanced_Conditional_Unet import Unet | |
from diffusers import DDPMScheduler | |
import torch | |
import os | |
import glob | |
from torchvision import transforms | |
import pathlib | |
from safetensors.torch import load_model, save_model | |
import time as tm | |
denoising_timesteps = 4000 | |
image_size = 128 | |
channels = 3 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "mps" if torch.backends.mps.is_available() else device | |
model = Unet( | |
dim=image_size, | |
channels=channels, | |
dim_mults=(1, 2, 4, 8), | |
use_convnext=False, | |
).to(device) | |
results_folder = pathlib.Path("models") | |
checkpoint_files_st = glob.glob(str(results_folder / "model-epoch_*.st")) | |
checkpoint_files_pt = glob.glob(str(results_folder / "model-epoch_*.pt")) | |
if checkpoint_files_st: | |
# Sort the list of matching files by modification time (newest first) | |
checkpoint_files_st.sort(key=lambda x: os.path.getmtime(x), reverse=True) | |
# Select the newest file | |
checkpoint_files = checkpoint_files_st[0] | |
# Now, newest_model_file contains the path to the newest "model" file | |
load_model(model, checkpoint_files) | |
model.eval() | |
print("Loaded model from checkpoint", checkpoint_files) | |
elif checkpoint_files_pt: | |
# Sort the list of matching files by modification time (newest first) | |
checkpoint_files_pt.sort(key=lambda x: os.path.getmtime(x), reverse=True) | |
# Select the newest file | |
checkpoint_files = checkpoint_files_pt[0] | |
# Now, newest_model_file contains the path to the newest "model" file | |
checkpoint = torch.load(checkpoint_files, map_location=device) | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
epoch = checkpoint["epoch"] | |
model.eval() | |
print("Loaded model from checkpoint", checkpoint_files) | |
if not pathlib.Path(str(results_folder / "model-epoch_*.st")).exists(): | |
save_model(model, results_folder / "model-epoch_{}.st".format(epoch)) | |
print("Saved model as a safetensor", results_folder) | |
else: | |
raise Exception("No model files found in the folder.") | |
def sample(sketch, scribbles, sampling_steps, seed_nr, progress): | |
torch.manual_seed(seed_nr) | |
noise_scheduler = DDPMScheduler( | |
num_train_timesteps=denoising_timesteps, beta_schedule="squaredcos_cap_v2" | |
) | |
noise_scheduler.set_timesteps(sampling_steps, device=device) | |
sketch = sketch.to(device) | |
scribbles = scribbles.to(device) | |
sketch = sketch.unsqueeze(0) | |
scribbles = scribbles.unsqueeze(0) | |
with torch.no_grad(): | |
b = sketch.shape[0] | |
noise_for_plain = torch.randn_like(sketch, device=device) | |
for t in progress.tqdm( | |
noise_scheduler.timesteps, | |
desc="Painting πππ", | |
): | |
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to( | |
device | |
) | |
time = t.expand( | |
b, | |
).to(device) | |
plain_noise_pred = model( | |
x=noise_for_plain, | |
time=time, | |
implicit_conditioning=scribbles, | |
explicit_conditioning=sketch, | |
) | |
noise_for_plain = noise_scheduler.step( | |
plain_noise_pred, | |
t.long(), | |
noise_for_plain, | |
).prev_sample | |
tm.sleep(0.01) | |
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1) | |
image = transforms.ToPILImage()(sample[0].cpu()) | |
image.save("results/sample.png") | |
return image | |