csyxwei commited on
Commit
2e7bc51
·
1 Parent(s): 5e6a78f

readme and new requirements

Browse files
.gitignore CHANGED
@@ -5,7 +5,6 @@ _sc.py
5
  *.ckpt
6
  *.bin
7
 
8
- checkpoints
9
  .idea
10
  .idea/workspace.xml
11
  .DS_Store
 
5
  *.ckpt
6
  *.bin
7
 
 
8
  .idea
9
  .idea/workspace.xml
10
  .DS_Store
2_gpu.json DELETED
@@ -1,11 +0,0 @@
1
- {
2
- "compute_environment": "LOCAL_MACHINE",
3
- "distributed_type": "MULTI_GPU",
4
- "fp16": false,
5
- "machine_rank": 0,
6
- "main_process_ip": null,
7
- "main_process_port": null,
8
- "main_training_function": "main",
9
- "num_machines": 1,
10
- "num_processes": 2
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
3_gpu.json DELETED
@@ -1,11 +0,0 @@
1
- {
2
- "compute_environment": "LOCAL_MACHINE",
3
- "distributed_type": "MULTI_GPU",
4
- "fp16": false,
5
- "machine_rank": 0,
6
- "main_process_ip": null,
7
- "main_process_port": null,
8
- "main_training_function": "main",
9
- "num_machines": 1,
10
- "num_processes": 3
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,155 @@
1
- # ELITE
2
 
3
- The detailed README is coming soom.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ELITE: Encoding Visual Concepts into Textual Embeddings for Customized Text-to-Image Generation
2
 
3
+
4
+ <a href="https://arxiv.org/pdf/2302.13848.pdf"><img src="https://img.shields.io/badge/arXiv-2302.13848-b31b1b.svg" height=22.5></a>
5
+ <a href="https://huggingface.co/spaces/ELITE-library/ELITE"><img src="https://img.shields.io/static/v1?label=HuggingFace&message=gradio demo&color=darkgreen" height=22.5></a>
6
+
7
+ ## Getting Started
8
+
9
+ ----
10
+
11
+ ### Environment Setup
12
+
13
+ ```shell
14
+ git clone https://github.com/csyxwei/ELITE.git
15
+ cd ELITE
16
+ conda create -n elite python=3.9
17
+ conda activate elite
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ ### Pretrained Models
22
+
23
+ We provide the pretrained checkpoints in [Google Drive](https://drive.google.com/drive/folders/1VkiVZzA_i9gbfuzvHaLH2VYh7kOTzE0x?usp=sharing). One can download them and save to the directory `checkpoints`.
24
+
25
+ ### Setting up Diffusers
26
+
27
+ Our code is built on the [diffusers](https://github.com/huggingface/diffusers/), and you can follow the guideline [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#cat-toy-example) to set it.
28
+
29
+ ### Customized Generation
30
+
31
+ We provide the testing dataset in [test_datasets](./test_datasets), which contains both images and object masks. For testing, you can run,
32
+ ```
33
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
34
+ export DATA_DIR='./test_datasets/'
35
+ CUDA_VISIBLE_DEVICES=0 python inference_local.py \
36
+ --pretrained_model_name_or_path=$MODEL_NAME \
37
+ --test_data_dir=$DATA_DIR \
38
+ --output_dir="./outputs/local_mapping" \
39
+ --suffix="object" \
40
+ --template="a photo of a S" \
41
+ --llambda="0.8" \
42
+ --global_mapper_path="./checkpoints/global_mapper.pt" \
43
+ --local_mapper_path="./checkpoints/local_mapper.pt"
44
+ ```
45
+ or you can use the shell script:
46
+ ```
47
+ bash inference_local.sh
48
+ ```
49
+ If you want to test your customized dataset, you should align the image to ensure the object is at the center of image, and also provide the corresponding object mask. The object mask can be obtained by [image-matting-app](https://huggingface.co/spaces/SankarSrin/image-matting-app), or other image matting methods.
50
+
51
+ ## Training
52
+
53
+ ----
54
+
55
+ ### Preparing Dataset
56
+
57
+ We use the **test** dataset of Open-Images V6 to train our ELITE. You can prepare the dataset as follows:
58
+
59
+ - Download Open-Images test dataset from [CVDF's site](https://github.com/cvdfoundation/open-images-dataset#download-images-with-bounding-boxes-annotations) and unzip it to the directory `datasets/Open_Images/images/test`.
60
+ - Download attribute names file `oidv6-attributes-description.csv` of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and save it to the directory `datasets/Open_Images/annotations/`.
61
+ - Download bbox annotations file `test-annotations-bbox.csv` of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and save it to the directory `datasets/Open_Images/annotations/`.
62
+ - Download segmentation annotations of Open-Images test dataset from [Open-Images official site](https://storage.googleapis.com/openimages/web/download_v7.html#download-manually) and unzip them to the directory `datasets/Open_Images/segs/test`. And put the `test-annotations-object-segmentation.csv` into `datasets/Open_Images/annotations/`.
63
+ - Obtain the mask bbox by running the following command:
64
+ ```shell
65
+ python data_scripts/cal_bbox_by_seg.py
66
+ ```
67
+
68
+ The final data structure is like this:
69
+
70
+ ```
71
+ datasets
72
+ ├── Open_Images
73
+ │ ├── annotations
74
+ │ │ ├── oidv6-class-descriptions.csv
75
+ │ │ ├── test-annotations-object-segmentation.csv
76
+ │ │ ├── test-annotations-bbox.csv
77
+ │ ├── images
78
+ │ │ ├── test
79
+ │ │ │ ├── xxx.jpg
80
+ │ │ │ ├── ...
81
+ │ ├── segs
82
+ │ │ ├── test
83
+ │ │ │ ├── xxx.png
84
+ │ │ │ ├── ...
85
+ │ │ ├── test_bbox_dict.npy
86
+ ```
87
+
88
+ ### Training Global Mapping Network
89
+
90
+ To train the global mapping network, run the following command:
91
+
92
+ ```Shell
93
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
94
+ export DATA_DIR='./datasets/Open_Images/'
95
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
96
+ --pretrained_model_name_or_path=$MODEL_NAME \
97
+ --train_data_dir=$DATA_DIR \
98
+ --placeholder_token="S" \
99
+ --resolution=512 \
100
+ --train_batch_size=4 \
101
+ --gradient_accumulation_steps=4 \
102
+ --max_train_steps=200000 \
103
+ --learning_rate=1e-06 --scale_lr \
104
+ --lr_scheduler="constant" \
105
+ --lr_warmup_steps=0 \
106
+ --output_dir="./elite_experiments/global_mapping" \
107
+ --save_steps 200
108
+ ```
109
+ or you can use the shell script:
110
+ ```shell
111
+ bash train_global.sh
112
+ ```
113
+
114
+ ### Training Local Mapping Network
115
+
116
+ After the global mapping is trained, you can train the local mapping by running the following command:
117
+
118
+ ```Shell
119
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
120
+ export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
121
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
122
+ --pretrained_model_name_or_path=$MODEL_NAME \
123
+ --train_data_dir=$DATA_DIR \
124
+ --placeholder_token="S" \
125
+ --resolution=512 \
126
+ --train_batch_size=2 \
127
+ --gradient_accumulation_steps=4 \
128
+ --max_train_steps=200000 \
129
+ --learning_rate=1e-5 --scale_lr \
130
+ --lr_scheduler="constant" \
131
+ --lr_warmup_steps=0 \
132
+ --global_mapper_path "./elite_experiments/global_mapping/mapper_070000.pt" \
133
+ --output_dir="./elite_experiments/local_mapping" \
134
+ --save_steps 200
135
+ ```
136
+ or you can use the shell script:
137
+ ```shell
138
+ bash train_local.sh
139
+ ```
140
+
141
+
142
+ ## Citation
143
+
144
+ ```
145
+ @article{wei2023elite,
146
+ title={ELITE: Encoding Visual Concepts into Textual Embeddings for Customized Text-to-Image Generation},
147
+ author={Wei, Yuxiang and Zhang, Yabo and Ji, Zhilong and Bai, Jinfeng and Zhang, Lei and Zuo, Wangmeng},
148
+ journal={arXiv preprint arXiv:2302.13848},
149
+ year={2023}
150
+ }
151
+ ```
152
+
153
+ ## Acknowledgements
154
+
155
+ This code is built on [diffusers](https://github.com/huggingface/diffusers/). We thank the authors for sharing the codes.
data_scripts/cal_bbox_by_seg.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from os.path import join
3
+ import os
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ dir = './datasets/Open_Images/'
8
+ mode = 'test'
9
+
10
+ image_dir = join(dir, 'images', mode)
11
+ seg_dir = join(dir, 'segs', mode)
12
+
13
+ files = os.listdir(seg_dir)
14
+
15
+ data_dict = {}
16
+
17
+ for file in tqdm(files):
18
+ seg_path = join(seg_dir, file)
19
+ image_path = join(image_dir, file.split('_')[0] + '.jpg')
20
+ seg = cv2.imread(seg_path)
21
+ image = cv2.imread(image_path)
22
+ seg = cv2.resize(seg, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
23
+
24
+ seg = seg[:, :, 0]
25
+
26
+ # obtain contours point set: contours
27
+ contours = cv2.findContours(seg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
28
+ contours = contours[0] if len(contours) == 2 else contours[1]
29
+
30
+ if len(contours) > 1:
31
+ cntr = np.vstack(contours)
32
+ elif len(contours) == 1:
33
+ cntr = contours[0]
34
+ else:
35
+ continue
36
+
37
+ if len(cntr) < 2:
38
+ continue
39
+
40
+ hs, he = np.min(cntr[:, :, 1]), np.max(cntr[:, :, 1])
41
+ ws, we = np.min(cntr[:, :, 0]), np.max(cntr[:, :, 0])
42
+
43
+ h, w = seg.shape
44
+
45
+ if (he - hs) % 2 == 1 and (he + 1) <= h:
46
+ he = he + 1
47
+ if (he - hs) % 2 == 1 and (hs - 1) >= 0:
48
+ hs = hs - 1
49
+ if (we - ws) % 2 == 1 and (we + 1) <= w:
50
+ we = we + 1
51
+ if (we - ws) % 2 == 1 and (ws - 1) >= 0:
52
+ ws = ws - 1
53
+
54
+ if he - hs < 2 or we - ws < 2:
55
+ continue
56
+
57
+ data_dict[file] = [cntr, hs, he, ws, we]
58
+
59
+ np.save(join(dir, 'segs', f'{mode}_bbox_dict.npy'), data_dict)
datasets.py CHANGED
@@ -141,7 +141,9 @@ class CustomDatasetWithBG(Dataset):
141
  image = Image.open(self.image_paths[i % self.num_images])
142
 
143
  mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
144
- mask = np.array(Image.open(mask_path)) / 255.0
 
 
145
 
146
  if not image.mode == "RGB":
147
  image = image.convert("RGB")
 
141
  image = Image.open(self.image_paths[i % self.num_images])
142
 
143
  mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
144
+ mask = np.array(Image.open(mask_path))
145
+
146
+ mask = np.where(mask > 0, 1, 0)
147
 
148
  if not image.mode == "RGB":
149
  image = image.convert("RGB")
elite.yaml DELETED
@@ -1,147 +0,0 @@
1
- name: elite
2
- channels:
3
- - defaults
4
- dependencies:
5
- - _libgcc_mutex=0.1=main
6
- - ca-certificates=2022.10.11=h06a4308_0
7
- - certifi=2022.9.24=py39h06a4308_0
8
- - ld_impl_linux-64=2.38=h1181459_1
9
- - libffi=3.3=he6710b0_2
10
- - libgcc-ng=9.1.0=hdf63c60_0
11
- - libstdcxx-ng=9.1.0=hdf63c60_0
12
- - ncurses=6.3=h7f8727e_2
13
- - openssl=1.1.1s=h7f8727e_0
14
- - pip=22.2.2=py39h06a4308_0
15
- - python=3.9.12=h12debd9_1
16
- - readline=8.1.2=h7f8727e_1
17
- - sqlite=3.38.5=hc218d9a_0
18
- - tk=8.6.12=h1ccaba5_0
19
- - wheel=0.37.1=pyhd3eb1b0_0
20
- - xz=5.2.5=h7f8727e_1
21
- - zlib=1.2.12=h7f8727e_2
22
- - pip:
23
- - absl-py==1.3.0
24
- - accelerate==0.15.0
25
- - aiohttp==3.8.3
26
- - aiosignal==1.3.1
27
- - albumentations==1.1.0
28
- - altair==4.2.0
29
- - antlr4-python3-runtime==4.8
30
- - async-timeout==4.0.2
31
- - attrs==22.1.0
32
- - blinker==1.5
33
- - cachetools==5.2.0
34
- - charset-normalizer==2.1.1
35
- - click==8.1.3
36
- - commonmark==0.9.1
37
- - contourpy==1.0.6
38
- - cycler==0.11.0
39
- - cython==0.29.33
40
- - decorator==5.1.1
41
- - diffusers==0.11.1
42
- - einops==0.4.1
43
- - emoji==2.2.0
44
- - entrypoints==0.4
45
- - faiss-gpu==1.7.2
46
- - filelock==3.8.0
47
- - fonttools==4.38.0
48
- - frozenlist==1.3.3
49
- - fsspec==2022.11.0
50
- - ftfy==6.1.1
51
- - future==0.18.2
52
- - gitdb==4.0.9
53
- - gitpython==3.1.29
54
- - google-auth==2.14.1
55
- - google-auth-oauthlib==0.4.6
56
- - grpcio==1.50.0
57
- - huggingface-hub==0.11.0
58
- - idna==3.4
59
- - imageio==2.14.1
60
- - imageio-ffmpeg==0.4.7
61
- - importlib-metadata==5.0.0
62
- - jinja2==3.1.2
63
- - joblib==1.2.0
64
- - jsonschema==4.17.0
65
- - kiwisolver==1.4.4
66
- - kornia==0.6.0
67
- - markdown==3.4.1
68
- - markupsafe==2.1.1
69
- - matplotlib==3.6.2
70
- - multidict==6.0.2
71
- - networkx==2.8.8
72
- - nltk==3.7
73
- - numpy==1.23.4
74
- - oauthlib==3.2.2
75
- - omegaconf==2.1.1
76
- - opencv-python==4.6.0.66
77
- - opencv-python-headless==4.6.0.66
78
- - packaging==21.3
79
- - pandas==1.5.1
80
- - pillow==9.0.1
81
- - protobuf==3.20.1
82
- - psutil==5.9.4
83
- - pudb==2019.2
84
- - pyarrow==10.0.0
85
- - pyasn1==0.4.8
86
- - pyasn1-modules==0.2.8
87
- - pycocotools==2.0.6
88
- - pydeck==0.8.0
89
- - pydensecrf==1.0rc2
90
- - pydeprecate==0.3.2
91
- - pygments==2.13.0
92
- - pympler==1.0.1
93
- - pyparsing==3.0.9
94
- - pyrsistent==0.19.2
95
- - python-dateutil==2.8.2
96
- - python-dotenv==0.21.0
97
- - pytorch-lightning==1.6.5
98
- - pytz==2022.6
99
- - pytz-deprecation-shim==0.1.0.post0
100
- - pywavelets==1.4.1
101
- - pyyaml==6.0
102
- - qudida==0.0.4
103
- - regex==2022.10.31
104
- - requests==2.28.1
105
- - requests-oauthlib==1.3.1
106
- - rich==12.6.0
107
- - rsa==4.9
108
- - sacremoses==0.0.53
109
- - scikit-image==0.19.3
110
- - scikit-learn==1.1.3
111
- - scipy==1.9.3
112
- - semver==2.13.0
113
- - setuptools==59.5.0
114
- - six==1.16.0
115
- - smmap==5.0.0
116
- - stanza==1.4.2
117
- - streamlit==1.15.0
118
- - tensorboard==2.11.0
119
- - tensorboard-data-server==0.6.1
120
- - tensorboard-plugin-wit==1.8.1
121
- - test-tube==0.7.5
122
- - threadpoolctl==3.1.0
123
- - tifffile==2022.10.10
124
- - timm==0.6.12
125
- - tokenizers==0.12.1
126
- - toml==0.10.2
127
- - toolz==0.12.0
128
- - torch==1.12.1+cu116
129
- - torch-fidelity==0.3.0
130
- - torchaudio==0.12.1+cu116
131
- - torchmetrics==0.6.0
132
- - torchvision==0.13.1+cu116
133
- - tornado==6.2
134
- - tqdm==4.64.1
135
- - transformers==4.25.1
136
- - typing-extensions==4.4.0
137
- - tzdata==2022.6
138
- - tzlocal==4.2
139
- - urllib3==1.26.12
140
- - urwid==2.1.2
141
- - validators==0.20.0
142
- - watchdog==2.1.9
143
- - wcwidth==0.2.5
144
- - werkzeug==2.2.2
145
- - yarl==1.8.1
146
- - zipp==3.10.0
147
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_global.py CHANGED
@@ -170,6 +170,12 @@ def parse_args():
170
  help="Data index. -1 for all.",
171
  )
172
 
 
 
 
 
 
 
173
  args = parser.parse_args()
174
  return args
175
 
@@ -204,11 +210,7 @@ if __name__ == "__main__":
204
  batch["input_ids"] = batch["input_ids"].to("cuda:0")
205
  batch["index"] = batch["index"].to("cuda:0").long()
206
  print(step, batch['text'])
207
- seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
208
- 8269, 87367, 29379, 4658, 39, 598]
209
- seeds = sorted(seeds)
210
- for seed in seeds:
211
- syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
212
- token_index=args.token_index, seed=seed)
213
- concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
214
- Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
 
170
  help="Data index. -1 for all.",
171
  )
172
 
173
+ parser.add_argument(
174
+ "--seed",
175
+ type=int,
176
+ default=None,
177
+ help="A seed for testing.",
178
+ )
179
  args = parser.parse_args()
180
  return args
181
 
 
210
  batch["input_ids"] = batch["input_ids"].to("cuda:0")
211
  batch["index"] = batch["index"].to("cuda:0").long()
212
  print(step, batch['text'])
213
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
214
+ token_index=args.token_index, seed=args.seed)
215
+ concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
216
+ Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(args.seed).zfill(5)}.jpg'))
 
 
 
 
inference_global.sh CHANGED
@@ -1,12 +1,13 @@
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
  export DATA_DIR='./test_datasets/'
3
 
4
- CUDA_VISIBLE_DEVICES=6 python inference_global.py \
5
  --pretrained_model_name_or_path=$MODEL_NAME \
6
  --test_data_dir=$DATA_DIR \
7
  --output_dir="./outputs/global_mapping" \
8
  --suffix="object" \
9
  --token_index="0" \
10
- --template="a photo of a {}" \
11
- --global_mapper_path="./checkpoints/global_mapper.pt"
 
12
 
 
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
  export DATA_DIR='./test_datasets/'
3
 
4
+ CUDA_VISIBLE_DEVICES=7 python inference_global.py \
5
  --pretrained_model_name_or_path=$MODEL_NAME \
6
  --test_data_dir=$DATA_DIR \
7
  --output_dir="./outputs/global_mapping" \
8
  --suffix="object" \
9
  --token_index="0" \
10
+ --template="a photo of a S" \
11
+ --global_mapper_path="./checkpoints/global_mapper.pt" \
12
+ --seed 42
13
 
inference_local.py CHANGED
@@ -199,6 +199,13 @@ def parse_args():
199
  help="Lambda for fuse the global and local feature.",
200
  )
201
 
 
 
 
 
 
 
 
202
  args = parser.parse_args()
203
  return args
204
 
@@ -236,12 +243,8 @@ if __name__ == "__main__":
236
  batch["input_ids"] = batch["input_ids"].to("cuda:0")
237
  batch["index"] = batch["index"].to("cuda:0").long()
238
  print(step, batch['text'])
239
- seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
240
- 8269, 87367, 29379, 4658, 39, 598]
241
- seeds = sorted(seeds)
242
- for seed in seeds:
243
- syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
244
- batch["pixel_values_clip"].device, 5,
245
- seed=seed, llambda=float(args.llambda))
246
- concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
247
- Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
 
199
  help="Lambda for fuse the global and local feature.",
200
  )
201
 
202
+ parser.add_argument(
203
+ "--seed",
204
+ type=int,
205
+ default=None,
206
+ help="A seed for testing.",
207
+ )
208
+
209
  args = parser.parse_args()
210
  return args
211
 
 
243
  batch["input_ids"] = batch["input_ids"].to("cuda:0")
244
  batch["index"] = batch["index"].to("cuda:0").long()
245
  print(step, batch['text'])
246
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
247
+ batch["pixel_values_clip"].device, 5,
248
+ seed=args.seed, llambda=float(args.llambda))
249
+ concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
250
+ Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(args.seed).zfill(5)}.jpg'))
 
 
 
 
inference_local.sh CHANGED
@@ -5,8 +5,9 @@ CUDA_VISIBLE_DEVICES=7 python inference_local.py \
5
  --test_data_dir=$DATA_DIR \
6
  --output_dir="./outputs/local_mapping" \
7
  --suffix="object" \
8
- --template="a photo of a {}" \
9
  --llambda="0.8" \
10
  --global_mapper_path="./checkpoints/global_mapper.pt" \
11
- --local_mapper_path="./checkpoints/local_mapper.pt"
 
12
 
 
5
  --test_data_dir=$DATA_DIR \
6
  --output_dir="./outputs/local_mapping" \
7
  --suffix="object" \
8
+ --template="a photo of a S" \
9
  --llambda="0.8" \
10
  --global_mapper_path="./checkpoints/global_mapper.pt" \
11
+ --local_mapper_path="./checkpoints/local_mapper.pt" \
12
+ --seed 42
13
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.16.0
2
+ albumentations==1.3.0
3
+ diffusers==0.11.1
4
+ gradio==3.20.1
5
+ huggingface-hub==0.13.0
6
+ opencv-python-headless==4.7.0.68
7
+ Pillow==9.4.0
8
+ torch==1.13.1
9
+ torchvision==0.14.1
10
+ tqdm==4.65.0
11
+ transformers==4.26.1
train_global.py CHANGED
@@ -11,7 +11,6 @@ import torch.nn.functional as F
11
  import torch.utils.checkpoint
12
  from torch.utils.data import Dataset
13
 
14
- import PIL
15
  from accelerate import Accelerator
16
  from accelerate.logging import get_logger
17
  from accelerate.utils import set_seed
@@ -31,7 +30,7 @@ from PIL import Image
31
  from tqdm.auto import tqdm
32
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
33
 
34
- from typing import Any, Optional, Tuple, Union
35
  from datasets import OpenImagesDataset
36
 
37
 
@@ -362,7 +361,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
362
  else:
363
  return f"{organization}/{model_id}"
364
 
365
-
366
  def freeze_params(params):
367
  for param in params:
368
  param.requires_grad = False
 
11
  import torch.utils.checkpoint
12
  from torch.utils.data import Dataset
13
 
 
14
  from accelerate import Accelerator
15
  from accelerate.logging import get_logger
16
  from accelerate.utils import set_seed
 
30
  from tqdm.auto import tqdm
31
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
32
 
33
+ from typing import Optional, Tuple, Union
34
  from datasets import OpenImagesDataset
35
 
36
 
 
361
  else:
362
  return f"{organization}/{model_id}"
363
 
 
364
  def freeze_params(params):
365
  for param in params:
366
  param.requires_grad = False
train_global.sh CHANGED
@@ -1,6 +1,6 @@
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
- export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
3
- CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
4
  --pretrained_model_name_or_path=$MODEL_NAME \
5
  --train_data_dir=$DATA_DIR \
6
  --placeholder_token="S" \
@@ -11,5 +11,5 @@ CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_p
11
  --learning_rate=1e-06 --scale_lr \
12
  --lr_scheduler="constant" \
13
  --lr_warmup_steps=0 \
14
- --output_dir="./elite_experiments/global_mapping" \
15
  --save_steps 200
 
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='./datasets/Open_Images/'
3
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
4
  --pretrained_model_name_or_path=$MODEL_NAME \
5
  --train_data_dir=$DATA_DIR \
6
  --placeholder_token="S" \
 
11
  --learning_rate=1e-06 --scale_lr \
12
  --lr_scheduler="constant" \
13
  --lr_warmup_steps=0 \
14
+ --output_dir="./elite_experiments/global_mapping_new" \
15
  --save_steps 200
train_local.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import argparse
3
  import itertools
4
  import math
@@ -16,15 +15,13 @@ import PIL
16
  from accelerate import Accelerator
17
  from accelerate.logging import get_logger
18
  from accelerate.utils import set_seed
19
- from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
20
  from diffusers.optimization import get_scheduler
21
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
22
  from huggingface_hub import HfFolder, Repository, whoami
23
 
24
- # TODO: remove and import from diffusers.utils when the new version of diffusers is released
25
  from PIL import Image
26
  from tqdm.auto import tqdm
27
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
28
 
29
 
30
  from typing import Optional
 
 
1
  import argparse
2
  import itertools
3
  import math
 
15
  from accelerate import Accelerator
16
  from accelerate.logging import get_logger
17
  from accelerate.utils import set_seed
18
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, LMSDiscreteScheduler
19
  from diffusers.optimization import get_scheduler
 
20
  from huggingface_hub import HfFolder, Repository, whoami
21
 
 
22
  from PIL import Image
23
  from tqdm.auto import tqdm
24
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
25
 
26
 
27
  from typing import Optional
train_local.sh CHANGED
@@ -1,6 +1,6 @@
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
- export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
3
- CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
4
  --pretrained_model_name_or_path=$MODEL_NAME \
5
  --train_data_dir=$DATA_DIR \
6
  --placeholder_token="S" \
 
1
  export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='./datasets/Open_Images/'
3
+ CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
4
  --pretrained_model_name_or_path=$MODEL_NAME \
5
  --train_data_dir=$DATA_DIR \
6
  --placeholder_token="S" \