File size: 2,689 Bytes
954caab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
from pathlib import Path

import torch
from diffusers import DiffusionPipeline

from visual_anagrams.views import get_views
from visual_anagrams.samplers import sample_stage_1, sample_stage_2
from visual_anagrams.utils import add_args, save_illusion, save_metadata



# Parse args
parser = argparse.ArgumentParser()
parser = add_args(parser)
args = parser.parse_args()

# Do admin stuff
save_dir = Path(args.save_dir) / args.name
save_dir.mkdir(exist_ok=True, parents=True)

# Make models
stage_1 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-I-M-v1.0",
                variant="fp16",
                torch_dtype=torch.float16)
stage_2 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-II-M-v1.0",
                text_encoder=None,
                variant="fp16",
                torch_dtype=torch.float16,
            )
stage_1.enable_model_cpu_offload()
stage_2.enable_model_cpu_offload()
stage_1 = stage_1.to(args.device)
stage_2 = stage_2.to(args.device)

# Get prompt embeddings
prompt_embeds = [stage_1.encode_prompt(f'{args.style} {p}'.strip()) for p in args.prompts]
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)  # These are just null embeds

# Get views
views = get_views(args.views)

# Save metadata
save_metadata(views, args, save_dir)

# Sample illusions
for i in range(args.num_samples):
    # Admin stuff
    generator = torch.manual_seed(args.seed + i)
    sample_dir = save_dir / f'{i:04}'
    sample_dir.mkdir(exist_ok=True, parents=True)

    # Sample 64x64 image
    image = sample_stage_1(stage_1, 
                           prompt_embeds,
                           negative_prompt_embeds,
                           views,
                           num_inference_steps=args.num_inference_steps,
                           guidance_scale=args.guidance_scale,
                           reduction=args.reduction,
                           generator=generator)
    save_illusion(image, views, sample_dir)

    # Sample 256x256 image, by upsampling 64x64 image
    image = sample_stage_2(stage_2,
                           image,
                           prompt_embeds,
                           negative_prompt_embeds, 
                           views,
                           num_inference_steps=args.num_inference_steps,
                           guidance_scale=args.guidance_scale,
                           reduction=args.reduction,
                           noise_level=args.noise_level,
                           generator=generator)
    save_illusion(image, views, sample_dir)