File size: 4,186 Bytes
ab2369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys

import torch
import os
import json
import argparse

sys.path.append(os.getcwd())

from diffusers import StableDiffusion3Pipeline, FluxPipeline, FlowMatchHeunDiscreteScheduler, FlowMatchEulerDiscreteScheduler
from scheduler.scheduling_flow_match_euler_discrete_lm import FlowMatchEulerDiscreteLMScheduler
from tqdm import tqdm

def main():
    parser = argparse.ArgumentParser(description="sampling script for T2I-Bench.")
    parser.add_argument('--test_num', type=int, default=10)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--num_inference_steps', type=int, default=10)
    parser.add_argument('--guidance', type=float, default=7.5)
    parser.add_argument('--sampler_type', type = str, default='fm_euler')
    parser.add_argument('--model_id', type=str, default='XXX')
    parser.add_argument('--save_dir', type=str, default='results/')
    parser.add_argument('--lamb', type=float, default=5.0)
    parser.add_argument('--kappa', type=float, default=0.0)
    parser.add_argument('--freeze', type=float, default=0.0)
    parser.add_argument('--dataset_category', type=str, default="color")
    parser.add_argument('--dataset_path', type=str, default="T2I-CompBench-main")
    parser.add_argument('--dtype', type=str, default='bf16')
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()
    dtype = None
    if args.dtype in ['fp32']:
        dtype = torch.float32
    elif args.dtype in ['fp64']:
        dtype = torch.float64
    elif args.dtype in ['fp16']:
        dtype = torch.float16
    elif args.dtype in ['bf16']:
        dtype = torch.bfloat16

    start_index = args.start_index
    sampler_type = args.sampler_type
    test_num = args.test_num
    guidance_scale = args.guidance
    num_inference_steps = args.num_inference_steps
    lamb = args.lamb
    freeze = args.freeze
    kappa = args.kappa
    model_id = args.model_id
    device = args.device

    # load model
    sd_pipe = FluxPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype, safety_checker=None)
    sd_pipe = sd_pipe.to(device)
    print("flux model loaded")

    if sampler_type in ['fm_euler']:
        pass
    elif sampler_type in ['lml_euler']:
        sd_pipe.scheduler = FlowMatchEulerDiscreteLMScheduler.from_config(sd_pipe.scheduler.config)
        sd_pipe.scheduler.lamb = lamb
        sd_pipe.scheduler.lm = True
        sd_pipe.scheduler.kappa = kappa
    else:
        raise ValueError(f"invalid: '{sampler_type}'.")

    save_dir = args.save_dir
    
    if sampler_type in ['lml_euler']:
        save_dir = os.path.join(save_dir, "flux", args.dataset_category, sampler_type + "_lamda_" + str(lamb))
    else:
        save_dir = os.path.join(save_dir, "flux", args.dataset_category, sampler_type)
    
    save_dir = os.path.join(save_dir, "samples")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    
    def getT2IDataset(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            for line in file:
                stripped_line = line.strip()
                if stripped_line:
                    yield stripped_line
    
    # T2I prompts
    dataset_path = os.path.join(args.dataset_path, 'examples/dataset', args.dataset_category + '_val.txt')
    count = 0
    with tqdm(total=300 * test_num, desc="Generating Images") as pbar:
        try:
            for prompt in getT2IDataset(dataset_path):
                for seed in range(start_index, start_index + test_num):
                    torch.manual_seed(seed)
                    res = sd_pipe(prompt=prompt, num_inference_steps=num_inference_steps,
                            guidance_scale=guidance_scale, generator=None, width=512, height=512).images[0]
                    res.save(os.path.join(save_dir, f"{prompt}_{count:06d}.png"))
                    count += 1
                    pbar.update(1)
        except FileNotFoundError:
            print(f"dataset can not be found: {dataset_path}")
        except Exception as e:
            print(f"unknown error: {str(e)}")
    print(f"{dataset_path} finish")

if __name__ == '__main__':
    main()