Songwei Ge commited on
Commit
61c1bd4
1 Parent(s): ab7db7f
Files changed (2) hide show
  1. requirements.txt +6 -0
  2. sample.py +0 -109
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.11.0
2
+ torchvision==0.12.0
3
+ diffusers==0.12.1
4
+ transformers==4.25.1
5
+ numpy==1.24.2
6
+ seaborn==0.12.2
sample.py DELETED
@@ -1,109 +0,0 @@
1
- import os
2
- import json
3
- import time
4
- import argparse
5
- import imageio
6
- import torch
7
- import numpy as np
8
- from torchvision import transforms
9
-
10
- from models.region_diffusion import RegionDiffusion
11
- from utils.attention_utils import get_token_maps
12
- from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
13
- get_attention_control_input, get_gradient_guidance_input
14
-
15
-
16
- def main(args, param):
17
-
18
- # Create the folder to store outputs.
19
- run_dir = args.run_dir
20
- os.makedirs(args.run_dir, exist_ok=True)
21
-
22
- # Load region diffusion model.
23
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
- model = RegionDiffusion(device)
25
-
26
- # parse json to span attributes
27
- base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
28
- color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
29
- param['text_input'])
30
-
31
- # create control input for region diffusion
32
- region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
33
- model, base_text_prompt, style_text_prompts, footnote_text_prompts,
34
- footnote_target_tokens, color_text_prompts, color_names)
35
-
36
- # create control input for cross attention
37
- text_format_dict = get_attention_control_input(
38
- model, base_tokens, size_text_prompts_and_sizes)
39
-
40
- # create control input for region guidance
41
- text_format_dict, color_target_token_ids = get_gradient_guidance_input(
42
- model, base_tokens, color_text_prompts, color_rgbs, text_format_dict)
43
-
44
- height = param['height']
45
- width = param['width']
46
- seed = param['noise_index']
47
- negative_text = param['negative_prompt']
48
- seed_everything(seed)
49
-
50
- # get token maps from plain text to image generation.
51
- begin_time = time.time()
52
- if model.attention_maps is None:
53
- model.register_evaluation_hooks()
54
- else:
55
- model.reset_attention_maps()
56
- plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
57
- height=height, width=width, num_inference_steps=param['steps'],
58
- guidance_scale=param['guidance_weight'])
59
- fn_base = os.path.join(run_dir, 'seed%d_plain.png' % (seed))
60
- imageio.imwrite(fn_base, plain_img[0])
61
- print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
62
- color_obj_masks = get_token_maps(
63
- model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
64
- model.masks = get_token_maps(
65
- model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
66
- color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
67
- interpolation=transforms.InterpolationMode.BICUBIC,
68
- antialias=True)
69
- for color_obj_mask in color_obj_masks]
70
- text_format_dict['color_obj_atten'] = color_obj_masks
71
- model.remove_evaluation_hooks()
72
-
73
- # generate image from rich text
74
- begin_time = time.time()
75
- seed_everything(seed)
76
- rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
77
- height=height, width=width, num_inference_steps=param['steps'],
78
- guidance_scale=param['guidance_weight'], use_grad_guidance=use_grad_guidance,
79
- text_format_dict=text_format_dict)
80
- print('time lapses to generate image from rich text: %.4f' %
81
- (time.time()-begin_time))
82
- fn_style = os.path.join(run_dir, 'seed%d_rich.png' % (seed))
83
- imageio.imwrite(fn_style, rich_img[0])
84
- # imageio.imwrite(fn_cat, np.concatenate([img[0], rich_img[0]], 1))
85
-
86
-
87
- if __name__ == '__main__':
88
- parser = argparse.ArgumentParser()
89
- parser.add_argument('--run_dir', type=str, default='results/release/debug')
90
- parser.add_argument('--height', type=int, default=512)
91
- parser.add_argument('--width', type=int, default=512)
92
- parser.add_argument('--seed', type=int, default=6)
93
- parser.add_argument('--sample_steps', type=int, default=41)
94
- parser.add_argument('--rich_text_json', type=str,
95
- default='{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. There are palm trees in the background."}]}')
96
- parser.add_argument('--negative_prompt', type=str, default='')
97
- parser.add_argument('--guidance_weight', type=float, default=8.5)
98
- args = parser.parse_args()
99
- param = {
100
- 'text_input': json.loads(args.rich_text_json),
101
- 'height': args.height,
102
- 'width': args.width,
103
- 'guidance_weight': args.guidance_weight,
104
- 'steps': args.sample_steps,
105
- 'noise_index': args.seed,
106
- 'negative_prompt': args.negative_prompt,
107
- }
108
-
109
- main(args, param)