Upload 11 files
Browse files- __init__.py +0 -0
- config.yaml +63 -0
- environment.yml +271 -0
- fid_utils.py +41 -0
- main.py +537 -0
- pororo_100.h5 +3 -0
- readme-storyvisualization.md +123 -0
- requirements.txt +10 -0
- run.sh +1 -0
- test.py +94 -0
- transtoyolo.py +320 -0
__init__.py
ADDED
Binary file (2 Bytes). View file
|
|
config.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# device
|
2 |
+
mode: sample # train sample
|
3 |
+
gpu_ids: [3] # gpu ids
|
4 |
+
batch_size: 1 # batch size each item denotes one story
|
5 |
+
num_workers: 4 # number of workers
|
6 |
+
num_cpu_cores: -1 # number of cpu cores
|
7 |
+
seed: 0 # random seed
|
8 |
+
ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
|
9 |
+
run_name: ARLDM # name for this run
|
10 |
+
|
11 |
+
# task
|
12 |
+
dataset: pororo # pororo flintstones vistsis vistdii
|
13 |
+
task: visualization # continuation visualization
|
14 |
+
|
15 |
+
# train
|
16 |
+
init_lr: 1e-5 # initial learning rate
|
17 |
+
warmup_epochs: 1 # warmup epochs
|
18 |
+
max_epochs: 5 #50 # max epochs
|
19 |
+
train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
|
20 |
+
freeze_clip: True #False # whether to freeze clip
|
21 |
+
freeze_blip: True #False # whether to freeze blip
|
22 |
+
freeze_resnet: True #False # whether to freeze resnet
|
23 |
+
|
24 |
+
# sample
|
25 |
+
test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
|
26 |
+
calculate_fid: True # whether to calculate FID scores
|
27 |
+
scheduler: ddim # ddim pndm
|
28 |
+
guidance_scale: 6 # guidance scale
|
29 |
+
num_inference_steps: 250 # number of inference steps
|
30 |
+
sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
|
31 |
+
|
32 |
+
pororo:
|
33 |
+
hdf5_file: /root/lihui/StoryVisualization/pororo.h5
|
34 |
+
max_length: 85
|
35 |
+
new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
|
36 |
+
clip_embedding_tokens: 49416
|
37 |
+
blip_embedding_tokens: 30530
|
38 |
+
|
39 |
+
flintstones:
|
40 |
+
hdf5_file: /path/to/flintstones.h5
|
41 |
+
max_length: 91
|
42 |
+
new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
|
43 |
+
clip_embedding_tokens: 49412
|
44 |
+
blip_embedding_tokens: 30525
|
45 |
+
|
46 |
+
vistsis:
|
47 |
+
hdf5_file: /path/to/vist.h5
|
48 |
+
max_length: 100
|
49 |
+
clip_embedding_tokens: 49408
|
50 |
+
blip_embedding_tokens: 30524
|
51 |
+
|
52 |
+
vistdii:
|
53 |
+
hdf5_file: /path/to/vist.h5
|
54 |
+
max_length: 65
|
55 |
+
clip_embedding_tokens: 49408
|
56 |
+
blip_embedding_tokens: 30524
|
57 |
+
|
58 |
+
hydra:
|
59 |
+
run:
|
60 |
+
dir: .
|
61 |
+
output_subdir: null
|
62 |
+
hydra/job_logging: disabled
|
63 |
+
hydra/hydra_logging: disabled
|
environment.yml
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: story
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- _openmp_mutex=5.1=1_gnu
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- brotlipy=0.7.0=py38h27cfd23_1003
|
11 |
+
- bzip2=1.0.8=h7b6447c_0
|
12 |
+
- ca-certificates=2023.01.10=h06a4308_0
|
13 |
+
- certifi=2022.12.7=py38h06a4308_0
|
14 |
+
- cffi=1.15.1=py38h5eee18b_3
|
15 |
+
- cryptography=39.0.1=py38h9ce1e76_0
|
16 |
+
- cuda-cudart=11.7.99=0
|
17 |
+
- cuda-cupti=11.7.101=0
|
18 |
+
- cuda-libraries=11.7.1=0
|
19 |
+
- cuda-nvrtc=11.7.99=0
|
20 |
+
- cuda-nvtx=11.7.91=0
|
21 |
+
- cuda-runtime=11.7.1=0
|
22 |
+
- ffmpeg=4.3=hf484d3e_0
|
23 |
+
- flit-core=3.8.0=py38h06a4308_0
|
24 |
+
- freetype=2.12.1=h4a9f257_0
|
25 |
+
- giflib=5.2.1=h5eee18b_3
|
26 |
+
- gmp=6.2.1=h295c915_3
|
27 |
+
- gnutls=3.6.15=he1e5248_0
|
28 |
+
- idna=3.4=py38h06a4308_0
|
29 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
30 |
+
- jpeg=9e=h5eee18b_1
|
31 |
+
- lame=3.100=h7b6447c_0
|
32 |
+
- lcms2=2.12=h3be6417_0
|
33 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
34 |
+
- lerc=3.0=h295c915_0
|
35 |
+
- libcublas=11.10.3.66=0
|
36 |
+
- libcufft=10.7.2.124=h4fbf590_0
|
37 |
+
- libcufile=1.6.0.25=0
|
38 |
+
- libcurand=10.3.2.56=0
|
39 |
+
- libcusolver=11.4.0.1=0
|
40 |
+
- libcusparse=11.7.4.91=0
|
41 |
+
- libdeflate=1.17=h5eee18b_0
|
42 |
+
- libffi=3.4.2=h6a678d5_6
|
43 |
+
- libgcc-ng=11.2.0=h1234567_1
|
44 |
+
- libgomp=11.2.0=h1234567_1
|
45 |
+
- libiconv=1.16=h7f8727e_2
|
46 |
+
- libidn2=2.3.2=h7f8727e_0
|
47 |
+
- libnpp=11.7.4.75=0
|
48 |
+
- libnvjpeg=11.8.0.2=0
|
49 |
+
- libpng=1.6.39=h5eee18b_0
|
50 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
51 |
+
- libtasn1=4.19.0=h5eee18b_0
|
52 |
+
- libtiff=4.5.0=h6a678d5_2
|
53 |
+
- libunistring=0.9.10=h27cfd23_0
|
54 |
+
- libwebp=1.2.4=h11a3e52_1
|
55 |
+
- libwebp-base=1.2.4=h5eee18b_1
|
56 |
+
- lz4-c=1.9.4=h6a678d5_0
|
57 |
+
- mkl=2021.4.0=h06a4308_640
|
58 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
59 |
+
- mkl_fft=1.3.1=py38hd3c417c_0
|
60 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
61 |
+
- ncurses=6.4=h6a678d5_0
|
62 |
+
- nettle=3.7.3=hbbd107a_1
|
63 |
+
- numpy-base=1.23.5=py38h31eccc5_0
|
64 |
+
- openh264=2.1.1=h4ff587b_0
|
65 |
+
- openssl=1.1.1t=h7f8727e_0
|
66 |
+
- pip=23.0.1=py38h06a4308_0
|
67 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
68 |
+
- pyopenssl=23.0.0=py38h06a4308_0
|
69 |
+
- pysocks=1.7.1=py38h06a4308_0
|
70 |
+
- python=3.8.16=h7a1cb2a_3
|
71 |
+
- pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0
|
72 |
+
- pytorch-cuda=11.7=h778d358_3
|
73 |
+
- pytorch-mutex=1.0=cuda
|
74 |
+
- readline=8.2=h5eee18b_0
|
75 |
+
- six=1.16.0=pyhd3eb1b0_1
|
76 |
+
- sqlite=3.41.1=h5eee18b_0
|
77 |
+
- tk=8.6.12=h1ccaba5_0
|
78 |
+
- typing_extensions=4.4.0=py38h06a4308_0
|
79 |
+
- urllib3=1.26.15=py38h06a4308_0
|
80 |
+
- wheel=0.38.4=py38h06a4308_0
|
81 |
+
- xz=5.2.10=h5eee18b_1
|
82 |
+
- zlib=1.2.13=h5eee18b_0
|
83 |
+
- zstd=1.5.4=hc292b87_0
|
84 |
+
- pip:
|
85 |
+
- absl-py==1.4.0
|
86 |
+
- accelerate==0.17.1
|
87 |
+
- aiofiles==23.1.0
|
88 |
+
- aiohttp==3.8.4
|
89 |
+
- aiosignal==1.3.1
|
90 |
+
- altair==4.2.2
|
91 |
+
- antlr4-python3-runtime==4.9.3
|
92 |
+
- anyio==3.6.2
|
93 |
+
- appdirs==1.4.4
|
94 |
+
- argon2-cffi==21.3.0
|
95 |
+
- argon2-cffi-bindings==21.2.0
|
96 |
+
- arrow==1.2.3
|
97 |
+
- asttokens==2.2.1
|
98 |
+
- async-timeout==4.0.2
|
99 |
+
- attrs==22.2.0
|
100 |
+
- backcall==0.2.0
|
101 |
+
- beautifulsoup4==4.11.2
|
102 |
+
- bleach==6.0.0
|
103 |
+
- cachetools==5.3.0
|
104 |
+
- chardet==5.1.0
|
105 |
+
- charset-normalizer==3.1.0
|
106 |
+
- click==8.1.3
|
107 |
+
- comm==0.1.2
|
108 |
+
- contourpy==1.0.7
|
109 |
+
- cycler==0.11.0
|
110 |
+
- debugpy==1.6.6
|
111 |
+
- decorator==5.1.1
|
112 |
+
- defusedxml==0.7.1
|
113 |
+
- diffusers==0.9.0
|
114 |
+
- docker-pycreds==0.4.0
|
115 |
+
- entrypoints==0.4
|
116 |
+
- executing==1.2.0
|
117 |
+
- fastapi==0.95.0
|
118 |
+
- fastjsonschema==2.16.3
|
119 |
+
- ffmpy==0.3.0
|
120 |
+
- filelock==3.10.0
|
121 |
+
- fire==0.5.0
|
122 |
+
- flatbuffers==23.3.3
|
123 |
+
- fonttools==4.39.3
|
124 |
+
- fqdn==1.5.1
|
125 |
+
- frozenlist==1.3.3
|
126 |
+
- fsspec==2023.3.0
|
127 |
+
- ftfy==6.1.1
|
128 |
+
- gitdb==4.0.10
|
129 |
+
- gitpython==3.1.31
|
130 |
+
- google-auth==2.16.2
|
131 |
+
- google-auth-oauthlib==0.4.6
|
132 |
+
- gradio==3.24.1
|
133 |
+
- gradio-client==0.0.5
|
134 |
+
- grpcio==1.51.3
|
135 |
+
- h11==0.14.0
|
136 |
+
- h5py==3.8.0
|
137 |
+
- httpcore==0.16.3
|
138 |
+
- httpx==0.23.3
|
139 |
+
- huggingface-hub==0.13.2
|
140 |
+
- hydra-core==1.3.2
|
141 |
+
- importlib-metadata==6.1.0
|
142 |
+
- importlib-resources==5.12.0
|
143 |
+
- ipykernel==6.21.3
|
144 |
+
- ipython==8.11.0
|
145 |
+
- ipython-genutils==0.2.0
|
146 |
+
- ipywidgets==8.0.4
|
147 |
+
- isoduration==20.11.0
|
148 |
+
- jedi==0.18.2
|
149 |
+
- jinja2==3.1.2
|
150 |
+
- jsonpointer==2.3
|
151 |
+
- jsonschema==4.17.3
|
152 |
+
- jupyter==1.0.0
|
153 |
+
- jupyter-client==8.0.3
|
154 |
+
- jupyter-console==6.6.3
|
155 |
+
- jupyter-core==5.3.0
|
156 |
+
- jupyter-events==0.6.3
|
157 |
+
- jupyter-server==2.5.0
|
158 |
+
- jupyter-server-terminals==0.4.4
|
159 |
+
- jupyterlab-pygments==0.2.2
|
160 |
+
- jupyterlab-widgets==3.0.5
|
161 |
+
- kiwisolver==1.4.4
|
162 |
+
- lightning-bolts==0.5.0
|
163 |
+
- linkify-it-py==2.0.0
|
164 |
+
- lora-diffusion==0.1.7
|
165 |
+
- markdown==3.4.1
|
166 |
+
- markdown-it-py==2.2.0
|
167 |
+
- markupsafe==2.1.2
|
168 |
+
- matplotlib==3.7.1
|
169 |
+
- matplotlib-inline==0.1.6
|
170 |
+
- mdit-py-plugins==0.3.3
|
171 |
+
- mdurl==0.1.2
|
172 |
+
- mediapipe==0.9.1.0
|
173 |
+
- mistune==2.0.5
|
174 |
+
- multidict==6.0.4
|
175 |
+
- nbclassic==0.5.3
|
176 |
+
- nbclient==0.7.2
|
177 |
+
- nbconvert==7.2.10
|
178 |
+
- nbformat==5.7.3
|
179 |
+
- nest-asyncio==1.5.6
|
180 |
+
- notebook==6.5.3
|
181 |
+
- notebook-shim==0.2.2
|
182 |
+
- numpy==1.24.2
|
183 |
+
- oauthlib==3.2.2
|
184 |
+
- omegaconf==2.3.0
|
185 |
+
- opencv-contrib-python==4.7.0.72
|
186 |
+
- opencv-python==4.7.0.72
|
187 |
+
- orjson==3.8.9
|
188 |
+
- packaging==23.0
|
189 |
+
- pandas==1.5.3
|
190 |
+
- pandocfilters==1.5.0
|
191 |
+
- parso==0.8.3
|
192 |
+
- pathtools==0.1.2
|
193 |
+
- pexpect==4.8.0
|
194 |
+
- pickleshare==0.7.5
|
195 |
+
- pillow==9.4.0
|
196 |
+
- pkgutil-resolve-name==1.3.10
|
197 |
+
- platformdirs==3.1.1
|
198 |
+
- prometheus-client==0.16.0
|
199 |
+
- prompt-toolkit==3.0.38
|
200 |
+
- protobuf==3.20.1
|
201 |
+
- psutil==5.9.4
|
202 |
+
- ptyprocess==0.7.0
|
203 |
+
- pure-eval==0.2.2
|
204 |
+
- pyasn1==0.4.8
|
205 |
+
- pyasn1-modules==0.2.8
|
206 |
+
- pydantic==1.10.7
|
207 |
+
- pydeprecate==0.3.2
|
208 |
+
- pydub==0.25.1
|
209 |
+
- pygments==2.14.0
|
210 |
+
- pyparsing==3.0.9
|
211 |
+
- pyrsistent==0.19.3
|
212 |
+
- python-dateutil==2.8.2
|
213 |
+
- python-json-logger==2.0.7
|
214 |
+
- python-multipart==0.0.6
|
215 |
+
- pytorch-lightning==1.6.5
|
216 |
+
- pytz==2023.3
|
217 |
+
- pyyaml==6.0
|
218 |
+
- pyzmq==25.0.1
|
219 |
+
- qtconsole==5.4.1
|
220 |
+
- qtpy==2.3.0
|
221 |
+
- regex==2022.10.31
|
222 |
+
- requests==2.28.2
|
223 |
+
- requests-oauthlib==1.3.1
|
224 |
+
- rfc3339-validator==0.1.4
|
225 |
+
- rfc3986==1.5.0
|
226 |
+
- rfc3986-validator==0.1.1
|
227 |
+
- rsa==4.9
|
228 |
+
- safetensors==0.3.0
|
229 |
+
- scipy==1.10.1
|
230 |
+
- semantic-version==2.10.0
|
231 |
+
- send2trash==1.8.0
|
232 |
+
- sentry-sdk==1.17.0
|
233 |
+
- setproctitle==1.3.2
|
234 |
+
- setuptools==59.5.0
|
235 |
+
- smmap==5.0.0
|
236 |
+
- sniffio==1.3.0
|
237 |
+
- soupsieve==2.4
|
238 |
+
- stack-data==0.6.2
|
239 |
+
- starlette==0.26.1
|
240 |
+
- tensorboard==2.12.0
|
241 |
+
- tensorboard-data-server==0.7.0
|
242 |
+
- tensorboard-plugin-wit==1.8.1
|
243 |
+
- termcolor==2.2.0
|
244 |
+
- terminado==0.17.1
|
245 |
+
- timm==0.6.12
|
246 |
+
- tinycss2==1.2.1
|
247 |
+
- tokenizers==0.13.2
|
248 |
+
- toolz==0.12.0
|
249 |
+
- torch==1.9.0
|
250 |
+
- torchaudio==0.9.0
|
251 |
+
- torchmetrics==0.11.4
|
252 |
+
- torchvision==0.10.0+cu111
|
253 |
+
- tornado==6.2
|
254 |
+
- tqdm==4.65.0
|
255 |
+
- traitlets==5.9.0
|
256 |
+
- transformers==4.28.1
|
257 |
+
- typing-extensions==4.5.0
|
258 |
+
- uc-micro-py==1.0.1
|
259 |
+
- uri-template==1.2.0
|
260 |
+
- uvicorn==0.21.1
|
261 |
+
- wandb==0.14.0
|
262 |
+
- wcwidth==0.2.6
|
263 |
+
- webcolors==1.12
|
264 |
+
- webencodings==0.5.1
|
265 |
+
- websocket-client==1.5.1
|
266 |
+
- websockets==11.0
|
267 |
+
- werkzeug==2.2.3
|
268 |
+
- widgetsnbextension==4.0.5
|
269 |
+
- yarl==1.8.2
|
270 |
+
- zipp==3.15.0
|
271 |
+
prefix: /root/anaconda3/envs/story
|
fid_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy import linalg
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
6 |
+
mu1 = np.atleast_1d(mu1)
|
7 |
+
mu2 = np.atleast_1d(mu2)
|
8 |
+
|
9 |
+
sigma1 = np.atleast_2d(sigma1)
|
10 |
+
sigma2 = np.atleast_2d(sigma2)
|
11 |
+
|
12 |
+
assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
|
13 |
+
assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
|
14 |
+
|
15 |
+
diff = mu1 - mu2
|
16 |
+
|
17 |
+
# Product might be almost singular
|
18 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
19 |
+
if not np.isfinite(covmean).all():
|
20 |
+
print('fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps)
|
21 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
22 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
23 |
+
|
24 |
+
# Numerical error might give slight imaginary component
|
25 |
+
if np.iscomplexobj(covmean):
|
26 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
27 |
+
m = np.max(np.abs(covmean.imag))
|
28 |
+
raise ValueError('Imaginary component {}'.format(m))
|
29 |
+
covmean = covmean.real
|
30 |
+
|
31 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
32 |
+
|
33 |
+
|
34 |
+
def calculate_fid_given_features(feature1, feature2):
|
35 |
+
mu1 = np.mean(feature1, axis=0)
|
36 |
+
sigma1 = np.cov(feature1, rowvar=False)
|
37 |
+
mu2 = np.mean(feature2, axis=0)
|
38 |
+
sigma2 = np.cov(feature2, rowvar=False)
|
39 |
+
fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
|
40 |
+
|
41 |
+
return fid_value
|
main.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import hydra
|
6 |
+
import numpy as np
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
from PIL import Image
|
12 |
+
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler
|
13 |
+
from omegaconf import DictConfig
|
14 |
+
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
16 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
17 |
+
from pytorch_lightning.strategies import DDPStrategy
|
18 |
+
from torch import nn
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from torchvision import transforms
|
21 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
22 |
+
|
23 |
+
from fid_utils import calculate_fid_given_features
|
24 |
+
from lora_diffusion import monkeypatch_or_replace_lora, tune_lora_scale
|
25 |
+
|
26 |
+
from models.blip_override.blip import blip_feature_extractor, init_tokenizer
|
27 |
+
from models.diffusers_override.unet_2d_condition import UNet2DConditionModel
|
28 |
+
from models.inception import InceptionV3
|
29 |
+
unet_target_replace_module = {"CrossAttention", "Attention", "GEGLU"}
|
30 |
+
#!/usr/bin/env python3
|
31 |
+
from transformers import CLIPProcessor
|
32 |
+
import transformers
|
33 |
+
from PIL import Image
|
34 |
+
import PIL.Image
|
35 |
+
import numpy as np
|
36 |
+
import torchvision.transforms as tvtrans
|
37 |
+
import requests
|
38 |
+
from io import BytesIO
|
39 |
+
|
40 |
+
class LightningDataset(pl.LightningDataModule):
|
41 |
+
def __init__(self, args: DictConfig):
|
42 |
+
super(LightningDataset, self).__init__()
|
43 |
+
self.kwargs = {"num_workers": args.num_workers, "persistent_workers": True if args.num_workers > 0 else False,
|
44 |
+
"pin_memory": True}
|
45 |
+
self.args = args
|
46 |
+
|
47 |
+
def setup(self, stage="fit"):
|
48 |
+
if self.args.dataset == "pororo":
|
49 |
+
import datasets.pororo as data
|
50 |
+
elif self.args.dataset == 'flintstones':
|
51 |
+
import datasets.flintstones as data
|
52 |
+
elif self.args.dataset == 'vistsis':
|
53 |
+
import datasets.vistsis as data
|
54 |
+
elif self.args.dataset == 'vistdii':
|
55 |
+
import datasets.vistdii as data
|
56 |
+
else:
|
57 |
+
raise ValueError("Unknown dataset: {}".format(self.args.dataset))
|
58 |
+
if stage == "fit":
|
59 |
+
self.train_data = data.StoryDataset("train", self.args)
|
60 |
+
self.val_data = data.StoryDataset("val", self.args)
|
61 |
+
if stage == "test":
|
62 |
+
self.test_data = data.StoryDataset("test", self.args)
|
63 |
+
|
64 |
+
def train_dataloader(self):
|
65 |
+
if not hasattr(self, 'trainloader'):
|
66 |
+
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
|
67 |
+
return self.trainloader
|
68 |
+
|
69 |
+
def val_dataloader(self):
|
70 |
+
return DataLoader(self.val_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
71 |
+
|
72 |
+
def test_dataloader(self):
|
73 |
+
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
74 |
+
|
75 |
+
def predict_dataloader(self):
|
76 |
+
return DataLoader(self.test_data, batch_size=self.args.batch_size, shuffle=False, **self.kwargs)
|
77 |
+
|
78 |
+
def get_length_of_train_dataloader(self):
|
79 |
+
if not hasattr(self, 'trainloader'):
|
80 |
+
self.trainloader = DataLoader(self.train_data, batch_size=self.args.batch_size, shuffle=True, **self.kwargs)
|
81 |
+
return len(self.trainloader)
|
82 |
+
|
83 |
+
|
84 |
+
class ARLDM(pl.LightningModule):
|
85 |
+
def __init__(self, args: DictConfig, steps_per_epoch=1):
|
86 |
+
super(ARLDM, self).__init__()
|
87 |
+
self.args = args
|
88 |
+
self.steps_per_epoch = steps_per_epoch
|
89 |
+
"""
|
90 |
+
Configurations
|
91 |
+
"""
|
92 |
+
self.task = args.task
|
93 |
+
|
94 |
+
if args.mode == 'sample':
|
95 |
+
if args.scheduler == "pndm":
|
96 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
97 |
+
skip_prk_steps=True)
|
98 |
+
elif args.scheduler == "ddim":
|
99 |
+
self.scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
100 |
+
clip_sample=False, set_alpha_to_one=True)
|
101 |
+
else:
|
102 |
+
raise ValueError("Scheduler not supported")
|
103 |
+
self.fid_augment = transforms.Compose([
|
104 |
+
transforms.Resize([64, 64]),
|
105 |
+
transforms.ToTensor(),
|
106 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
107 |
+
])
|
108 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
|
109 |
+
self.inception = InceptionV3([block_idx])
|
110 |
+
|
111 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="tokenizer")
|
112 |
+
##############################
|
113 |
+
#self.clip_tokenizer.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/tokenizer')
|
114 |
+
self.blip_tokenizer = init_tokenizer()
|
115 |
+
self.blip_image_processor = transforms.Compose([
|
116 |
+
transforms.Resize([224, 224]),
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
|
119 |
+
])
|
120 |
+
self.max_length = args.get(args.dataset).max_length
|
121 |
+
|
122 |
+
blip_image_null_token = self.blip_image_processor(
|
123 |
+
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))).unsqueeze(0).float()
|
124 |
+
clip_text_null_token = self.clip_tokenizer([""], padding="max_length", max_length=self.max_length,
|
125 |
+
return_tensors="pt").input_ids
|
126 |
+
blip_text_null_token = self.blip_tokenizer([""], padding="max_length", max_length=self.max_length,
|
127 |
+
return_tensors="pt").input_ids
|
128 |
+
|
129 |
+
self.register_buffer('clip_text_null_token', clip_text_null_token)
|
130 |
+
self.register_buffer('blip_text_null_token', blip_text_null_token)
|
131 |
+
self.register_buffer('blip_image_null_token', blip_image_null_token)
|
132 |
+
|
133 |
+
self.text_encoder = CLIPTextModel.from_pretrained('runwayml/stable-diffusion-v1-5',
|
134 |
+
subfolder="text_encoder")
|
135 |
+
############################################
|
136 |
+
#self.text_encoder.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/text_encoder')
|
137 |
+
self.text_encoder.resize_token_embeddings(args.get(args.dataset).clip_embedding_tokens)
|
138 |
+
# resize_position_embeddings
|
139 |
+
old_embeddings = self.text_encoder.text_model.embeddings.position_embedding
|
140 |
+
new_embeddings = self.text_encoder._get_resized_embeddings(old_embeddings, self.max_length)
|
141 |
+
self.text_encoder.text_model.embeddings.position_embedding = new_embeddings
|
142 |
+
self.text_encoder.config.max_position_embeddings = self.max_length
|
143 |
+
self.text_encoder.max_position_embeddings = self.max_length
|
144 |
+
self.text_encoder.text_model.embeddings.position_ids = torch.arange(self.max_length).expand((1, -1))
|
145 |
+
|
146 |
+
self.modal_type_embeddings = nn.Embedding(2, 768)
|
147 |
+
self.time_embeddings = nn.Embedding(5, 768)
|
148 |
+
self.mm_encoder = blip_feature_extractor(
|
149 |
+
# pretrained='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth',
|
150 |
+
pretrained='/root/lihui/StoryVisualization/save_pretrained/model_large.pth',
|
151 |
+
image_size=224, vit='large')#, local_files_only=True)
|
152 |
+
self.mm_encoder.text_encoder.resize_token_embeddings(args.get(args.dataset).blip_embedding_tokens)
|
153 |
+
|
154 |
+
self.vae = AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="vae")
|
155 |
+
self.unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
|
156 |
+
|
157 |
+
self.noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
158 |
+
num_train_timesteps=1000)
|
159 |
+
# monkeypatch_or_replace_lora(
|
160 |
+
# self.unet,
|
161 |
+
# torch.load("lora/example_loras/analog_svd_rank4.safetensors"),
|
162 |
+
# r=4,
|
163 |
+
# target_replace_module=unet_target_replace_module,
|
164 |
+
# )
|
165 |
+
#
|
166 |
+
# tune_lora_scale(self.unet, 1.00)
|
167 |
+
#tune_lora_scale(self.text_encoder, 1.00)
|
168 |
+
|
169 |
+
# torch.manual_seed(0)
|
170 |
+
###################################
|
171 |
+
#self.vae.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/vae')
|
172 |
+
#self.unet.save_pretrained('/root/lihui/StoryVisualization/save_pretrained/unet')
|
173 |
+
|
174 |
+
# Freeze vae and unet
|
175 |
+
self.freeze_params(self.vae.parameters())
|
176 |
+
if args.freeze_resnet:
|
177 |
+
self.freeze_params([p for n, p in self.unet.named_parameters() if "attentions" not in n])
|
178 |
+
|
179 |
+
if args.freeze_blip and hasattr(self, "mm_encoder"):
|
180 |
+
self.freeze_params(self.mm_encoder.parameters())
|
181 |
+
self.unfreeze_params(self.mm_encoder.text_encoder.embeddings.word_embeddings.parameters())
|
182 |
+
|
183 |
+
if args.freeze_clip and hasattr(self, "text_encoder"):
|
184 |
+
self.freeze_params(self.text_encoder.parameters())
|
185 |
+
self.unfreeze_params(self.text_encoder.text_model.embeddings.token_embedding.parameters())
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def freeze_params(params):
|
189 |
+
for param in params:
|
190 |
+
param.requires_grad = False
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def unfreeze_params(params):
|
194 |
+
for param in params:
|
195 |
+
param.requires_grad = True
|
196 |
+
|
197 |
+
def configure_optimizers(self):
|
198 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=1e-4) # optim_bits=8
|
199 |
+
scheduler = LinearWarmupCosineAnnealingLR(optimizer,
|
200 |
+
warmup_epochs=self.args.warmup_epochs * self.steps_per_epoch,
|
201 |
+
max_epochs=self.args.max_epochs * self.steps_per_epoch)
|
202 |
+
optim_dict = {
|
203 |
+
'optimizer': optimizer,
|
204 |
+
'lr_scheduler': {
|
205 |
+
'scheduler': scheduler, # The LR scheduler instance (required)
|
206 |
+
'interval': 'step', # The unit of the scheduler's step size
|
207 |
+
}
|
208 |
+
}
|
209 |
+
return optim_dict
|
210 |
+
|
211 |
+
def forward(self, batch):
|
212 |
+
if self.args.freeze_clip and hasattr(self, "text_encoder"):
|
213 |
+
self.text_encoder.eval()
|
214 |
+
if self.args.freeze_blip and hasattr(self, "mm_encoder"):
|
215 |
+
self.mm_encoder.eval()
|
216 |
+
images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_images = batch
|
217 |
+
B, V, S = captions.shape
|
218 |
+
src_V = V + 1 if self.task == 'continuation' else V
|
219 |
+
images = torch.flatten(images, 0, 1)
|
220 |
+
captions = torch.flatten(captions, 0, 1)
|
221 |
+
attention_mask = torch.flatten(attention_mask, 0, 1)
|
222 |
+
source_images = torch.flatten(source_images, 0, 1)
|
223 |
+
source_caption = torch.flatten(source_caption, 0, 1)
|
224 |
+
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
|
225 |
+
# 1 is not masked, 0 is maske
|
226 |
+
|
227 |
+
classifier_free_idx = np.random.rand(B * V) < 0.1
|
228 |
+
|
229 |
+
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
|
230 |
+
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
|
231 |
+
mode='multimodal').reshape(B, src_V * S, -1)
|
232 |
+
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
|
233 |
+
caption_embeddings[classifier_free_idx] = \
|
234 |
+
self.text_encoder(self.clip_text_null_token).last_hidden_state[0]
|
235 |
+
source_embeddings[classifier_free_idx] = \
|
236 |
+
self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token, attention_mask=None,
|
237 |
+
mode='multimodal')[0].repeat(src_V, 1)
|
238 |
+
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
239 |
+
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
240 |
+
source_embeddings += self.time_embeddings(
|
241 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
242 |
+
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
|
243 |
+
|
244 |
+
attention_mask = torch.cat(
|
245 |
+
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
|
246 |
+
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
|
247 |
+
attention_mask[classifier_free_idx] = False
|
248 |
+
|
249 |
+
# B, V, V, S
|
250 |
+
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
|
251 |
+
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
|
252 |
+
square_mask = square_mask.reshape(B * V, V * S)
|
253 |
+
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
|
254 |
+
|
255 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
256 |
+
latents = latents * 0.18215
|
257 |
+
|
258 |
+
noise = torch.randn(latents.shape, device=self.device)
|
259 |
+
bsz = latents.shape[0]
|
260 |
+
timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=self.device).long()
|
261 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
262 |
+
|
263 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, attention_mask).sample
|
264 |
+
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
265 |
+
return loss
|
266 |
+
|
267 |
+
def sample(self, batch):
|
268 |
+
original_images, captions, attention_mask, source_images, source_caption, source_attention_mask, texts, ori_test_images = batch
|
269 |
+
B, V, S = captions.shape
|
270 |
+
src_V = V + 1 if self.task == 'continuation' else V
|
271 |
+
original_images = torch.flatten(original_images, 0, 1)
|
272 |
+
captions = torch.flatten(captions, 0, 1)
|
273 |
+
attention_mask = torch.flatten(attention_mask, 0, 1)
|
274 |
+
source_images = torch.flatten(source_images, 0, 1)
|
275 |
+
source_caption = torch.flatten(source_caption, 0, 1)
|
276 |
+
source_attention_mask = torch.flatten(source_attention_mask, 0, 1)
|
277 |
+
|
278 |
+
caption_embeddings = self.text_encoder(captions, attention_mask).last_hidden_state # B * V, S, D
|
279 |
+
source_embeddings = self.mm_encoder(source_images, source_caption, source_attention_mask,
|
280 |
+
mode='multimodal').reshape(B, src_V * S, -1)
|
281 |
+
caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
282 |
+
source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
283 |
+
source_embeddings += self.time_embeddings(
|
284 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
285 |
+
source_embeddings = source_embeddings.repeat_interleave(V, dim=0)
|
286 |
+
encoder_hidden_states = torch.cat([caption_embeddings, source_embeddings], dim=1)
|
287 |
+
|
288 |
+
attention_mask = torch.cat(
|
289 |
+
[attention_mask, source_attention_mask.reshape(B, src_V * S).repeat_interleave(V, dim=0)], dim=1)
|
290 |
+
attention_mask = ~(attention_mask.bool()) # B * V, (src_V + 1) * S
|
291 |
+
# B, V, V, S
|
292 |
+
square_mask = torch.triu(torch.ones((V, V), device=self.device)).bool()
|
293 |
+
square_mask = square_mask.unsqueeze(0).unsqueeze(-1).expand(B, V, V, S)
|
294 |
+
square_mask = square_mask.reshape(B * V, V * S)
|
295 |
+
attention_mask[:, -V * S:] = torch.logical_or(square_mask, attention_mask[:, -V * S:])
|
296 |
+
|
297 |
+
uncond_caption_embeddings = self.text_encoder(self.clip_text_null_token).last_hidden_state
|
298 |
+
uncond_source_embeddings = self.mm_encoder(self.blip_image_null_token, self.blip_text_null_token,
|
299 |
+
attention_mask=None, mode='multimodal').repeat(1, src_V, 1)
|
300 |
+
uncond_caption_embeddings += self.modal_type_embeddings(torch.tensor(0, device=self.device))
|
301 |
+
uncond_source_embeddings += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
302 |
+
uncond_source_embeddings += self.time_embeddings(
|
303 |
+
torch.arange(src_V, device=self.device).repeat_interleave(S, dim=0))
|
304 |
+
uncond_embeddings = torch.cat([uncond_caption_embeddings, uncond_source_embeddings], dim=1)
|
305 |
+
uncond_embeddings = uncond_embeddings.expand(B * V, -1, -1)
|
306 |
+
|
307 |
+
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
|
308 |
+
uncond_attention_mask = torch.zeros((B * V, (src_V + 1) * S), device=self.device).bool()
|
309 |
+
uncond_attention_mask[:, -V * S:] = square_mask
|
310 |
+
attention_mask = torch.cat([uncond_attention_mask, attention_mask], dim=0)
|
311 |
+
|
312 |
+
attention_mask = attention_mask.reshape(2, B, V, (src_V + 1) * S)
|
313 |
+
images = list()
|
314 |
+
for i in range(V):
|
315 |
+
encoder_hidden_states = encoder_hidden_states.reshape(2, B, V, (src_V + 1) * S, -1)
|
316 |
+
new_image = self.diffusion(encoder_hidden_states[:, :, i].reshape(2 * B, (src_V + 1) * S, -1),
|
317 |
+
attention_mask[:, :, i].reshape(2 * B, (src_V + 1) * S),
|
318 |
+
512, 512, self.args.num_inference_steps, self.args.guidance_scale, 0.0)
|
319 |
+
images += new_image
|
320 |
+
|
321 |
+
new_image = torch.stack([self.blip_image_processor(im) for im in new_image]).to(self.device)
|
322 |
+
|
323 |
+
new_embedding = self.mm_encoder(new_image, # B,C,H,W
|
324 |
+
source_caption.reshape(B, src_V, S)[:, i + src_V - V],
|
325 |
+
source_attention_mask.reshape(B, src_V, S)[:, i + src_V - V],
|
326 |
+
mode='multimodal') # B, S, D
|
327 |
+
new_embedding = new_embedding.repeat_interleave(V, dim=0)
|
328 |
+
new_embedding += self.modal_type_embeddings(torch.tensor(1, device=self.device))
|
329 |
+
new_embedding += self.time_embeddings(torch.tensor(i + src_V - V, device=self.device))
|
330 |
+
|
331 |
+
encoder_hidden_states = encoder_hidden_states[1].reshape(B * V, (src_V + 1) * S, -1)
|
332 |
+
encoder_hidden_states[:, (i + 1 + src_V - V) * S:(i + 2 + src_V - V) * S] = new_embedding
|
333 |
+
encoder_hidden_states = torch.cat([uncond_embeddings, encoder_hidden_states])
|
334 |
+
|
335 |
+
return original_images, images, texts, ori_test_images
|
336 |
+
|
337 |
+
|
338 |
+
def training_step(self, batch, batch_idx):
|
339 |
+
loss = self(batch)
|
340 |
+
self.log('loss/train_loss', loss, on_step=True, on_epoch=False, sync_dist=True, prog_bar=True)
|
341 |
+
return loss
|
342 |
+
|
343 |
+
def validation_step(self, batch, batch_idx):
|
344 |
+
loss = self(batch)
|
345 |
+
self.log('loss/val_loss', loss, on_step=False, on_epoch=True, sync_dist=True, prog_bar=True)
|
346 |
+
|
347 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
348 |
+
original_images, images, texts, ori_test_images = self.sample(batch)
|
349 |
+
if self.args.calculate_fid:
|
350 |
+
original_images = original_images.cpu().numpy().astype('uint8')
|
351 |
+
original_images = [Image.fromarray(im, 'RGB') for im in original_images]
|
352 |
+
|
353 |
+
# ori_test_images = torch.stack(ori_test_images).cpu().numpy().astype('uint8')
|
354 |
+
# ori_test_images = [Image.fromarray(im, 'RGB') for im in ori_test_images]
|
355 |
+
ori = self.inception_feature(original_images).cpu().numpy()
|
356 |
+
gen = self.inception_feature(images).cpu().numpy()
|
357 |
+
else:
|
358 |
+
ori = None
|
359 |
+
gen = None
|
360 |
+
|
361 |
+
return images, ori, gen, ori_test_images, texts
|
362 |
+
|
363 |
+
def diffusion(self, encoder_hidden_states, attention_mask, height, width, num_inference_steps, guidance_scale, eta):
|
364 |
+
latents = torch.randn((encoder_hidden_states.shape[0] // 2, self.unet.in_channels, height // 8, width // 8),
|
365 |
+
device=self.device)
|
366 |
+
|
367 |
+
# set timesteps
|
368 |
+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
369 |
+
extra_set_kwargs = {}
|
370 |
+
if accepts_offset:
|
371 |
+
extra_set_kwargs["offset"] = 1
|
372 |
+
|
373 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
374 |
+
|
375 |
+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
376 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
377 |
+
latents = latents * self.scheduler.sigmas[0]
|
378 |
+
|
379 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
380 |
+
extra_step_kwargs = {}
|
381 |
+
if accepts_eta:
|
382 |
+
extra_step_kwargs["eta"] = eta
|
383 |
+
|
384 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
385 |
+
# expand the latents if we are doing classifier free guidance
|
386 |
+
latent_model_input = torch.cat([latents] * 2)
|
387 |
+
|
388 |
+
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states).sample
|
389 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states, attention_mask).sample
|
390 |
+
|
391 |
+
# perform guidance
|
392 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
393 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
394 |
+
|
395 |
+
# compute the previous noisy sample x_t -> x_t-1
|
396 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
397 |
+
|
398 |
+
# scale and decode the image latents with vae
|
399 |
+
latents = 1 / 0.18215 * latents
|
400 |
+
image = self.vae.decode(latents).sample
|
401 |
+
|
402 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
403 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
404 |
+
|
405 |
+
return self.numpy_to_pil(image)
|
406 |
+
|
407 |
+
@staticmethod
|
408 |
+
def numpy_to_pil(images):
|
409 |
+
"""
|
410 |
+
Convert a numpy image or a batch of images to a PIL image.
|
411 |
+
"""
|
412 |
+
if images.ndim == 3:
|
413 |
+
images = images[None, ...]
|
414 |
+
images = (images * 255).round().astype("uint8")
|
415 |
+
pil_images = [Image.fromarray(image, 'RGB') for image in images]
|
416 |
+
|
417 |
+
return pil_images
|
418 |
+
|
419 |
+
def inception_feature(self, images):
|
420 |
+
images = torch.stack([self.fid_augment(image) for image in images])
|
421 |
+
images = images.type(torch.FloatTensor).to(self.device)
|
422 |
+
images = (images + 1) / 2
|
423 |
+
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
|
424 |
+
pred = self.inception(images)[0]
|
425 |
+
|
426 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
427 |
+
pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))
|
428 |
+
return pred.reshape(-1, 2048)
|
429 |
+
|
430 |
+
|
431 |
+
def train(args: DictConfig) -> None:
|
432 |
+
dataloader = LightningDataset(args)
|
433 |
+
dataloader.setup('fit')
|
434 |
+
# dataloader.
|
435 |
+
model = ARLDM(args, steps_per_epoch=dataloader.get_length_of_train_dataloader())
|
436 |
+
|
437 |
+
logger = TensorBoardLogger(save_dir=os.path.join(args.ckpt_dir, args.run_name), name='log', default_hp_metric=False)
|
438 |
+
|
439 |
+
checkpoint_callback = ModelCheckpoint(
|
440 |
+
dirpath=os.path.join(args.ckpt_dir, args.run_name),
|
441 |
+
save_top_k=0,
|
442 |
+
save_last=True
|
443 |
+
)
|
444 |
+
|
445 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
446 |
+
|
447 |
+
callback_list = [lr_monitor, checkpoint_callback]
|
448 |
+
|
449 |
+
trainer = pl.Trainer(
|
450 |
+
accelerator='gpu',
|
451 |
+
devices=args.gpu_ids,
|
452 |
+
max_epochs=args.max_epochs,
|
453 |
+
benchmark=True,
|
454 |
+
logger=logger,
|
455 |
+
log_every_n_steps=1,
|
456 |
+
callbacks=callback_list,
|
457 |
+
strategy=DDPStrategy(find_unused_parameters=False)
|
458 |
+
)
|
459 |
+
trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
|
460 |
+
|
461 |
+
|
462 |
+
def sample(args: DictConfig) -> None:
|
463 |
+
|
464 |
+
assert args.test_model_file is not None, "test_model_file cannot be None"
|
465 |
+
assert args.gpu_ids == 1 or len(args.gpu_ids) == 1, "Only one GPU is supported in test mode"
|
466 |
+
dataloader = LightningDataset(args)
|
467 |
+
dataloader.setup('test')
|
468 |
+
model = ARLDM.load_from_checkpoint(args.test_model_file, args=args, strict=False)
|
469 |
+
|
470 |
+
predictor = pl.Trainer(
|
471 |
+
accelerator='gpu',
|
472 |
+
devices=args.gpu_ids,
|
473 |
+
max_epochs=-1,
|
474 |
+
benchmark=True
|
475 |
+
)
|
476 |
+
predictions = predictor.predict(model, dataloader)
|
477 |
+
images = [elem for sublist in predictions for elem in sublist[0]]
|
478 |
+
ori_images = [elem for sublist in predictions for elem in sublist[3]]
|
479 |
+
ori_test_images = list()
|
480 |
+
if not os.path.exists(args.sample_output_dir):
|
481 |
+
try:
|
482 |
+
os.mkdir(args.sample_output_dir)
|
483 |
+
except:
|
484 |
+
pass
|
485 |
+
|
486 |
+
text_list = [elem for sublist in predictions for elem in sublist[4]]
|
487 |
+
################################
|
488 |
+
# print(f"index: {index}")
|
489 |
+
num_images = len(images)
|
490 |
+
num_groups = (num_images + 4) // 5 # 计算总共需要的组数
|
491 |
+
|
492 |
+
for g in range(num_groups):
|
493 |
+
print('Story {}:'.format(g + 1)) # 打印组号
|
494 |
+
start_index = g * 5 # 当前组的起始索引
|
495 |
+
end_index = min(start_index + 5, num_images) # 当前组的结束索引
|
496 |
+
for i in range(start_index, end_index):
|
497 |
+
print(text_list[i]) # 打印对应的文本
|
498 |
+
images[i].save(
|
499 |
+
os.path.join(args.sample_output_dir, 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
|
500 |
+
# ori_images[i] = ori_images[i]
|
501 |
+
ori_images_pil = Image.fromarray(np.uint8(ori_images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
|
502 |
+
ori_test_images.append(ori_images_pil)
|
503 |
+
ori_images_pil.save(
|
504 |
+
os.path.join('/root/lihui/StoryVisualization/ori_test_images_epoch10', 'group{:02d}_image{:02d}.png'.format(g + 1, i - start_index + 1)))
|
505 |
+
# for i, im in enumerate(ori_images):
|
506 |
+
# file_path = '/root/lihui/StoryVisualization/ori_test_images/image{}.png'.format(i)
|
507 |
+
# cv2.imwrite(file_path, im)
|
508 |
+
|
509 |
+
|
510 |
+
if args.calculate_fid:
|
511 |
+
ori = np.array([elem for sublist in predictions for elem in sublist[1]])
|
512 |
+
gen = np.array([elem for sublist in predictions for elem in sublist[2]])
|
513 |
+
fid = calculate_fid_given_features(ori, gen)
|
514 |
+
print('FID: {}'.format(fid))
|
515 |
+
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
|
520 |
+
@hydra.main(config_path=".", config_name="config")
|
521 |
+
def main(args: DictConfig) -> None:
|
522 |
+
pl.seed_everything(args.seed)
|
523 |
+
if args.num_cpu_cores > 0:
|
524 |
+
torch.set_num_threads(args.num_cpu_cores)
|
525 |
+
|
526 |
+
if args.mode == 'train':
|
527 |
+
############################
|
528 |
+
train(args)
|
529 |
+
elif args.mode == 'sample':
|
530 |
+
# dataloader = LightningDataset(args)
|
531 |
+
# dataloader.setup('test')
|
532 |
+
sample(args)
|
533 |
+
|
534 |
+
|
535 |
+
|
536 |
+
if __name__ == '__main__':
|
537 |
+
main()
|
pororo_100.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b5d47440de7abbbbb2265e1d5ecbc1c5d4d3188434db3988cb13e7ec5fa7549
|
3 |
+
size 69568
|
readme-storyvisualization.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 一、基于叙事文本的跨模态序列图像生成模型
|
2 |
+
|
3 |
+
## 安装环境
|
4 |
+
conda create -n arldm python=3.8
|
5 |
+
conda activate arldm
|
6 |
+
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch-lts
|
7 |
+
cd /root/lihui/StoryVisualization
|
8 |
+
pip install -r requirements.txt
|
9 |
+
## 数据准备
|
10 |
+
Download the PororoSV dataset here.
|
11 |
+
To accelerate I/O, using the following scrips to convert your downloaded data to HDF5
|
12 |
+
python data_script/pororo_hdf5.py
|
13 |
+
--data_dir /path/to/pororo_data
|
14 |
+
--save_path /path/to/save_hdf5_file
|
15 |
+
## 配置文件config.yaml
|
16 |
+
|
17 |
+
#device
|
18 |
+
mode: sample # train sample
|
19 |
+
ckpt_dir: /root/lihui/StoryVisualization/save_ckpt_epoch5_new # checkpoint directory
|
20 |
+
run_name: ARLDM # name for this run
|
21 |
+
|
22 |
+
#train
|
23 |
+
train_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for resume, none for train from scratch
|
24 |
+
|
25 |
+
#sample
|
26 |
+
test_model_file: /root/lihui/StoryVisualization/save_ckpt_3last50/ARLDM/last.ckpt # model file for test
|
27 |
+
sample_output_dir: /root/lihui/StoryVisualization/save_samples_128_epoch50 # output directory
|
28 |
+
## 训练
|
29 |
+
在 config.yaml 中指定您的目录和设备配置并运行:
|
30 |
+
python main.py
|
31 |
+
## 采样
|
32 |
+
在 config.yaml 中指定您的目录和设备配置并运行:
|
33 |
+
python main.py
|
34 |
+
## 引用
|
35 |
+
@article{pan2022synthesizing,
|
36 |
+
title={Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models},
|
37 |
+
author={Pan, Xichen and Qin, Pengda and Li, Yuhong and Xue, Hui and Chen, Wenhu},
|
38 |
+
journal={arXiv preprint arXiv:2211.10950},
|
39 |
+
year={2022}
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
### 二、基于Real-ESRGAN的超分算法
|
44 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
|
45 |
+
[论文] [项目主页] [YouTube 视频] [B站视频] [Poster] [PPT]
|
46 |
+
Xintao Wang, Liangbin Xie, Chao Dong, Ying Shan
|
47 |
+
Tencent ARC Lab; Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences
|
48 |
+
## 环境
|
49 |
+
Python >= 3.7 (推荐使用Anaconda或Miniconda)
|
50 |
+
PyTorch >= 1.7
|
51 |
+
## 安装
|
52 |
+
1、直接进入已配好的文件夹
|
53 |
+
cd /root/lihui/StoryVisualization/Real-ESRGAN
|
54 |
+
2、或 把项目克隆到本地
|
55 |
+
bash git clone https://github.com/xinntao/Real-ESRGAN.git cd Real-ESRGAN
|
56 |
+
3、 安装各种依赖
|
57 |
+
```bash
|
58 |
+
安装 basicsr - https://github.com/xinntao/BasicSR
|
59 |
+
#我们使用BasicSR来训练以及推断
|
60 |
+
pip install basicsr
|
61 |
+
#facexlib和gfpgan是用来增强人脸的
|
62 |
+
pip install facexlib pip install gfpgan pip install -r requirements.txt python setup.py develop
|
63 |
+
```
|
64 |
+
## 训练
|
65 |
+
训练好的模型: RealESRGAN_x4plus_anime_6B
|
66 |
+
有关waifu2x的更多信息和对比在anime_model.md中。
|
67 |
+
## 下载模型
|
68 |
+
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P weights
|
69 |
+
## 推断
|
70 |
+
python inference_realesrgan.py -n RealESRGAN_x4plus_anime_6B -i inputs
|
71 |
+
结果在results文件夹
|
72 |
+
## BibTeX 引用
|
73 |
+
@Article{wang2021realesrgan,
|
74 |
+
title={Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data},
|
75 |
+
author={Xintao Wang and Liangbin Xie and Chao Dong and Ying Shan},
|
76 |
+
journal={arXiv:2107.10833},
|
77 |
+
year={2021}
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
### 三、基于YOLOv5的目标角色检测算法
|
82 |
+
## 安装
|
83 |
+
克隆 repo,并要求在 Python>=3.7.0 环境中安装 requirements.txt ,且要求 PyTorch>=1.7 。
|
84 |
+
git clone https://github.com/ultralytics/yolov5 # clone
|
85 |
+
cd /root/lihui/StoryVisualization
|
86 |
+
cd yolov5
|
87 |
+
pip install -r requirements.txt # install
|
88 |
+
## 转换图片
|
89 |
+
cd /root/lihui/StoryVisualization
|
90 |
+
python transtoyolo.py
|
91 |
+
## 使用 detect.py 推理
|
92 |
+
detect.py 在各种来源上运行推理, 模型 自动从 最新的YOLOv5 release 中下载,并将结果保存到 runs/detect 。
|
93 |
+
python detect.py --weights yolov5s.pt --source 0 # webcam
|
94 |
+
img.jpg # image
|
95 |
+
vid.mp4 # video
|
96 |
+
screen # screenshot
|
97 |
+
path/ # directory
|
98 |
+
list.txt # list of images
|
99 |
+
list.streams # list of streams
|
100 |
+
'path/*.jpg' # glob
|
101 |
+
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
102 |
+
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
103 |
+
## 训练
|
104 |
+
最新的 模型 和 数据集 将自动的从 YOLOv5 release 中下载。 YOLOv5n/s/m/l/x 在 V100 GPU 的训练时间为 1/2/4/6/8 天( 多GPU 训练速度更快)。 尽可能使用更大的 --batch-size ,或通过 --batch-size -1 实现 YOLOv5 自动批处理 。下方显示的 batchsize 适用于 V100-16GB。
|
105 |
+
python train.py --data xxx.yaml --epochs 500 --weights '' --cfg yolov5l --batch-size 64
|
106 |
+
# xx.yaml文件为转换后的数据
|
107 |
+
|
108 |
+
## 许可
|
109 |
+
YOLOv5 在两种不同的 License 下可用:
|
110 |
+
AGPL-3.0 License: 查看 License 文件的详细信息。
|
111 |
+
企业License:在没有 AGPL-3.0 开源要求的情况下为商业产品开发提供更大的灵活性。典型用例是将 Ultralytics 软件和 AI 模型嵌入到商业产品和应用程序中。在以下位置申请企业许可证 Ultralytics 许可 。
|
112 |
+
|
113 |
+
|
114 |
+
### 四、演示系统
|
115 |
+
|
116 |
+
## 指定文件目录并运行:
|
117 |
+
cd /root/lihui/StoryVisualization/visualsystem
|
118 |
+
python main.py
|
119 |
+
|
120 |
+
|
121 |
+
#
|
122 |
+
Your identification has been saved in .
|
123 |
+
Your public key has been saved in C:\Users\30254/.ssh/id_ed25519.pub.
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch_lightning<1.7.0
|
2 |
+
lightning-bolts
|
3 |
+
transformers==4.24.0
|
4 |
+
diffusers==0.7.2
|
5 |
+
timm
|
6 |
+
ftfy
|
7 |
+
hydra-core
|
8 |
+
opencv-python
|
9 |
+
h5py
|
10 |
+
scipy
|
run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python main.py
|
test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
def gettext(index):
|
13 |
+
with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as h5:
|
14 |
+
story = list()
|
15 |
+
h5 = h5['test']
|
16 |
+
# 读取当前索引处的文本,并使用decode方法将其解码为UTF-8
|
17 |
+
texts = h5['text'][index].decode('utf-8').split('|')
|
18 |
+
symbol = '\n'
|
19 |
+
texts = symbol.join(texts)
|
20 |
+
texts = 'Story<' + str(index) + '> :' + '\n' + texts
|
21 |
+
print(texts)
|
22 |
+
return texts
|
23 |
+
|
24 |
+
|
25 |
+
# for i in range(1000):
|
26 |
+
# gettext(i)
|
27 |
+
|
28 |
+
# 截取前100的数据集
|
29 |
+
# ###正确的##############
|
30 |
+
# # import h5py
|
31 |
+
# # import numpy as np
|
32 |
+
# # from PIL import Image
|
33 |
+
# #
|
34 |
+
# #
|
35 |
+
# # # 创建名为“images”的子目录来保存图像
|
36 |
+
# # os.makedirs("train_images", exist_ok=True)
|
37 |
+
# #
|
38 |
+
# # 创建一个h5文件
|
39 |
+
# nf = h5py.File('/root/lihui/StoryVisualization/pororo_100.h5', "w")
|
40 |
+
# with h5py.File('/root/lihui/StoryVisualization/pororo.h5', 'r') as f:
|
41 |
+
# test_group = f['test']
|
42 |
+
# texts = np.array(test_group['text'][()])
|
43 |
+
# ngroup = nf.create_group('test')
|
44 |
+
# ntext = ngroup.create_dataset('text', (100,), dtype=h5py.string_dtype(encoding='utf-8'))
|
45 |
+
# for i in range(100):
|
46 |
+
# ntext[i]=texts[i]
|
47 |
+
# print(f"样本 {i}:")
|
48 |
+
# # for j in range(5):
|
49 |
+
# # # 创建一个固定的文件名来保存图像
|
50 |
+
# # # filename = os.path.join("images", f"image_{i}_{j}.png")
|
51 |
+
# # # # 将HDF5文件中的图像数据保存到文件中
|
52 |
+
# # # with open(filename, "wb") as img_file:
|
53 |
+
# # # img_file.write(test_group[f'image{j}'][i])
|
54 |
+
# # # 打印文本信息和文件名
|
55 |
+
# # ntext[i]='|'.join(texts[i].decode('utf-8').split('|')[j])
|
56 |
+
# # print(f"图像{j}已保存到文件:{filename}")
|
57 |
+
# print(ntext[i])
|
58 |
+
# nf.close()
|
59 |
+
|
60 |
+
#保存测试集图像,随机截取视频帧
|
61 |
+
with h5py.File(r'C:\Users\zjlab\Desktop\StoryVisualization\pororo.h5', 'r') as h5:
|
62 |
+
h5 = h5['test']
|
63 |
+
|
64 |
+
for index in range(len(h5['text'])): #len(h5['text'])
|
65 |
+
# index = int(index + 1)
|
66 |
+
# print(index)
|
67 |
+
images = list()
|
68 |
+
for i in range(5):
|
69 |
+
# 从h5文件中读取一组图像和对应的文本。
|
70 |
+
im = h5['image{}'.format(i)][index]
|
71 |
+
# print(im)
|
72 |
+
# pil_img = Image.fromarray(im)
|
73 |
+
# # 保存图像
|
74 |
+
# pil_img.save(os.path.join('/root/lihui/StoryVisualization/ori_test_images', '{:04d}.png'.format(i)))
|
75 |
+
# 对每个图像解码
|
76 |
+
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
|
77 |
+
# 随机选择一个128像素的图像切片
|
78 |
+
idx = random.randint(0, im.shape[0] / 128 - 1)
|
79 |
+
# 将切片后的图像加到images列表中
|
80 |
+
images.append(im[idx * 128: (idx + 1) * 128])
|
81 |
+
# 深拷贝,后续不随images变化
|
82 |
+
# ori_images = copy.deepcopy(images)
|
83 |
+
# 保存test原始图像
|
84 |
+
|
85 |
+
# for i, im in enumerate(images):
|
86 |
+
# file_path = 'C:/Users/zjlab/Desktop/StoryVisualization/test_images/group{:02d}_image{:02d}.png'.format(
|
87 |
+
# index + 1,
|
88 |
+
# i + 1)
|
89 |
+
# cv2.imwrite(file_path, im)
|
90 |
+
|
91 |
+
ori_images_pil = Image.fromarray(images[i])#numpy.uint8(images[i].detach().cpu().squeeze().float().numpy())).convert("RGB")
|
92 |
+
ori_images_pil.save(
|
93 |
+
os.path.join('C:/Users/zjlab/Desktop/StoryVisualization/test_images',
|
94 |
+
'group{:02d}_image{:02d}.png'.format(index + 1,i + 1)))
|
transtoyolo.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
from glob import glob
|
7 |
+
import cv2
|
8 |
+
import shutil
|
9 |
+
import yaml
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
# 获取当前路径
|
15 |
+
ROOT_DIR = os.getcwd()
|
16 |
+
|
17 |
+
'''
|
18 |
+
统一图像格式
|
19 |
+
'''
|
20 |
+
def change_image_format(label_path=ROOT_DIR, suffix='.png'):
|
21 |
+
"""
|
22 |
+
统一当前文件夹下所有图像的格式,如'.jpg'
|
23 |
+
:param suffix: 图像文件后缀
|
24 |
+
:param label_path:当前文件路径
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
externs = ['png', 'jpg', 'JPEG', 'BMP', 'bmp']
|
28 |
+
files = list()
|
29 |
+
# 获取尾缀在ecterns中的所有图像
|
30 |
+
for extern in externs:
|
31 |
+
files.extend(glob(label_path + "\\*." + extern))
|
32 |
+
# 遍历所有图像,转换图像格式
|
33 |
+
for file in files:
|
34 |
+
name = ''.join(file.split('.')[:-1])
|
35 |
+
file_suffix = file.split('.')[-1]
|
36 |
+
if file_suffix != suffix.split('.')[-1]:
|
37 |
+
# 重命名为jpg
|
38 |
+
new_name = name + suffix
|
39 |
+
# 读取图像
|
40 |
+
image = cv2.imread(file)
|
41 |
+
# 重新存图为jpg格式
|
42 |
+
cv2.imwrite(new_name, image)
|
43 |
+
# 删除旧图像
|
44 |
+
os.remove(file)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
'''
|
49 |
+
读取所有json文件,获取所有的类别
|
50 |
+
'''
|
51 |
+
def get_all_class(file_list, label_path=ROOT_DIR):
|
52 |
+
"""
|
53 |
+
从json文件中获取当前数据的所有类别
|
54 |
+
:param file_list:当前路径下的所有文件名
|
55 |
+
:param label_path:当前文件路径
|
56 |
+
:return:
|
57 |
+
"""
|
58 |
+
# 初始化类别列表
|
59 |
+
classes = list()
|
60 |
+
# 遍历所有json,读取shape中的label值内容,添加到classes
|
61 |
+
for filename in tqdm(file_list):
|
62 |
+
json_path = os.path.join(label_path, filename + '.json')
|
63 |
+
json_file = json.load(open(json_path, "r", encoding="utf-8"))
|
64 |
+
for item in json_file["shapes"]:
|
65 |
+
label_class = item['label']
|
66 |
+
if label_class not in classes:
|
67 |
+
classes.append(label_class)
|
68 |
+
print('read file done')
|
69 |
+
return classes
|
70 |
+
|
71 |
+
|
72 |
+
'''
|
73 |
+
划分训练集、验证机、测试集
|
74 |
+
'''
|
75 |
+
def split_dataset(label_path, test_size=0.3, isUseTest=False, useNumpyShuffle=False):
|
76 |
+
"""
|
77 |
+
将文件分为训练集,测试集和验证集
|
78 |
+
:param useNumpyShuffle: 使用numpy方法分割数据集
|
79 |
+
:param test_size: 分割测试集或验证集的比例
|
80 |
+
:param isUseTest: 是否使用测试集,默认为False
|
81 |
+
:param label_path:当前文件路径
|
82 |
+
:return:
|
83 |
+
"""
|
84 |
+
# 获取所有json
|
85 |
+
files = glob(label_path + "\\*.json")
|
86 |
+
files = [i.replace("\\", "/").split("/")[-1].split(".json")[0] for i in files]
|
87 |
+
|
88 |
+
if useNumpyShuffle:
|
89 |
+
file_length = len(files)
|
90 |
+
index = np.arange(file_length)
|
91 |
+
np.random.seed(32)
|
92 |
+
np.random.shuffle(index) # 随机划分
|
93 |
+
|
94 |
+
test_files = None
|
95 |
+
# 是否有测试集
|
96 |
+
if isUseTest:
|
97 |
+
trainval_files, test_files = np.array(files)[index[:int(file_length * (1 - test_size))]], np.array(files)[
|
98 |
+
index[int(file_length * (1 - test_size)):]]
|
99 |
+
else:
|
100 |
+
trainval_files = files
|
101 |
+
# 划分训练集和测试集
|
102 |
+
train_files, val_files = np.array(trainval_files)[index[:int(len(trainval_files) * (1 - test_size))]], \
|
103 |
+
np.array(trainval_files)[index[int(len(trainval_files) * (1 - test_size)):]]
|
104 |
+
else:
|
105 |
+
test_files = None
|
106 |
+
if isUseTest:
|
107 |
+
trainval_files, test_files = train_test_split(files, test_size=test_size, random_state=55)
|
108 |
+
else:
|
109 |
+
trainval_files = files
|
110 |
+
train_files, val_files = train_test_split(trainval_files, test_size=test_size, random_state=55)
|
111 |
+
|
112 |
+
return train_files, val_files, test_files, files
|
113 |
+
|
114 |
+
|
115 |
+
'''
|
116 |
+
生成yolov5的训练、验证、测试集的文件夹
|
117 |
+
'''
|
118 |
+
def create_save_file(label_path=ROOT_DIR):
|
119 |
+
"""
|
120 |
+
按照训练时的图像和标注路径创建文件夹
|
121 |
+
:param label_path:当前文件路径
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
# 生成训练集
|
125 |
+
train_image = os.path.join(label_path, 'train', 'images')
|
126 |
+
if not os.path.exists(train_image):
|
127 |
+
os.makedirs(train_image)
|
128 |
+
train_label = os.path.join(label_path, 'train', 'labels')
|
129 |
+
if not os.path.exists(train_label):
|
130 |
+
os.makedirs(train_label)
|
131 |
+
# 生成验证集
|
132 |
+
val_image = os.path.join(label_path, 'valid', 'images')
|
133 |
+
if not os.path.exists(val_image):
|
134 |
+
os.makedirs(val_image)
|
135 |
+
val_label = os.path.join(label_path, 'valid', 'labels')
|
136 |
+
if not os.path.exists(val_label):
|
137 |
+
os.makedirs(val_label)
|
138 |
+
# 生成测试集
|
139 |
+
test_image = os.path.join(label_path, 'test', 'images')
|
140 |
+
if not os.path.exists(test_image):
|
141 |
+
os.makedirs(test_image)
|
142 |
+
test_label = os.path.join(label_path, 'test', 'labels')
|
143 |
+
if not os.path.exists(test_label):
|
144 |
+
os.makedirs(test_label)
|
145 |
+
return train_image, train_label, val_image, val_label, test_image, test_label
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
'''
|
150 |
+
转换,根据图像大小,返回box框的中点和高宽信息
|
151 |
+
'''
|
152 |
+
def convert(size, box):
|
153 |
+
# 宽
|
154 |
+
dw = 1. / (size[0])
|
155 |
+
# 高
|
156 |
+
dh = 1. / (size[1])
|
157 |
+
|
158 |
+
x = (box[0] + box[1]) / 2.0 - 1
|
159 |
+
y = (box[2] + box[3]) / 2.0 - 1
|
160 |
+
# 宽
|
161 |
+
w = box[1] - box[0]
|
162 |
+
# 高
|
163 |
+
h = box[3] - box[2]
|
164 |
+
|
165 |
+
x = x * dw
|
166 |
+
w = w * dw
|
167 |
+
y = y * dh
|
168 |
+
h = h * dh
|
169 |
+
return x, y, w, h
|
170 |
+
|
171 |
+
|
172 |
+
'''
|
173 |
+
移动图像和标注文件到指定的训练集、验证集和测试集中
|
174 |
+
'''
|
175 |
+
def push_into_file(file, images, labels, label_path=ROOT_DIR, suffix='.jpg'):
|
176 |
+
"""
|
177 |
+
最终生成在当前文件夹下的所有文件按image和label分别存在到训练集/验证集/测试集路径的文件夹下
|
178 |
+
:param file: 文件名列表
|
179 |
+
:param images: 存放images的路径
|
180 |
+
:param labels: 存放labels的路径
|
181 |
+
:param label_path: 当前文件路径
|
182 |
+
:param suffix: 图像文件后缀
|
183 |
+
:return:
|
184 |
+
"""
|
185 |
+
# 遍历所有文件
|
186 |
+
for filename in file:
|
187 |
+
# 图像文件
|
188 |
+
image_file = os.path.join(label_path, filename + suffix)
|
189 |
+
# 标注文件
|
190 |
+
label_file = os.path.join(label_path, filename + '.txt')
|
191 |
+
# yolov5存放图像文件夹
|
192 |
+
if not os.path.exists(os.path.join(images, filename + suffix)):
|
193 |
+
try:
|
194 |
+
shutil.move(image_file, images)
|
195 |
+
except OSError:
|
196 |
+
pass
|
197 |
+
# yolov5存放标注文件夹
|
198 |
+
if not os.path.exists(os.path.join(labels, filename + suffix)):
|
199 |
+
try:
|
200 |
+
shutil.move(label_file, labels)
|
201 |
+
except OSError:
|
202 |
+
pass
|
203 |
+
|
204 |
+
'''
|
205 |
+
|
206 |
+
'''
|
207 |
+
def json2txt(classes, txt_Name='allfiles', label_path=ROOT_DIR, suffix='.png'):
|
208 |
+
"""
|
209 |
+
将json文件转化为txt文件,并将json文件存放到指定文件夹
|
210 |
+
:param classes: 类别名
|
211 |
+
:param txt_Name:txt文件,用来存放所有文件的路径
|
212 |
+
:param label_path:当前文件路径
|
213 |
+
:param suffix:图像文件后缀
|
214 |
+
:return:
|
215 |
+
"""
|
216 |
+
store_json = os.path.join(label_path, 'json')
|
217 |
+
if not os.path.exists(store_json):
|
218 |
+
os.makedirs(store_json)
|
219 |
+
|
220 |
+
_, _, _, files = split_dataset(label_path)
|
221 |
+
if not os.path.exists(os.path.join(label_path, 'tmp')):
|
222 |
+
os.makedirs(os.path.join(label_path, 'tmp'))
|
223 |
+
|
224 |
+
list_file = open('tmp/%s.txt' % txt_Name, 'w')
|
225 |
+
for json_file_ in tqdm(files):
|
226 |
+
json_filename = os.path.join(label_path, json_file_ + ".json")
|
227 |
+
imagePath = os.path.join(label_path, json_file_ + suffix)
|
228 |
+
list_file.write('%s\n' % imagePath)
|
229 |
+
out_file = open('%s/%s.txt' % (label_path, json_file_), 'w')
|
230 |
+
json_file = json.load(open(json_filename, "r", encoding="utf-8"))
|
231 |
+
if os.path.exists(imagePath):
|
232 |
+
height, width, channels = cv2.imread(imagePath).shape
|
233 |
+
for multi in json_file["shapes"]:
|
234 |
+
if len(multi["points"][0]) == 0:
|
235 |
+
out_file.write('')
|
236 |
+
continue
|
237 |
+
points = np.array(multi["points"])
|
238 |
+
xmin = min(points[:, 0]) if min(points[:, 0]) > 0 else 0
|
239 |
+
xmax = max(points[:, 0]) if max(points[:, 0]) > 0 else 0
|
240 |
+
ymin = min(points[:, 1]) if min(points[:, 1]) > 0 else 0
|
241 |
+
ymax = max(points[:, 1]) if max(points[:, 1]) > 0 else 0
|
242 |
+
label = multi["label"]
|
243 |
+
if xmax <= xmin:
|
244 |
+
pass
|
245 |
+
elif ymax <= ymin:
|
246 |
+
pass
|
247 |
+
else:
|
248 |
+
cls_id = classes.index(label)
|
249 |
+
b = (float(xmin), float(xmax), float(ymin), float(ymax))
|
250 |
+
bb = convert((width, height), b)
|
251 |
+
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
|
252 |
+
# print(json_filename, xmin, ymin, xmax, ymax, cls_id)
|
253 |
+
if not os.path.exists(os.path.join(store_json, json_file_ + '.json')):
|
254 |
+
try:
|
255 |
+
shutil.move(json_filename, store_json)
|
256 |
+
except OSError:
|
257 |
+
pass
|
258 |
+
|
259 |
+
'''
|
260 |
+
创建yaml文件
|
261 |
+
'''
|
262 |
+
def create_yaml(classes, label_path, isUseTest=False):
|
263 |
+
nc = len(classes)
|
264 |
+
if not isUseTest:
|
265 |
+
desired_caps = {
|
266 |
+
'path': label_path,
|
267 |
+
'train': 'train/images',
|
268 |
+
'val': 'valid/images',
|
269 |
+
'nc': nc,
|
270 |
+
'names': classes
|
271 |
+
}
|
272 |
+
else:
|
273 |
+
desired_caps = {
|
274 |
+
'path': label_path,
|
275 |
+
'train': 'train/images',
|
276 |
+
'val': 'valid/images',
|
277 |
+
'test': 'test/images',
|
278 |
+
'nc': nc,
|
279 |
+
'names': classes
|
280 |
+
}
|
281 |
+
yamlpath = os.path.join(label_path, "data" + ".yaml")
|
282 |
+
|
283 |
+
# 写入到yaml文件
|
284 |
+
with open(yamlpath, "w+", encoding="utf-8") as f:
|
285 |
+
for key, val in desired_caps.items():
|
286 |
+
yaml.dump({key: val}, f, default_flow_style=False)
|
287 |
+
|
288 |
+
|
289 |
+
# 首先确保当前文件夹下的所有图片统一后缀,如.jpg,如果为其他后缀,将suffix改为对应的后缀,如.png
|
290 |
+
def ChangeToYolo5(label_path=r"D:\storydata", suffix='.png', test_size=0.1, isUseTest=False):
|
291 |
+
"""
|
292 |
+
生成最终标准格式的文件
|
293 |
+
:param test_size: 分割测试集或验证集的比例
|
294 |
+
:param label_path:当前文件路径
|
295 |
+
:param suffix: 文件后缀名
|
296 |
+
:param isUseTest: 是否使用测试集
|
297 |
+
:return:
|
298 |
+
"""
|
299 |
+
# step1:统一图像格式
|
300 |
+
change_image_format(label_path)
|
301 |
+
# step2:根据json文件划分训练集、验证集、测试集
|
302 |
+
train_files, val_files, test_file, files = split_dataset(label_path, test_size=test_size, isUseTest=isUseTest)
|
303 |
+
# step3:根据json文件,获取所有类别
|
304 |
+
classes = get_all_class(files)
|
305 |
+
# step4:将json文件转化为txt文件,并将json文件存放到指定文件夹
|
306 |
+
json2txt(classes)
|
307 |
+
# step5:创建yolov5训练所需的yaml文件
|
308 |
+
create_yaml(classes, label_path, isUseTest=isUseTest)
|
309 |
+
# step6:生成yolov5的训练、验证、测试集的文件夹
|
310 |
+
train_image, train_label, val_image, val_label, test_image, test_label = create_save_file(label_path)
|
311 |
+
# step7:将所有图像和标注文件,移动到对应的训练集、验证集、测试集
|
312 |
+
push_into_file(train_files, train_image, train_label, suffix=suffix) # 将文件移动到训练集文件中
|
313 |
+
push_into_file(val_files, val_image, val_label, suffix=suffix) # 将文件移动到验证集文件夹中
|
314 |
+
if test_file is not None: # 如果测试集存在,则将文件移动到测试集文件中
|
315 |
+
push_into_file(test_file, test_image, test_label, suffix=suffix)
|
316 |
+
print('create dataset done')
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
ChangeToYolo5()
|