Spaces:
Runtime error
Runtime error
xco2
commited on
Commit
•
ebf6d7b
1
Parent(s):
687cb7c
init
Browse files- net/UNet.py +0 -96
- requirements.txt +2 -179
net/UNet.py
CHANGED
@@ -422,99 +422,3 @@ class UNet(nn.Module):
|
|
422 |
# print("decoder:")
|
423 |
# print(decoder_out.shape)
|
424 |
return decoder_out
|
425 |
-
|
426 |
-
|
427 |
-
if __name__ == '__main__':
|
428 |
-
import cv2, os
|
429 |
-
|
430 |
-
|
431 |
-
def modelSave(model, save_path, save_name):
|
432 |
-
if not os.path.exists(save_path):
|
433 |
-
os.mkdir(save_path)
|
434 |
-
torch.save(model.state_dict(), os.path.join(save_path, save_name))
|
435 |
-
|
436 |
-
|
437 |
-
def merge_images(images: np.ndarray):
|
438 |
-
"""
|
439 |
-
合并图像
|
440 |
-
:param images: 图像数组
|
441 |
-
:return: 合并后的图像数组
|
442 |
-
"""
|
443 |
-
n, h, w, c = images.shape
|
444 |
-
nn = int(np.ceil(n ** 0.5))
|
445 |
-
merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
|
446 |
-
for i in range(n):
|
447 |
-
row = i // nn
|
448 |
-
col = i % nn
|
449 |
-
merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]
|
450 |
-
|
451 |
-
merged_image = np.clip(merged_image, 0, 255)
|
452 |
-
merged_image = np.array(merged_image, dtype=np.uint8)
|
453 |
-
return merged_image
|
454 |
-
|
455 |
-
|
456 |
-
# 320,448,576,832
|
457 |
-
config = { # 模型结构相关
|
458 |
-
"en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
|
459 |
-
"en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
|
460 |
-
"en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
|
461 |
-
"en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
|
462 |
-
"de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
|
463 |
-
"de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
|
464 |
-
"de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
|
465 |
-
"de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8), # skip的地方不做self-attention
|
466 |
-
"t_out_c": 256,
|
467 |
-
"vae_c": 4,
|
468 |
-
"block_deep": 3,
|
469 |
-
}
|
470 |
-
device = "cuda"
|
471 |
-
total_step = 1000
|
472 |
-
|
473 |
-
unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
|
474 |
-
config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
|
475 |
-
config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
|
476 |
-
|
477 |
-
print("总参数", sum(i.numel() for i in unet.parameters()) / 10000, "单位:万")
|
478 |
-
print("encoder", sum(i.numel() for i in unet.encoder.parameters()) / 10000, "单位:万")
|
479 |
-
print("decoder", sum(i.numel() for i in unet.decoder.parameters()) / 10000, "单位:万")
|
480 |
-
print("t", sum(i.numel() for i in unet.t_encoder.parameters()) / 10000, "单位:万")
|
481 |
-
|
482 |
-
batch_size = 2
|
483 |
-
x = np.random.random((batch_size, config["vae_c"], 32, 32))
|
484 |
-
t = np.random.uniform(1, total_step + 0.9999, size=(batch_size, 1))
|
485 |
-
t = np.array(t, dtype=np.int16)
|
486 |
-
t = t / total_step
|
487 |
-
|
488 |
-
with torch.no_grad():
|
489 |
-
x = torch.Tensor(x).to(device)
|
490 |
-
t = torch.Tensor(t).to(device)
|
491 |
-
y = unet(x, t)
|
492 |
-
print(y.shape)
|
493 |
-
|
494 |
-
z = y[0].cpu().numpy()
|
495 |
-
# z = (z - np.mean(z)) / (np.max(z) - np.min(z))
|
496 |
-
z = np.clip(np.asarray((z + 1) * 127.5), 0, 255)
|
497 |
-
z = np.asarray(z, dtype=np.uint8)
|
498 |
-
|
499 |
-
z = [np.tile(z[ii, :, :, np.newaxis], (1, 1, 3)) for ii in range(z.shape[0])]
|
500 |
-
noise = merge_images(np.array(z))
|
501 |
-
|
502 |
-
noise = cv2.resize(noise, None, fx=2, fy=2)
|
503 |
-
cv2.imshow("noise", noise)
|
504 |
-
cv2.waitKey(0)
|
505 |
-
|
506 |
-
# modelSave(unet, "./", "test.pth")
|
507 |
-
# 导出为onnx格式
|
508 |
-
torch.onnx.export(
|
509 |
-
unet,
|
510 |
-
(x, t),
|
511 |
-
'unet.onnx',
|
512 |
-
export_params=True,
|
513 |
-
opset_version=12,
|
514 |
-
)
|
515 |
-
import onnx
|
516 |
-
|
517 |
-
# 增加维度信息
|
518 |
-
model_file = 'unet.onnx'
|
519 |
-
onnx_model = onnx.load(model_file)
|
520 |
-
onnx.save(onnx.shape_inference.infer_shapes(onnx_model), model_file)
|
|
|
422 |
# print("decoder:")
|
423 |
# print(decoder_out.shape)
|
424 |
return decoder_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,186 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
aiofiles==23.1.0
|
4 |
-
aiohttp==3.8.3
|
5 |
-
aiosignal==1.3.1
|
6 |
-
aliyun-python-sdk-core==2.13.36
|
7 |
-
aliyun-python-sdk-kms==2.16.0
|
8 |
-
altair==4.2.0
|
9 |
-
anyio==3.6.2
|
10 |
-
appdirs==1.4.4
|
11 |
-
asttokens==2.3.0
|
12 |
-
async-timeout==4.0.2
|
13 |
-
attrs==22.1.0
|
14 |
-
audioread==3.0.0
|
15 |
-
backcall==0.2.0
|
16 |
-
certifi==2022.12.7
|
17 |
-
cffi==1.15.1
|
18 |
-
charset-normalizer==2.1.1
|
19 |
-
chumpy==0.70
|
20 |
-
click==8.1.3
|
21 |
-
clip==1.0
|
22 |
-
colorama==0.4.6
|
23 |
-
commonmark==0.9.1
|
24 |
-
contourpy==1.0.6
|
25 |
-
cpm-kernels==1.0.11
|
26 |
-
crcmod==1.7
|
27 |
-
cryptography==39.0.2
|
28 |
-
cycler==0.11.0
|
29 |
-
Cython==0.29.32
|
30 |
-
datasets==2.8.0
|
31 |
-
decorator==5.1.1
|
32 |
-
decord==0.6.0
|
33 |
-
diffusers==0.20.1
|
34 |
-
dill==0.3.6
|
35 |
-
docker-pycreds==0.4.0
|
36 |
-
einops==0.6.0
|
37 |
-
entrypoints==0.4
|
38 |
-
exceptiongroup==1.1.3
|
39 |
-
executing==1.2.0
|
40 |
-
fastapi==0.88.0
|
41 |
-
ffmpy==0.3.0
|
42 |
-
filelock==3.8.2
|
43 |
-
Flask==2.0.2
|
44 |
-
Flask-Cors==3.0.10
|
45 |
-
fonttools==4.38.0
|
46 |
-
frozenlist==1.3.3
|
47 |
-
fsspec==2022.11.0
|
48 |
-
ftfy==6.1.1
|
49 |
-
gast==0.5.3
|
50 |
-
gitdb==4.0.10
|
51 |
-
GitPython==3.1.32
|
52 |
-
gradio==3.39.0
|
53 |
-
gradio_client==0.3.0
|
54 |
-
h11==0.14.0
|
55 |
-
httpcore==0.16.2
|
56 |
-
httpx==0.23.1
|
57 |
huggingface-hub==0.16.4
|
58 |
-
icetk==0.0.4
|
59 |
-
idna==3.4
|
60 |
-
importlib-metadata==5.2.0
|
61 |
-
ipython==8.15.0
|
62 |
-
itsdangerous==2.1.2
|
63 |
-
jedi==0.19.0
|
64 |
-
Jinja2==3.1.2
|
65 |
-
jmespath==0.10.0
|
66 |
-
joblib==1.2.0
|
67 |
-
json-tricks==3.16.1
|
68 |
-
jsonplus==0.8.0
|
69 |
-
jsonschema==4.17.3
|
70 |
-
kiwisolver==1.4.4
|
71 |
-
lazy_loader==0.1
|
72 |
-
librosa==0.10.0
|
73 |
-
linkify-it-py==1.0.3
|
74 |
-
lion-pytorch==0.1.2
|
75 |
-
llvmlite==0.39.1
|
76 |
-
loguru==0.6.0
|
77 |
-
Markdown==3.4.1
|
78 |
-
markdown-it-py==2.1.0
|
79 |
-
MarkupSafe==2.1.1
|
80 |
-
matplotlib==3.6.2
|
81 |
-
matplotlib-inline==0.1.6
|
82 |
-
mdit-py-plugins==0.3.3
|
83 |
-
mdurl==0.1.2
|
84 |
-
mediapipe==0.8.11
|
85 |
-
mmcv-full==1.7.0
|
86 |
-
mmdet==2.26.0
|
87 |
-
model-index==0.1.11
|
88 |
-
modelscope==1.3.2
|
89 |
-
mpmath==1.2.1
|
90 |
-
msgpack==1.0.4
|
91 |
-
multidict==6.0.3
|
92 |
-
multiprocess==0.70.14
|
93 |
-
munkres==1.1.4
|
94 |
-
networkx==3.0
|
95 |
-
numba==0.56.4
|
96 |
numpy==1.23.4
|
97 |
-
onnx==1.14.1
|
98 |
-
opencv-contrib-python==4.5.5.64
|
99 |
-
opencv-python==4.5.5.64
|
100 |
-
openmim==0.3.3
|
101 |
-
ordered-set==4.1.0
|
102 |
-
orjson==3.8.3
|
103 |
-
oss2==2.16.0
|
104 |
-
packaging==21.3
|
105 |
-
pandas==1.5.2
|
106 |
-
parso==0.8.3
|
107 |
-
pathtools==0.1.2
|
108 |
-
pickleshare==0.7.5
|
109 |
-
Pillow==9.2.0
|
110 |
-
pip==23.1.2
|
111 |
-
platformdirs==3.1.0
|
112 |
-
plotly==5.11.0
|
113 |
-
pooch==1.7.0
|
114 |
-
prodigyopt==1.0
|
115 |
-
prompt-toolkit==3.0.39
|
116 |
-
protobuf==4.24.2
|
117 |
-
psutil==5.9.5
|
118 |
-
pure-eval==0.2.2
|
119 |
-
pyarrow==11.0.0
|
120 |
-
pycocotools==2.0.6
|
121 |
-
pycparser==2.21
|
122 |
-
pycryptodome==3.16.0
|
123 |
-
pydantic==1.10.2
|
124 |
-
pydub==0.25.1
|
125 |
-
Pygments==2.13.0
|
126 |
-
pyparsing==3.0.9
|
127 |
-
pyrsistent==0.19.2
|
128 |
-
python-dateutil==2.8.2
|
129 |
-
python-multipart==0.0.5
|
130 |
-
pytorch-fid==0.3.0
|
131 |
-
pytz==2022.6
|
132 |
-
PyYAML==6.0
|
133 |
-
regex==2022.10.31
|
134 |
-
requests==2.28.1
|
135 |
-
responses==0.18.0
|
136 |
-
rfc3986==1.5.0
|
137 |
-
rich==12.6.0
|
138 |
-
safetensors==0.3.3
|
139 |
-
scikit-learn==1.2.1
|
140 |
-
scipy==1.9.3
|
141 |
-
semantic-version==2.10.0
|
142 |
-
sentencepiece==0.1.97
|
143 |
-
sentry-sdk==1.28.0
|
144 |
-
setproctitle==1.3.2
|
145 |
-
setuptools==65.5.0
|
146 |
-
simplejson==3.18.3
|
147 |
-
six==1.16.0
|
148 |
-
smmap==5.0.0
|
149 |
-
sniffio==1.3.0
|
150 |
-
sortedcontainers==2.4.0
|
151 |
-
soundfile==0.12.1
|
152 |
-
soxr==0.3.4
|
153 |
-
stack-data==0.6.2
|
154 |
-
starlette==0.22.0
|
155 |
-
sympy==1.11.1
|
156 |
-
tabulate==0.9.0
|
157 |
-
tenacity==8.1.0
|
158 |
-
terminaltables==3.1.10
|
159 |
-
threadpoolctl==3.1.0
|
160 |
-
timm==0.4.9
|
161 |
-
tokenizers==0.13.2
|
162 |
-
toolz==0.12.0
|
163 |
torch==2.0.0+cu117
|
164 |
torchaudio==2.0.1+cu117
|
165 |
torchinfo==1.7.1
|
166 |
torchvision==0.15.1+cu117
|
167 |
tqdm==4.64.1
|
168 |
-
traitlets==5.9.0
|
169 |
-
transformers==4.26.1
|
170 |
-
typing_extensions==4.4.0
|
171 |
-
uc-micro-py==1.0.1
|
172 |
-
unicodedata2==15.0.0
|
173 |
-
urllib3==1.26.12
|
174 |
-
uvicorn==0.20.0
|
175 |
-
wandb==0.15.5
|
176 |
-
wcwidth==0.2.5
|
177 |
-
websockets==10.4
|
178 |
-
Werkzeug==2.2.2
|
179 |
-
wheel==0.37.1
|
180 |
-
win32-setctime==1.1.0
|
181 |
-
wincertstore==0.2
|
182 |
-
xtcocotools==1.12
|
183 |
-
xxhash==3.2.0
|
184 |
-
yapf==0.32.0
|
185 |
-
yarl==1.8.2
|
186 |
-
zipp==3.11.0
|
|
|
1 |
+
gradio
|
2 |
+
gradio_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
huggingface-hub==0.16.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
numpy==1.23.4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
torch==2.0.0+cu117
|
6 |
torchaudio==2.0.1+cu117
|
7 |
torchinfo==1.7.1
|
8 |
torchvision==0.15.1+cu117
|
9 |
tqdm==4.64.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|