Anyou commited on
Commit
4a6e43e
1 Parent(s): 0ffaa52

Upload 11 files

Browse files
Files changed (11) hide show
  1. __init__.py +0 -0
  2. config.yaml +63 -0
  3. environment.yml +271 -0
  4. fid_utils.py +41 -0
  5. main.py +537 -0
  6. pororo_100.h5 +3 -0
  7. readme-storyvisualization.md +123 -0
  8. requirements.txt +10 -0
  9. run.sh +1 -0
  10. test.py +94 -0
  11. transtoyolo.py +320 -0
__init__.py ADDED
Binary file (2 Bytes). View file
 
config.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # device
2
+ mode: sample # train sample
3
+ gpu_ids: [3] # gpu ids
4
+ batch_size: 1 # batch size each item denotes one story
5
+ num_workers: 4 # number of workers
6
+ num_cpu_cores: -1 # number of cpu cores
7
+ seed: 0 # random seed
8
+ ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
9
+ run_name: ARLDM # name for this run
10
+
11
+ # task
12
+ dataset: pororo # pororo flintstones vistsis vistdii
13
+ task: visualization # continuation visualization
14
+
15
+ # train
16
+ init_lr: 1e-5 # initial learning rate
17
+ warmup_epochs: 1 # warmup epochs
18
+ max_epochs: 5 #50 # max epochs
19
+ train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
20
+ freeze_clip: True #False # whether to freeze clip
21
+ freeze_blip: True #False # whether to freeze blip
22
+ freeze_resnet: True #False # whether to freeze resnet
23
+
24
+ # sample
25
+ test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
26
+ calculate_fid: True # whether to calculate FID scores
27
+ scheduler: ddim # ddim pndm
28
+ guidance_scale: 6 # guidance scale
29
+ num_inference_steps: 250 # number of inference steps
30
+ sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
31
+
32
+ pororo:
33
+ hdf5_file: /root/lihui/StoryVisualization/pororo.h5
34
+ max_length: 85
35
+ new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
36
+ clip_embedding_tokens: 49416
37
+ blip_embedding_tokens: 30530
38
+
39
+ flintstones:
40
+ hdf5_file: /path/to/flintstones.h5
41
+ max_length: 91
42
+ new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
43
+ clip_embedding_tokens: 49412
44
+ blip_embedding_tokens: 30525
45
+
46
+ vistsis:
47
+ hdf5_file: /path/to/vist.h5
48
+ max_length: 100
49
+ clip_embedding_tokens: 49408
50
+ blip_embedding_tokens: 30524
51
+
52
+ vistdii:
53
+ hdf5_file: /path/to/vist.h5
54
+ max_length: 65
55
+ clip_embedding_tokens: 49408
56
+ blip_embedding_tokens: 30524
57
+
58
+ hydra:
59
+ run:
60
+ dir: .
61
+ output_subdir: null
62
+ hydra/job_logging: disabled
63
+ hydra/hydra_logging: disabled
environment.yml ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: story
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - brotlipy=0.7.0=py38h27cfd23_1003
11
+ - bzip2=1.0.8=h7b6447c_0
12
+ - ca-certificates=2023.01.10=h06a4308_0
13
+ - certifi=2022.12.7=py38h06a4308_0
14
+ - cffi=1.15.1=py38h5eee18b_3
15
+ - cryptography=39.0.1=py38h9ce1e76_0
16
+ - cuda-cudart=11.7.99=0
17
+ - cuda-cupti=11.7.101=0
18
+ - cuda-libraries=11.7.1=0
19
+ - cuda-nvrtc=11.7.99=0
20
+ - cuda-nvtx=11.7.91=0
21
+ - cuda-runtime=11.7.1=0
22
+ - ffmpeg=4.3=hf484d3e_0
23
+ - flit-core=3.8.0=py38h06a4308_0
24
+ - freetype=2.12.1=h4a9f257_0
25
+ - giflib=5.2.1=h5eee18b_3
26
+ - gmp=6.2.1=h295c915_3
27
+ - gnutls=3.6.15=he1e5248_0
28
+ - idna=3.4=py38h06a4308_0
29
+ - intel-openmp=2021.4.0=h06a4308_3561
30
+ - jpeg=9e=h5eee18b_1
31
+ - lame=3.100=h7b6447c_0
32
+ - lcms2=2.12=h3be6417_0
33
+ - ld_impl_linux-64=2.38=h1181459_1
34
+ - lerc=3.0=h295c915_0
35
+ - libcublas=11.10.3.66=0
36
+ - libcufft=10.7.2.124=h4fbf590_0
37
+ - libcufile=1.6.0.25=0
38
+ - libcurand=10.3.2.56=0
39
+ - libcusolver=11.4.0.1=0
40
+ - libcusparse=11.7.4.91=0
41
+ - libdeflate=1.17=h5eee18b_0
42
+ - libffi=3.4.2=h6a678d5_6
43
+ - libgcc-ng=11.2.0=h1234567_1
44
+ - libgomp=11.2.0=h1234567_1
45
+ - libiconv=1.16=h7f8727e_2
46
+ - libidn2=2.3.2=h7f8727e_0
47
+ - libnpp=11.7.4.75=0
48
+ - libnvjpeg=11.8.0.2=0
49
+ - libpng=1.6.39=h5eee18b_0
50
+ - libstdcxx-ng=11.2.0=h1234567_1
51
+ - libtasn1=4.19.0=h5eee18b_0
52
+ - libtiff=4.5.0=h6a678d5_2
53
+ - libunistring=0.9.10=h27cfd23_0
54
+ - libwebp=1.2.4=h11a3e52_1
55
+ - libwebp-base=1.2.4=h5eee18b_1
56
+ - lz4-c=1.9.4=h6a678d5_0
57
+ - mkl=2021.4.0=h06a4308_640
58
+ - mkl-service=2.4.0=py38h7f8727e_0
59
+ - mkl_fft=1.3.1=py38hd3c417c_0
60
+ - mkl_random=1.2.2=py38h51133e4_0
61
+ - ncurses=6.4=h6a678d5_0
62
+ - nettle=3.7.3=hbbd107a_1
63
+ - numpy-base=1.23.5=py38h31eccc5_0
64
+ - openh264=2.1.1=h4ff587b_0
65
+ - openssl=1.1.1t=h7f8727e_0
66
+ - pip=23.0.1=py38h06a4308_0
67
+ - pycparser=2.21=pyhd3eb1b0_0
68
+ - pyopenssl=23.0.0=py38h06a4308_0
69
+ - pysocks=1.7.1=py38h06a4308_0
70
+ - python=3.8.16=h7a1cb2a_3
71
+ - pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0
72
+ - pytorch-cuda=11.7=h778d358_3
73
+ - pytorch-mutex=1.0=cuda
74
+ - readline=8.2=h5eee18b_0
75
+ - six=1.16.0=pyhd3eb1b0_1
76
+ - sqlite=3.41.1=h5eee18b_0
77
+ - tk=8.6.12=h1ccaba5_0
78
+ - typing_extensions=4.4.0=py38h06a4308_0
79
+ - urllib3=1.26.15=py38h06a4308_0
80
+ - wheel=0.38.4=py38h06a4308_0
81
+ - xz=5.2.10=h5eee18b_1
82
+ - zlib=1.2.13=h5eee18b_0
83
+ - zstd=1.5.4=hc292b87_0
84
+ - pip:
85
+ - absl-py==1.4.0
86
+ - accelerate==0.17.1
87
+ - aiofiles==23.1.0
88
+ - aiohttp==3.8.4
89
+ - aiosignal==1.3.1
90
+ - altair==4.2.2
91
+ - antlr4-python3-runtime==4.9.3
92
+ - anyio==3.6.2
93
+ - appdirs==1.4.4
94
+ - argon2-cffi==21.3.0
95
+ - argon2-cffi-bindings==21.2.0
96
+ - arrow==1.2.3
97
+ - asttokens==2.2.1
98
+ - async-timeout==4.0.2
99
+ - attrs==22.2.0
100
+ - backcall==0.2.0
101
+ - beautifulsoup4==4.11.2
102
+ - bleach==6.0.0
103
+ - cachetools==5.3.0
104
+ - chardet==5.1.0
105
+ - charset-normalizer==3.1.0
106
+ - click==8.1.3
107
+ - comm==0.1.2
108
+ - contourpy==1.0.7
109
+ - cycler==0.11.0
110
+ - debugpy==1.6.6
111
+ - decorator==5.1.1
112
+ - defusedxml==0.7.1
113
+ - diffusers==0.9.0
114
+ - docker-pycreds==0.4.0
115
+ - entrypoints==0.4
116
+ - executing==1.2.0
117
+ - fastapi==0.95.0
118
+ - fastjsonschema==2.16.3
119
+ - ffmpy==0.3.0
120
+ - filelock==3.10.0
121
+ - fire==0.5.0
122
+ - flatbuffers==23.3.3
123
+ - fonttools==4.39.3
124
+ - fqdn==1.5.1
125
+ - frozenlist==1.3.3
126
+ - fsspec==2023.3.0
127
+ - ftfy==6.1.1
128
+ - gitdb==4.0.10
129
+ - gitpython==3.1.31
130
+ - google-auth==2.16.2
131
+ - google-auth-oauthlib==0.4.6
132
+ - gradio==3.24.1
133
+ - gradio-client==0.0.5
134
+ - grpcio==1.51.3
135
+ - h11==0.14.0
136
+ - h5py==3.8.0
137
+ - httpcore==0.16.3
138
+ - httpx==0.23.3
139
+ - huggingface-hub==0.13.2
140
+ - hydra-core==1.3.2
141
+ - importlib-metadata==6.1.0
142
+ - importlib-resources==5.12.0
143
+ - ipykernel==6.21.3
144
+ - ipython==8.11.0
145
+ - ipython-genutils==0.2.0
146
+ - ipywidgets==8.0.4
147
+ - isoduration==20.11.0
148
+ - jedi==0.18.2
149
+ - jinja2==3.1.2
150
+ - jsonpointer==2.3
151
+ - jsonschema==4.17.3
152
+ - jupyter==1.0.0
153
+ - jupyter-client==8.0.3
154
+ - jupyter-console==6.6.3
155
+ - jupyter-core==5.3.0
156
+ - jupyter-events==0.6.3
157
+ - jupyter-server==2.5.0
158
+ - jupyter-server-terminals==0.4.4
159
+ - jupyterlab-pygments==0.2.2
160
+ - jupyterlab-widgets==3.0.5
161
+ - kiwisolver==1.4.4
162
+ - lightning-bolts==0.5.0
163
+ - linkify-it-py==2.0.0
164
+ - lora-diffusion==0.1.7
165
+ - markdown==3.4.1
166
+ - markdown-it-py==2.2.0
167
+ - markupsafe==2.1.2
168
+ - matplotlib==3.7.1
169
+ - matplotlib-inline==0.1.6
170
+ - mdit-py-plugins==0.3.3
171
+ - mdurl==0.1.2
172
+ - mediapipe==0.9.1.0
173
+ - mistune==2.0.5
174
+ - multidict==6.0.4
175
+ - nbclassic==0.5.3
176
+ - nbclient==0.7.2
177
+ - nbconvert==7.2.10
178
+ - nbformat==5.7.3
179
+ - nest-asyncio==1.5.6
180
+ - notebook==6.5.3
181
+ - notebook-shim==0.2.2
182
+ - numpy==1.24.2
183
+ - oauthlib==3.2.2
184
+ - omegaconf==2.3.0
185
+ - opencv-contrib-python==4.7.0.72
186
+ - opencv-python==4.7.0.72
187
+ - orjson==3.8.9
188
+ - packaging==23.0
189
+ - pandas==1.5.3
190
+ - pandocfilters==1.5.0
191
+ - parso==0.8.3
192
+ - pathtools==0.1.2
193
+ - pexpect==4.8.0
194
+ - pickleshare==0.7.5
195
+ - pillow==9.4.0
196
+ - pkgutil-resolve-name==1.3.10
197
+ - platformdirs==3.1.1
198
+ - prometheus-client==0.16.0
199
+ - prompt-toolkit==3.0.38
200
+ - protobuf==3.20.1
201
+ - psutil==5.9.4
202
+ - ptyprocess==0.7.0
203
+ - pure-eval==0.2.2
204
+ - pyasn1==0.4.8
205
+ - pyasn1-modules==0.2.8
206
+ - pydantic==1.10.7
207
+ - pydeprecate==0.3.2
208
+ - pydub==0.25.1
209
+ - pygments==2.14.0
210
+ - pyparsing==3.0.9
211
+ - pyrsistent==0.19.3
212
+ - python-dateutil==2.8.2
213
+ - python-json-logger==2.0.7
214
+ - python-multipart==0.0.6
215
+ - pytorch-lightning==1.6.5
216
+ - pytz==2023.3
217
+ - pyyaml==6.0
218
+ - pyzmq==25.0.1
219
+ - qtconsole==5.4.1
220
+ - qtpy==2.3.0
221
+ - regex==2022.10.31
222
+ - requests==2.28.2
223
+ - requests-oauthlib==1.3.1
224
+ - rfc3339-validator==0.1.4
225
+ - rfc3986==1.5.0
226
+ - rfc3986-validator==0.1.1
227
+ - rsa==4.9
228
+ - safetensors==0.3.0
229
+ - scipy==1.10.1
230
+ - semantic-version==2.10.0
231
+ - send2trash==1.8.0
232
+ - sentry-sdk==1.17.0
233
+ - setproctitle==1.3.2
234
+ - setuptools==59.5.0
235
+ - smmap==5.0.0
236
+ - sniffio==1.3.0
237
+ - soupsieve==2.4
238
+ - stack-data==0.6.2
239
+ - starlette==0.26.1
240
+ - tensorboard==2.12.0
241
+ - tensorboard-data-server==0.7.0
242
+ - tensorboard-plugin-wit==1.8.1
243
+ - termcolor==2.2.0
244
+ - terminado==0.17.1
245
+ - timm==0.6.12
246
+ - tinycss2==1.2.1
247
+ - tokenizers==0.13.2
248
+ - toolz==0.12.0
249
+ - torch==1.9.0
250
+ - torchaudio==0.9.0
251
+ - torchmetrics==0.11.4
252
+ - torchvision==0.10.0+cu111
253
+ - tornado==6.2
254
+ - tqdm==4.65.0
255
+ - traitlets==5.9.0
256
+ - transformers==4.28.1
257
+ - typing-extensions==4.5.0
258
+ - uc-micro-py==1.0.1
259
+ - uri-template==1.2.0
260
+ - uvicorn==0.21.1
261
+ - wandb==0.14.0
262
+ - wcwidth==0.2.6
263
+ - webcolors==1.12
264
+ - webencodings==0.5.1
265
+ - websocket-client==1.5.1
266
+ - websockets==11.0
267
+ - werkzeug==2.2.3
268
+ - widgetsnbextension==4.0.5
269
+ - yarl==1.8.2
270
+ - zipp==3.15.0
271
+ prefix: /root/anaconda3/envs/story
fid_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
6
+ mu1 = np.atleast_1d(mu1)
7
+ mu2 = np.atleast_1d(mu2)
8
+
9
+ sigma1 = np.atleast_2d(sigma1)
10
+ sigma2 = np.atleast_2d(sigma2)
11
+
12
+ assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
13
+ assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
14
+
15
+ diff = mu1 - mu2
16
+
17
+ # Product might be almost singular
18
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
19
+ if not np.isfinite(covmean).all():
20
+ print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps)
21
+ offset = np.eye(sigma1.shape[0]) * eps
22
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
23
+
24
+ # Numerical error might give slight imaginary component
25
+ if np.iscomplexobj(covmean):
26
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
27
+ m = np.max(np.abs(covmean.imag))
28
+ raise ValueError('Imaginary component {}'.format(m))
29
+ covmean = covmean.real
30
+
31
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
32
+
33
+
34
+ def calculate_fid_given_features(feature1, feature2):
35
+ mu1 = np.mean(feature1, axis=0)
36
+ sigma1 = np.cov(feature1, rowvar=False)
37
+ mu2 = np.mean(feature2, axis=0)
38
+ sigma2 = np.cov(feature2, rowvar=False)
39
+ fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
40
+
41
+ return fid_value
main.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+
4
+ import cv2
5
+ import hydra
6
+ import numpy as np
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from PIL import Image
12
+ from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler
13
+ from omegaconf import DictConfig
14
+ from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
15
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
16
+ from pytorch_lightning.loggers import TensorBoardLogger
17
+ from pytorch_lightning.strategies import DDPStrategy
18
+ from torch import nn
19
+ from torch.utils.data import DataLoader
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+
23
+ from fid_utils import calculate_fid_given_features
24
+ from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
25
+
26
+ from models.blip_override.blip import blip_feature_extractor, init_tokenizer
27
+ from models.diffusers_override.unet_2d_condition import UNet2DConditionModel
28
+ from models.inception import InceptionV3
29
+ unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"}
30
+ #!/usr/bin/env python3
31
+ from transformers import CLIPProcessor
32
+ import transformers
33
+ from PIL import Image
34
+ import PIL.Image
35
+ import numpy as np
36
+ import torchvision.transforms as tvtrans
37
+ import requests
38
+ from io import BytesIO
39
+
40
+ class LightningDataset(pl.LightningDataModule):
41
+ def __init__(self, args: DictConfig):
42
+ super(LightningDataset, self).__init__()
43
+ self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
44
+ "pin_memory": True}
45
+ self.args = args
46
+
47
+ def setup(self, stage="fit"):
48
+ if self.args.dataset == "pororo":
49
+ import datasets.pororo as data
50
+ elif self.args.dataset == 'flintstones':
51
+ import datasets.flintstones as data
52
+ elif self.args.dataset == 'vistsis':
53
+ import datasets.vistsis as data
54
+ elif self.args.dataset == 'vistdii':
55
+ import datasets.vistdii as data
56
+ else:
57
+ raise ValueError("Unknown dataset: {}".format(self.args.dataset))
58
+ if stage == "fit":
59
+ self.train_data = data.StoryDataset("train", self.args)
60
+ self.val_data = data.StoryDataset("val", self.args)
61
+ if stage == "test":
62
+ self.test_data = data.StoryDataset("test", self.args)
63
+
64
+ def train_dataloader(self):
65
+ if not hasattr(self, 'trainloader'):
66
+ self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
67
+ return self.trainloader
68
+
69
+ def val_dataloader(self):
70
+ return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
71
+
72
+ def test_dataloader(self):
73
+ return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
74
+
75
+ def predict_dataloader(self):
76
+ return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
77
+
78
+ def get_length_of_train_dataloader(self):
79
+ if not hasattr(self, 'trainloader'):
80
+ self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
81
+ return len(self.trainloader)
82
+
83
+
84
+ class ARLDM(pl.LightningModule):
85
+ def __init__(self, args: DictConfig, steps_per_epoch=1):
86
+ super(ARLDM, self).__init__()
87
+ self.args = args
88
+ self.steps_per_epoch = steps_per_epoch
89
+ """
90
+ Configurations
91
+ """
92
+ self.task = args.task
93
+
94
+ if args.mode == 'sample':
95
+ if args.scheduler == "pndm":
96
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
97
+ skip_prk_steps=True)
98
+ elif args.scheduler == "ddim":
99
+ self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
100
+ clip_sample=False, set_alpha_to_one=True)
101
+ else:
102
+ raise ValueError("Scheduler not supported")
103
+ self.fid_augment = transforms.Compose([
104
+ transforms.Resize([64, 64]),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
107
+ ])
108
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
109
+ self.inception = InceptionV3([block_idx])
110
+
111
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
112
+ ##############################
113
+ #self.clip_tokenizer.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/tokenizer')
114
+ self.blip_tokenizer = init_tokenizer()
115
+ self.blip_image_processor = transforms.Compose([
116
+ transforms.Resize([224, 224]),
117
+ transforms.ToTensor(),
118
+ transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
119
+ ])
120
+ self.max_length = args.get(args.dataset).max_length
121
+
122
+ blip_image_null_token = self.blip_image_processor(
123
+ Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float()
124
+ clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length,
125
+ return_tensors="pt").input_ids
126
+ blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length,
127
+ return_tensors="pt").input_ids
128
+
129
+ self.register_buffer('clip_text_null_token', clip_text_null_token)
130
+ self.register_buffer('blip_text_null_token', blip_text_null_token)
131
+ self.register_buffer('blip_image_null_token', blip_image_null_token)
132
+
133
+ self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5',
134
+ subfolder="text_encoder")
135
+ ############################################
136
+ #self.text_encoder.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/text_encoder')
137
+ self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens)
138
+ # resize_position_embeddings
139
+ old_embeddings = self.text_encoder.text_model.embeddings.position_embedding
140
+ new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length)
141
+ self.text_encoder.text_model.embeddings.position_embedding = new_embeddings
142
+ self.text_encoder.config.max_position_embeddings = self.max_length
143
+ self.text_encoder.max_position_embeddings = self.max_length
144
+ self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1))
145
+
146
+ self.modal_type_embeddings = nn.Embedding(2, 768)
147
+ self.time_embeddings = nn.Embedding(5, 768)
148
+ self.mm_encoder = blip_feature_extractor(
149
+ # pretrained='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
150
+ pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth',
151
+ image_size=224, vit='large')#, local_files_only=True)
152
+ self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens)
153
+
154
+ self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae")
155
+ self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
156
+
157
+ self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
158
+ num_train_timesteps=1000)
159
+ # monkeypatch_or_replace_lora(
160
+ # self.unet,
161
+ # torch.load("lora/example_loras/analog_svd_rank4.safetensors"),
162
+ # r=4,
163
+ # target_replace_module=unet_target_replace_module,
164
+ # )
165
+ #
166
+ # tune_lora_scale(self.unet, 1.00)
167
+ #tune_lora_scale(self.text_encoder, 1.00)
168
+
169
+ # torch.manual_seed(0)
170
+ ###################################
171
+ #self.vae.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/vae')
172
+ #self.unet.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/unet')
173
+
174
+ # Freeze vae and unet
175
+ self.freeze_params(self.vae.parameters())
176
+ if args.freeze_resnet:
177
+ self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n])
178
+
179
+ if args.freeze_blip and hasattr(self, "mm_encoder"):
180
+ self.freeze_params(self.mm_encoder.parameters())
181
+ self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters())
182
+
183
+ if args.freeze_clip and hasattr(self, "text_encoder"):
184
+ self.freeze_params(self.text_encoder.parameters())
185
+ self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters())
186
+
187
+ @staticmethod
188
+ def freeze_params(params):
189
+ for param in params:
190
+ param.requires_grad = False
191
+
192
+ @staticmethod
193
+ def unfreeze_params(params):
194
+ for param in params:
195
+ param.requires_grad = True
196
+
197
+ def configure_optimizers(self):
198
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) # optim_bits=8
199
+ scheduler = LinearWarmupCosineAnnealingLR(optimizer,
200
+ warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch,
201
+ max_epochs=self.args.max_epochs * self.steps_per_epoch)
202
+ optim_dict = {
203
+ 'optimizer': optimizer,
204
+ 'lr_scheduler': {
205
+ 'scheduler': scheduler, # The LR scheduler instance (required)
206
+ 'interval': 'step', # The unit of the scheduler's step size
207
+ }
208
+ }
209
+ return optim_dict
210
+
211
+ def forward(self, batch):
212
+ if self.args.freeze_clip and hasattr(self, "text_encoder"):
213
+ self.text_encoder.eval()
214
+ if self.args.freeze_blip and hasattr(self, "mm_encoder"):
215
+ self.mm_encoder.eval()
216
+ images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch
217
+ B, V, S = captions.shape
218
+ src_V = V + 1 if self.task == 'continuation' else V
219
+ images = torch.flatten(images, 0, 1)
220
+ captions = torch.flatten(captions, 0, 1)
221
+ attention_mask = torch.flatten(attention_mask, 0, 1)
222
+ source_images = torch.flatten(source_images, 0, 1)
223
+ source_caption = torch.flatten(source_caption, 0, 1)
224
+ source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
225
+ # 1 is not masked, 0 is maske
226
+
227
+ classifier_free_idx = np.random.rand(B * V) < 0.1
228
+
229
+ caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
230
+ source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
231
+ mode='multimodal').reshape(B, src_V * S, -1)
232
+ source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
233
+ caption_embeddings[classifier_free_idx] = \
234
+ self.text_encoder(self.clip_text_null_token).last_hidden_state[0]
235
+ source_embeddings[classifier_free_idx] = \
236
+ self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None,
237
+ mode='multimodal')[0].repeat(src_V, 1)
238
+ caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
239
+ source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
240
+ source_embeddings += self.time_embeddings(
241
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
242
+ encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
243
+
244
+ attention_mask = torch.cat(
245
+ [attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
246
+ attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
247
+ attention_mask[classifier_free_idx] = False
248
+
249
+ # B, V, V, S
250
+ square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
251
+ square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
252
+ square_mask = square_mask.reshape(B * V, V * S)
253
+ attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
254
+
255
+ latents = self.vae.encode(images).latent_dist.sample()
256
+ latents = latents * 0.18215
257
+
258
+ noise = torch.randn(latents.shape, device=self.device)
259
+ bsz = latents.shape[0]
260
+ timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long()
261
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
262
+
263
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample
264
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
265
+ return loss
266
+
267
+ def sample(self, batch):
268
+ original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch
269
+ B, V, S = captions.shape
270
+ src_V = V + 1 if self.task == 'continuation' else V
271
+ original_images = torch.flatten(original_images, 0, 1)
272
+ captions = torch.flatten(captions, 0, 1)
273
+ attention_mask = torch.flatten(attention_mask, 0, 1)
274
+ source_images = torch.flatten(source_images, 0, 1)
275
+ source_caption = torch.flatten(source_caption, 0, 1)
276
+ source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
277
+
278
+ caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
279
+ source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
280
+ mode='multimodal').reshape(B, src_V * S, -1)
281
+ caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
282
+ source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
283
+ source_embeddings += self.time_embeddings(
284
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
285
+ source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
286
+ encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
287
+
288
+ attention_mask = torch.cat(
289
+ [attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
290
+ attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
291
+ # B, V, V, S
292
+ square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
293
+ square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
294
+ square_mask = square_mask.reshape(B * V, V * S)
295
+ attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
296
+
297
+ uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state
298
+ uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token,
299
+ attention_mask=None, mode='multimodal').repeat(1, src_V, 1)
300
+ uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
301
+ uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
302
+ uncond_source_embeddings += self.time_embeddings(
303
+ torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
304
+ uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1)
305
+ uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1)
306
+
307
+ encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
308
+ uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool()
309
+ uncond_attention_mask[:, -V * S:] = square_mask
310
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0)
311
+
312
+ attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S)
313
+ images = list()
314
+ for i in range(V):
315
+ encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1)
316
+ new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1),
317
+ attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S),
318
+ 512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0)
319
+ images += new_image
320
+
321
+ new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device)
322
+
323
+ new_embedding = self.mm_encoder(new_image, # B,C,H,W
324
+ source_caption.reshape(B, src_V, S)[:, i + src_V - V],
325
+ source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V],
326
+ mode='multimodal') # B, S, D
327
+ new_embedding = new_embedding.repeat_interleave(V, dim=0)
328
+ new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device))
329
+ new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device))
330
+
331
+ encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1)
332
+ encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding
333
+ encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
334
+
335
+ return original_images, images, texts, ori_test_images
336
+
337
+
338
+ def training_step(self, batch, batch_idx):
339
+ loss = self(batch)
340
+ self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True)
341
+ return loss
342
+
343
+ def validation_step(self, batch, batch_idx):
344
+ loss = self(batch)
345
+ self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
346
+
347
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
348
+ original_images, images, texts, ori_test_images = self.sample(batch)
349
+ if self.args.calculate_fid:
350
+ original_images = original_images.cpu().numpy().astype('uint8')
351
+ original_images = [Image.fromarray(im, 'RGB') for im in original_images]
352
+
353
+ # ori_test_images = torch.stack(ori_test_images).cpu().numpy().astype('uint8')
354
+ # ori_test_images = [Image.fromarray(im, 'RGB') for im in ori_test_images]
355
+ ori = self.inception_feature(original_images).cpu().numpy()
356
+ gen = self.inception_feature(images).cpu().numpy()
357
+ else:
358
+ ori = None
359
+ gen = None
360
+
361
+ return images, ori, gen, ori_test_images, texts
362
+
363
+ def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta):
364
+ latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8),
365
+ device=self.device)
366
+
367
+ # set timesteps
368
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
369
+ extra_set_kwargs = {}
370
+ if accepts_offset:
371
+ extra_set_kwargs["offset"] = 1
372
+
373
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
374
+
375
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
376
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
377
+ latents = latents * self.scheduler.sigmas[0]
378
+
379
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
380
+ extra_step_kwargs = {}
381
+ if accepts_eta:
382
+ extra_step_kwargs["eta"] = eta
383
+
384
+ for i, t in enumerate(self.scheduler.timesteps):
385
+ # expand the latents if we are doing classifier free guidance
386
+ latent_model_input = torch.cat([latents] * 2)
387
+
388
+ # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states).sample
389
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample
390
+
391
+ # perform guidance
392
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
393
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
394
+
395
+ # compute the previous noisy sample x_t -> x_t-1
396
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
397
+
398
+ # scale and decode the image latents with vae
399
+ latents = 1 / 0.18215 * latents
400
+ image = self.vae.decode(latents).sample
401
+
402
+ image = (image / 2 + 0.5).clamp(0, 1)
403
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
404
+
405
+ return self.numpy_to_pil(image)
406
+
407
+ @staticmethod
408
+ def numpy_to_pil(images):
409
+ """
410
+ Convert a numpy image or a batch of images to a PIL image.
411
+ """
412
+ if images.ndim == 3:
413
+ images = images[None, ...]
414
+ images = (images * 255).round().astype("uint8")
415
+ pil_images = [Image.fromarray(image, 'RGB') for image in images]
416
+
417
+ return pil_images
418
+
419
+ def inception_feature(self, images):
420
+ images = torch.stack([self.fid_augment(image) for image in images])
421
+ images = images.type(torch.FloatTensor).to(self.device)
422
+ images = (images + 1) / 2
423
+ images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
424
+ pred = self.inception(images)[0]
425
+
426
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
427
+ pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
428
+ return pred.reshape(-1, 2048)
429
+
430
+
431
+ def train(args: DictConfig) -> None:
432
+ dataloader = LightningDataset(args)
433
+ dataloader.setup('fit')
434
+ # dataloader.
435
+ model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader())
436
+
437
+ logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False)
438
+
439
+ checkpoint_callback = ModelCheckpoint(
440
+ dirpath=os.path.join(args.ckpt_dir, args.run_name),
441
+ save_top_k=0,
442
+ save_last=True
443
+ )
444
+
445
+ lr_monitor = LearningRateMonitor(logging_interval='step')
446
+
447
+ callback_list = [lr_monitor, checkpoint_callback]
448
+
449
+ trainer = pl.Trainer(
450
+ accelerator='gpu',
451
+ devices=args.gpu_ids,
452
+ max_epochs=args.max_epochs,
453
+ benchmark=True,
454
+ logger=logger,
455
+ log_every_n_steps=1,
456
+ callbacks=callback_list,
457
+ strategy=DDPStrategy(find_unused_parameters=False)
458
+ )
459
+ trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
460
+
461
+
462
+ def sample(args: DictConfig) -> None:
463
+
464
+ assert args.test_model_file is not None, "test_model_file cannot be None"
465
+ assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"
466
+ dataloader = LightningDataset(args)
467
+ dataloader.setup('test')
468
+ model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)
469
+
470
+ predictor = pl.Trainer(
471
+ accelerator='gpu',
472
+ devices=args.gpu_ids,
473
+ max_epochs=-1,
474
+ benchmark=True
475
+ )
476
+ predictions = predictor.predict(model, dataloader)
477
+ images = [elem for sublist in predictions for elem in sublist[0]]
478
+ ori_images = [elem for sublist in predictions for elem in sublist[3]]
479
+ ori_test_images = list()
480
+ if not os.path.exists(args.sample_output_dir):
481
+ try:
482
+ os.mkdir(args.sample_output_dir)
483
+ except:
484
+ pass
485
+
486
+ text_list = [elem for sublist in predictions for elem in sublist[4]]
487
+ ################################
488
+ # print(f"index: {index}")
489
+ num_images = len(images)
490
+ num_groups = (num_images + 4) // 5 # 计算总共需要的组数
491
+
492
+ for g in range(num_groups):
493
+ print('Story {}:'.format(g + 1)) # 打印组号
494
+ start_index = g * 5 # 当前组的起始索引
495
+ end_index = min(start_index + 5, num_images) # 当前组的结束索引
496
+ for i in range(start_index, end_index):
497
+ print(text_list[i]) # 打印对应的文本
498
+ images[i].save(
499
+ os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
500
+ # ori_images[i] = ori_images[i]
501
+ ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
502
+ ori_test_images.append(ori_images_pil)
503
+ ori_images_pil.save(
504
+ os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
505
+ # for i, im in enumerate(ori_images):
506
+ # file_path = '/root/lihui/StoryVisualization/ori_test_images/image{}.png'.format(i)
507
+ # cv2.imwrite(file_path, im)
508
+
509
+
510
+ if args.calculate_fid:
511
+ ori = np.array([elem for sublist in predictions for elem in sublist[1]])
512
+ gen = np.array([elem for sublist in predictions for elem in sublist[2]])
513
+ fid = calculate_fid_given_features(ori, gen)
514
+ print('FID: {}'.format(fid))
515
+
516
+
517
+
518
+
519
+
520
+ @hydra.main(config_path=".", config_name="config")
521
+ def main(args: DictConfig) -> None:
522
+ pl.seed_everything(args.seed)
523
+ if args.num_cpu_cores > 0:
524
+ torch.set_num_threads(args.num_cpu_cores)
525
+
526
+ if args.mode == 'train':
527
+ ############################
528
+ train(args)
529
+ elif args.mode == 'sample':
530
+ # dataloader = LightningDataset(args)
531
+ # dataloader.setup('test')
532
+ sample(args)
533
+
534
+
535
+
536
+ if __name__ == '__main__':
537
+ main()
pororo_100.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b5d47440de7abbbbb2265e1d5ecbc1c5d4d3188434db3988cb13e7ec5fa7549
3
+ size 69568
readme-storyvisualization.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 一、基于叙事文本的跨模态序列图像生成模型
2
+
3
+ ## 安装环境
4
+ conda create -n arldm python=3.8
5
+ conda activate arldm
6
+ conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts
7
+ cd /root/lihui/StoryVisualization
8
+ pip install -r requirements.txt
9
+ ## 数据准备
10
+ Download the PororoSV dataset here.
11
+ To accelerate I/O, using the following scrips to convert your downloaded data to HDF5
12
+ python data_script/pororo_hdf5.py
13
+ --data_dir /path/to/pororo_data
14
+ --save_path /path/to/save_hdf5_file
15
+ ## 配置文件config.yaml
16
+
17
+ #device
18
+ mode: sample # train sample
19
+ ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
20
+ run_name: ARLDM # name for this run
21
+
22
+ #train
23
+ train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
24
+
25
+ #sample
26
+ test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
27
+ sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
28
+ ## 训练
29
+ 在 config.yaml 中指定您的目录和设备配置并运行:
30
+ python main.py
31
+ ## 采样
32
+ 在 config.yaml 中指定您的目录和设备配置并运行:
33
+ python main.py
34
+ ## 引用
35
+ @article{pan2022synthesizing,
36
+ title={Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models},
37
+ author={Pan, Xichen and Qin, Pengda and Li, Yuhong and Xue, Hui and Chen, Wenhu},
38
+ journal={arXiv preprint arXiv:2211.10950},
39
+ year={2022}
40
+ }
41
+
42
+
43
+ ### 二、基于Real-ESRGAN的超分算法
44
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
45
+ [论文]   [项目主页]   [YouTube 视频]   [B站视频]   [Poster]   [PPT]
46
+ Xintao Wang, Liangbin Xie, Chao Dong, Ying Shan
47
+ Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
48
+ ## 环境
49
+ Python >= 3.7 (推荐使用Anaconda或Miniconda)
50
+ PyTorch >= 1.7
51
+ ## 安装
52
+ 1、直接进入已配好的文件夹
53
+ cd /root/lihui/StoryVisualization/Real-ESRGAN
54
+ 2、或 把项目克隆到本地
55
+ bash git clone https://github.com/xinntao/Real-ESRGAN.git cd Real-ESRGAN
56
+ 3、 安装各种依赖
57
+ ```bash
58
+ 安装 basicsr - https://github.com/xinntao/BasicSR
59
+ #我们使用BasicSR来训练以及推断
60
+ pip install basicsr
61
+ #facexlib和gfpgan是用来增强人脸的
62
+ pip install facexlib pip install gfpgan pip install -r requirements.txt python setup.py develop
63
+ ```
64
+ ## 训练
65
+ 训练好的模型: RealESRGAN_x4plus_anime_6B
66
+ 有关waifu2x的更多信息和对比在anime_model.md中。
67
+ ## 下载模型
68
+ wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
69
+ ## 推断
70
+ python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
71
+ 结果在results文件夹
72
+ ## BibTeX 引用
73
+ @Article{wang2021realesrgan,
74
+ title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
75
+ author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
76
+ journal={arXiv:2107.10833},
77
+ year={2021}
78
+ }
79
+
80
+
81
+ ### 三、基于YOLOv5的目标角色检测算法
82
+ ## 安装
83
+ 克隆 repo,并要求在 Python>=3.7.0 环境中安装 requirements.txt ,且要求 PyTorch>=1.7 。
84
+ git clone https://github.com/ultralytics/yolov5 # clone
85
+ cd /root/lihui/StoryVisualization
86
+ cd yolov5
87
+ pip install -r requirements.txt # install
88
+ ## 转换图片
89
+ cd /root/lihui/StoryVisualization
90
+ python transtoyolo.py
91
+ ## 使用 detect.py 推理
92
+ detect.py 在各种来源上运行推理, 模型 自动从 最新的YOLOv5 release 中下载,并将结果保存到 runs/detect 。
93
+ python detect.py --weights yolov5s.pt --source 0 # webcam
94
+ img.jpg # image
95
+ vid.mp4 # video
96
+ screen # screenshot
97
+ path/ # directory
98
+ list.txt # list of images
99
+ list.streams # list of streams
100
+ 'path/*.jpg' # glob
101
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
102
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
103
+ ## 训练
104
+ 最新的 模型 和 数据集 将自动的从 YOLOv5 release 中下载。 YOLOv5n/s/m/l/x 在 V100 GPU 的训练时间为 1/2/4/6/8 天( 多GPU 训练速度更快)。 尽可能使用更大的 --batch-size ,或通过 --batch-size -1 实现 YOLOv5 自动批处理 。下方显示的 batchsize 适用于 V100-16GB。
105
+ python train.py --data xxx.yaml --epochs 500 --weights '' --cfg yolov5l --batch-size 64
106
+ # xx.yaml文件为转换后的数据
107
+
108
+ ## 许可
109
+ YOLOv5 在两种不同的 License 下可用:
110
+ AGPL-3.0 License: 查看 License 文件的详细信息。
111
+ 企业License:在没有 AGPL-3.0 开源要求的情况下为商业产品开发提供更大的灵活性。典型用例是将 Ultralytics 软件和 AI 模型嵌入到商业产品和应用程序中。在以下位置申请企业许可证 Ultralytics 许可 。
112
+
113
+
114
+ ### 四、演示系统
115
+
116
+ ## 指定文件目录并运行:
117
+ cd /root/lihui/StoryVisualization/visualsystem
118
+ python main.py
119
+
120
+
121
+ #
122
+ Your identification has been saved in .
123
+ Your public key has been saved in C:\Users\30254/.ssh/id_ed25519.pub.
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch_lightning<1.7.0
2
+ lightning-bolts
3
+ transformers==4.24.0
4
+ diffusers==0.7.2
5
+ timm
6
+ ftfy
7
+ hydra-core
8
+ opencv-python
9
+ h5py
10
+ scipy
run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python main.py
test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import copy
4
+ import os
5
+ import random
6
+
7
+ import numpy
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+
12
+ def gettext(index):
13
+ with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as h5:
14
+ story = list()
15
+ h5 = h5['test']
16
+ # 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
17
+ texts = h5['text'][index].decode('utf-8').split('|')
18
+ symbol = '\n'
19
+ texts = symbol.join(texts)
20
+ texts = 'Story<' + str(index) + '> :' + '\n' + texts
21
+ print(texts)
22
+ return texts
23
+
24
+
25
+ # for i in range(1000):
26
+ # gettext(i)
27
+
28
+ # 截取前100的数据集
29
+ # ###正确的##############
30
+ # # import h5py
31
+ # # import numpy as np
32
+ # # from PIL import Image
33
+ # #
34
+ # #
35
+ # # # 创建名为“images”的子目录来保存图像
36
+ # # os.makedirs("train_images", exist_ok=True)
37
+ # #
38
+ # # 创建一个h5文件
39
+ # nf = h5py.File('/root/lihui/StoryVisualization/pororo_100.h5', "w")
40
+ # with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as f:
41
+ # test_group = f['test']
42
+ # texts = np.array(test_group['text'][()])
43
+ # ngroup = nf.create_group('test')
44
+ # ntext = ngroup.create_dataset('text', (100,), dtype=h5py.string_dtype(encoding='utf-8'))
45
+ # for i in range(100):
46
+ # ntext[i]=texts[i]
47
+ # print(f"样本 {i}:")
48
+ # # for j in range(5):
49
+ # # # 创建一个固定的文件名来保存图像
50
+ # # # filename = os.path.join("images", f"image_{i}_{j}.png")
51
+ # # # # 将HDF5文件中的图像数据保存到文件中
52
+ # # # with open(filename, "wb") as img_file:
53
+ # # # img_file.write(test_group[f'image{j}'][i])
54
+ # # # 打印文本信息和文件名
55
+ # # ntext[i]='|'.join(texts[i].decode('utf-8').split('|')[j])
56
+ # # print(f"图像{j}已保存到文件:{filename}")
57
+ # print(ntext[i])
58
+ # nf.close()
59
+
60
+ #保存测试集图像,随机截取视频帧
61
+ with h5py.File(r'C:\Users\zjlab\Desktop\StoryVisualization\pororo.h5', 'r') as h5:
62
+ h5 = h5['test']
63
+
64
+ for index in range(len(h5['text'])): #len(h5['text'])
65
+ # index = int(index + 1)
66
+ # print(index)
67
+ images = list()
68
+ for i in range(5):
69
+ # 从h5文件中读取一组图像和对应的文本。
70
+ im = h5['image{}'.format(i)][index]
71
+ # print(im)
72
+ # pil_img = Image.fromarray(im)
73
+ # # 保存图像
74
+ # pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
75
+ # 对每个图像解码
76
+ im = cv2.imdecode(im, cv2.IMREAD_COLOR)
77
+ # 随机选择一个128像素的图像切片
78
+ idx = random.randint(0, im.shape[0] / 128 - 1)
79
+ # 将切片后的图像加到images列表中
80
+ images.append(im[idx * 128: (idx + 1) * 128])
81
+ # 深拷贝,后续不随images变化
82
+ # ori_images = copy.deepcopy(images)
83
+ # 保存test原始图像
84
+
85
+ # for i, im in enumerate(images):
86
+ # file_path = 'C:/Users/zjlab/Desktop/StoryVisualization/test_images/group{:02d}_image{:02d}.png'.format(
87
+ # index + 1,
88
+ # i + 1)
89
+ # cv2.imwrite(file_path, im)
90
+
91
+ ori_images_pil = Image.fromarray(images[i])#numpy.uint8(images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
92
+ ori_images_pil.save(
93
+ os.path.join('C:/Users/zjlab/Desktop/StoryVisualization/test_images',
94
+ 'group{:02d}_image{:02d}.png'.format(index + 1,i + 1)))
transtoyolo.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import numpy as np
5
+ import json
6
+ from glob import glob
7
+ import cv2
8
+ import shutil
9
+ import yaml
10
+ from sklearn.model_selection import train_test_split
11
+ from tqdm import tqdm
12
+
13
+
14
+ # 获取当前路径
15
+ ROOT_DIR = os.getcwd()
16
+
17
+ '''
18
+ 统一图像格式
19
+ '''
20
+ def change_image_format(label_path=ROOT_DIR, suffix='.png'):
21
+ """
22
+ 统一当前文件夹下所有图像的格式,如'.jpg'
23
+ :param suffix: 图像文件后缀
24
+ :param label_path:当前文件路径
25
+ :return:
26
+ """
27
+ externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
28
+ files = list()
29
+ # 获取尾缀在ecterns中的所有图像
30
+ for extern in externs:
31
+ files.extend(glob(label_path + "\\*." + extern))
32
+ # 遍历所有图像,转换图像格式
33
+ for file in files:
34
+ name = ''.join(file.split('.')[:-1])
35
+ file_suffix = file.split('.')[-1]
36
+ if file_suffix != suffix.split('.')[-1]:
37
+ # 重命名为jpg
38
+ new_name = name + suffix
39
+ # 读取图像
40
+ image = cv2.imread(file)
41
+ # 重新存图为jpg格式
42
+ cv2.imwrite(new_name, image)
43
+ # 删除旧图像
44
+ os.remove(file)
45
+
46
+
47
+
48
+ '''
49
+ 读取所有json文件,获取所有的类别
50
+ '''
51
+ def get_all_class(file_list, label_path=ROOT_DIR):
52
+ """
53
+ 从json文件中获取当前数据的所有类别
54
+ :param file_list:当前路径下的所有文件名
55
+ :param label_path:当前文件路径
56
+ :return:
57
+ """
58
+ # 初始化类别列表
59
+ classes = list()
60
+ # 遍历所有json,读取shape中的label值内容,添加到classes
61
+ for filename in tqdm(file_list):
62
+ json_path = os.path.join(label_path, filename + '.json')
63
+ json_file = json.load(open(json_path, "r", encoding="utf-8"))
64
+ for item in json_file["shapes"]:
65
+ label_class = item['label']
66
+ if label_class not in classes:
67
+ classes.append(label_class)
68
+ print('read file done')
69
+ return classes
70
+
71
+
72
+ '''
73
+ 划分训练集、验证机、测试集
74
+ '''
75
+ def split_dataset(label_path, test_size=0.3, isUseTest=False, useNumpyShuffle=False):
76
+ """
77
+ 将文件分为训练集,测试集和验证集
78
+ :param useNumpyShuffle: 使用numpy方法分割数据集
79
+ :param test_size: 分割测试集或验证集的比例
80
+ :param isUseTest: 是否使用测试集,默认为False
81
+ :param label_path:当前文件路径
82
+ :return:
83
+ """
84
+ # 获取所有json
85
+ files = glob(label_path + "\\*.json")
86
+ files = [i.replace("\\", "/").split("/")[-1].split(".json")[0] for i in files]
87
+
88
+ if useNumpyShuffle:
89
+ file_length = len(files)
90
+ index = np.arange(file_length)
91
+ np.random.seed(32)
92
+ np.random.shuffle(index) # 随机划分
93
+
94
+ test_files = None
95
+ # 是否有测试集
96
+ if isUseTest:
97
+ trainval_files, test_files = np.array(files)[index[:int(file_length * (1 - test_size))]], np.array(files)[
98
+ index[int(file_length * (1 - test_size)):]]
99
+ else:
100
+ trainval_files = files
101
+ # 划分训练集和测试集
102
+ train_files, val_files = np.array(trainval_files)[index[:int(len(trainval_files) * (1 - test_size))]], \
103
+ np.array(trainval_files)[index[int(len(trainval_files) * (1 - test_size)):]]
104
+ else:
105
+ test_files = None
106
+ if isUseTest:
107
+ trainval_files, test_files = train_test_split(files, test_size=test_size, random_state=55)
108
+ else:
109
+ trainval_files = files
110
+ train_files, val_files = train_test_split(trainval_files, test_size=test_size, random_state=55)
111
+
112
+ return train_files, val_files, test_files, files
113
+
114
+
115
+ '''
116
+ 生成yolov5的训练、验证、测试集的文件夹
117
+ '''
118
+ def create_save_file(label_path=ROOT_DIR):
119
+ """
120
+ 按照训练时的图像和标注路径创建文件夹
121
+ :param label_path:当前文件路径
122
+ :return:
123
+ """
124
+ # 生成训练集
125
+ train_image = os.path.join(label_path, 'train', 'images')
126
+ if not os.path.exists(train_image):
127
+ os.makedirs(train_image)
128
+ train_label = os.path.join(label_path, 'train', 'labels')
129
+ if not os.path.exists(train_label):
130
+ os.makedirs(train_label)
131
+ # 生成验证集
132
+ val_image = os.path.join(label_path, 'valid', 'images')
133
+ if not os.path.exists(val_image):
134
+ os.makedirs(val_image)
135
+ val_label = os.path.join(label_path, 'valid', 'labels')
136
+ if not os.path.exists(val_label):
137
+ os.makedirs(val_label)
138
+ # 生成测试集
139
+ test_image = os.path.join(label_path, 'test', 'images')
140
+ if not os.path.exists(test_image):
141
+ os.makedirs(test_image)
142
+ test_label = os.path.join(label_path, 'test', 'labels')
143
+ if not os.path.exists(test_label):
144
+ os.makedirs(test_label)
145
+ return train_image, train_label, val_image, val_label, test_image, test_label
146
+
147
+
148
+
149
+ '''
150
+ 转换,根据图像大小,返回box框的中点和高宽信息
151
+ '''
152
+ def convert(size, box):
153
+ # 宽
154
+ dw = 1. / (size[0])
155
+ # 高
156
+ dh = 1. / (size[1])
157
+
158
+ x = (box[0] + box[1]) / 2.0 - 1
159
+ y = (box[2] + box[3]) / 2.0 - 1
160
+ # 宽
161
+ w = box[1] - box[0]
162
+ # 高
163
+ h = box[3] - box[2]
164
+
165
+ x = x * dw
166
+ w = w * dw
167
+ y = y * dh
168
+ h = h * dh
169
+ return x, y, w, h
170
+
171
+
172
+ '''
173
+ 移动图像和标注文件到指定的训练集、验证集和测试集中
174
+ '''
175
+ def push_into_file(file, images, labels, label_path=ROOT_DIR, suffix='.jpg'):
176
+ """
177
+ 最终生成在当前文件夹下的所有文件按image和label分别存在到训练集/验证集/测试集路径的文件夹下
178
+ :param file: 文件名列表
179
+ :param images: 存放images的路径
180
+ :param labels: 存放labels的路径
181
+ :param label_path: 当前文件路径
182
+ :param suffix: 图像文件后缀
183
+ :return:
184
+ """
185
+ # 遍历所有文件
186
+ for filename in file:
187
+ # 图像文件
188
+ image_file = os.path.join(label_path, filename + suffix)
189
+ # 标注文件
190
+ label_file = os.path.join(label_path, filename + '.txt')
191
+ # yolov5存放图像文件夹
192
+ if not os.path.exists(os.path.join(images, filename + suffix)):
193
+ try:
194
+ shutil.move(image_file, images)
195
+ except OSError:
196
+ pass
197
+ # yolov5存放标注文件夹
198
+ if not os.path.exists(os.path.join(labels, filename + suffix)):
199
+ try:
200
+ shutil.move(label_file, labels)
201
+ except OSError:
202
+ pass
203
+
204
+ '''
205
+
206
+ '''
207
+ def json2txt(classes, txt_Name='allfiles', label_path=ROOT_DIR, suffix='.png'):
208
+ """
209
+ 将json文件转化为txt文件,并将json文件存放到指定文件夹
210
+ :param classes: 类别名
211
+ :param txt_Name:txt文件,用来存放所有文件的路径
212
+ :param label_path:当前文件路径
213
+ :param suffix:图像文件后缀
214
+ :return:
215
+ """
216
+ store_json = os.path.join(label_path, 'json')
217
+ if not os.path.exists(store_json):
218
+ os.makedirs(store_json)
219
+
220
+ _, _, _, files = split_dataset(label_path)
221
+ if not os.path.exists(os.path.join(label_path, 'tmp')):
222
+ os.makedirs(os.path.join(label_path, 'tmp'))
223
+
224
+ list_file = open('tmp/%s.txt' % txt_Name, 'w')
225
+ for json_file_ in tqdm(files):
226
+ json_filename = os.path.join(label_path, json_file_ + ".json")
227
+ imagePath = os.path.join(label_path, json_file_ + suffix)
228
+ list_file.write('%s\n' % imagePath)
229
+ out_file = open('%s/%s.txt' % (label_path, json_file_), 'w')
230
+ json_file = json.load(open(json_filename, "r", encoding="utf-8"))
231
+ if os.path.exists(imagePath):
232
+ height, width, channels = cv2.imread(imagePath).shape
233
+ for multi in json_file["shapes"]:
234
+ if len(multi["points"][0]) == 0:
235
+ out_file.write('')
236
+ continue
237
+ points = np.array(multi["points"])
238
+ xmin = min(points[:, 0]) if min(points[:, 0]) > 0 else 0
239
+ xmax = max(points[:, 0]) if max(points[:, 0]) > 0 else 0
240
+ ymin = min(points[:, 1]) if min(points[:, 1]) > 0 else 0
241
+ ymax = max(points[:, 1]) if max(points[:, 1]) > 0 else 0
242
+ label = multi["label"]
243
+ if xmax <= xmin:
244
+ pass
245
+ elif ymax <= ymin:
246
+ pass
247
+ else:
248
+ cls_id = classes.index(label)
249
+ b = (float(xmin), float(xmax), float(ymin), float(ymax))
250
+ bb = convert((width, height), b)
251
+ out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
252
+ # print(json_filename, xmin, ymin, xmax, ymax, cls_id)
253
+ if not os.path.exists(os.path.join(store_json, json_file_ + '.json')):
254
+ try:
255
+ shutil.move(json_filename, store_json)
256
+ except OSError:
257
+ pass
258
+
259
+ '''
260
+ 创建yaml文件
261
+ '''
262
+ def create_yaml(classes, label_path, isUseTest=False):
263
+ nc = len(classes)
264
+ if not isUseTest:
265
+ desired_caps = {
266
+ 'path': label_path,
267
+ 'train': 'train/images',
268
+ 'val': 'valid/images',
269
+ 'nc': nc,
270
+ 'names': classes
271
+ }
272
+ else:
273
+ desired_caps = {
274
+ 'path': label_path,
275
+ 'train': 'train/images',
276
+ 'val': 'valid/images',
277
+ 'test': 'test/images',
278
+ 'nc': nc,
279
+ 'names': classes
280
+ }
281
+ yamlpath = os.path.join(label_path, "data" + ".yaml")
282
+
283
+ # 写入到yaml文件
284
+ with open(yamlpath, "w+", encoding="utf-8") as f:
285
+ for key, val in desired_caps.items():
286
+ yaml.dump({key: val}, f, default_flow_style=False)
287
+
288
+
289
+ # 首先确保当前文件夹下的所有图片统一后缀,如.jpg,如果为其他后缀,将suffix改为对应的后缀,如.png
290
+ def ChangeToYolo5(label_path=r"D:\storydata", suffix='.png', test_size=0.1, isUseTest=False):
291
+ """
292
+ 生成最终标准格式的文件
293
+ :param test_size: 分割测试集或验证集的比例
294
+ :param label_path:当前文件路径
295
+ :param suffix: 文件后缀名
296
+ :param isUseTest: 是否使用测试集
297
+ :return:
298
+ """
299
+ # step1:统一图像格式
300
+ change_image_format(label_path)
301
+ # step2:根据json文件划分训练集、验证集、测试集
302
+ train_files, val_files, test_file, files = split_dataset(label_path, test_size=test_size, isUseTest=isUseTest)
303
+ # step3:根据json文件,获取所有类别
304
+ classes = get_all_class(files)
305
+ # step4:将json文件转化为txt文件,并将json文件存放到指定文件夹
306
+ json2txt(classes)
307
+ # step5:创建yolov5训练所需的yaml文件
308
+ create_yaml(classes, label_path, isUseTest=isUseTest)
309
+ # step6:生成yolov5的训练、验证、测试集的文件夹
310
+ train_image, train_label, val_image, val_label, test_image, test_label = create_save_file(label_path)
311
+ # step7:将所有图像和标注文件,移动到对应的训练集、验证集、测试集
312
+ push_into_file(train_files, train_image, train_label, suffix=suffix) # 将文件移动到训练集文件中
313
+ push_into_file(val_files, val_image, val_label, suffix=suffix) # 将文件移动到验证集文件夹中
314
+ if test_file is not None: # 如果测试集存在,则将文件移动到测试集文件中
315
+ push_into_file(test_file, test_image, test_label, suffix=suffix)
316
+ print('create dataset done')
317
+
318
+
319
+ if __name__ == "__main__":
320
+ ChangeToYolo5()