Diffusers
Safetensors
English
AmusedPipeline
art
File size: 4,652 Bytes
5704d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import logging
from diffusers import AmusedPipeline
import os
from peft import PeftModel
from diffusers import UVit2DModel

logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )
    parser.add_argument("--style_descriptor", type=str, default="[V]")
    parser.add_argument(
        "--load_transformer_from",
        type=str,
        required=False,
        default=None,
    )
    parser.add_argument(
        "--load_transformer_lora_from",
        type=str,
        required=False,
        default=None,
    )
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--write_images_to", type=str, required=True)
    args = parser.parse_args()
    return args

def main(args):
    prompts = [
        f"A chihuahua in {args.style_descriptor} style",
        f"A tabby cat in {args.style_descriptor} style",
        f"A portrait of chihuahua in {args.style_descriptor} style",
        f"An apple on the table in {args.style_descriptor} style",
        f"A banana on the table in {args.style_descriptor} style",
        f"A church on the street in {args.style_descriptor} style",
        f"A church in the mountain in {args.style_descriptor} style",
        f"A church in the field in {args.style_descriptor} style",
        f"A church on the beach in {args.style_descriptor} style",
        f"A chihuahua walking on the street in {args.style_descriptor} style",
        f"A tabby cat walking on the street in {args.style_descriptor} style",
        f"A portrait of tabby cat in {args.style_descriptor} style",
        f"An apple on the dish in {args.style_descriptor} style",
        f"A banana on the dish in {args.style_descriptor} style",
        f"A human walking on the street in {args.style_descriptor} style",
        f"A temple on the street in {args.style_descriptor} style",
        f"A temple in the mountain in {args.style_descriptor} style",
        f"A temple in the field in {args.style_descriptor} style",
        f"A temple on the beach in {args.style_descriptor} style",
        f"A chihuahua walking in the forest in {args.style_descriptor} style",
        f"A tabby cat walking in the forest in {args.style_descriptor} style",
        f"A portrait of human face in {args.style_descriptor} style",
        f"An apple on the ground in {args.style_descriptor} style",
        f"A banana on the ground in {args.style_descriptor} style",
        f"A human walking in the forest in {args.style_descriptor} style",
        f"A cabin on the street in {args.style_descriptor} style",
        f"A cabin in the mountain in {args.style_descriptor} style",
        f"A cabin in the field in {args.style_descriptor} style",
        f"A cabin on the beach in {args.style_descriptor} style"
    ]

    logger.warning(f"generating image for {prompts}")

    logger.warning(f"loading models")

    pipe_args = {}

    if args.load_transformer_from is not None:
        pipe_args["transformer"] = UVit2DModel.from_pretrained(args.load_transformer_from)
    
    pipe = AmusedPipeline.from_pretrained(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        revision=args.revision, 
        variant=args.variant,
        **pipe_args
    )

    if args.load_transformer_lora_from is not None:
        pipe.transformer = PeftModel.from_pretrained(
            pipe.transformer, os.path.join(args.load_transformer_from), is_trainable=False
        )

    pipe.to(args.device)

    logger.warning(f"generating images")

    os.makedirs(args.write_images_to, exist_ok=True)

    for prompt_idx in range(0, len(prompts), args.batch_size):
        images = pipe(prompts[prompt_idx:prompt_idx+args.batch_size]).images

        for image_idx, image in enumerate(images):
            prompt = prompts[prompt_idx+image_idx]
            image.save(os.path.join(args.write_images_to, prompt + ".png"))

if __name__ == "__main__":
    main(parse_args())