Adapter commited on
Commit
2254a67
1 Parent(s): ee11c4c

rebuild+depth

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -0
  2. .gitignore +128 -0
  3. LICENSE +0 -0
  4. README.md +0 -0
  5. app.py +4 -2
  6. configs/stable-diffusion/app.yaml +0 -0
  7. configs/stable-diffusion/test_keypose.yaml +0 -87
  8. configs/stable-diffusion/test_mask.yaml +0 -87
  9. configs/stable-diffusion/test_mask_sketch.yaml +0 -87
  10. configs/stable-diffusion/test_sketch.yaml +0 -87
  11. configs/stable-diffusion/test_sketch_edit.yaml +0 -87
  12. configs/stable-diffusion/train_keypose.yaml +0 -87
  13. configs/stable-diffusion/train_mask.yaml +0 -87
  14. configs/stable-diffusion/train_sketch.yaml +0 -87
  15. dataset_coco.py +0 -138
  16. demo/demos.py +26 -1
  17. demo/model.py +69 -6
  18. dist_util.py +0 -91
  19. environment.yaml +0 -0
  20. examples/edit_cat/edge.png +0 -0
  21. examples/edit_cat/edge_2.png +0 -0
  22. examples/edit_cat/im.png +0 -0
  23. examples/edit_cat/mask.png +0 -0
  24. examples/keypose/iron.png +0 -0
  25. examples/seg/dinner.png +0 -0
  26. examples/seg/motor.png +0 -0
  27. examples/seg_sketch/edge.png +0 -0
  28. examples/seg_sketch/mask.png +0 -0
  29. examples/sketch/car.png +0 -0
  30. examples/sketch/girl.jpeg +0 -0
  31. examples/sketch/human.png +0 -0
  32. examples/sketch/scenery.jpg +0 -0
  33. examples/sketch/scenery2.jpg +0 -0
  34. gradio_keypose.py +0 -254
  35. gradio_sketch.py +0 -147
  36. ldm/data/__init__.py +0 -0
  37. ldm/data/base.py +0 -0
  38. ldm/data/imagenet.py +0 -0
  39. ldm/data/lsun.py +0 -0
  40. ldm/lr_scheduler.py +0 -0
  41. ldm/models/autoencoder.py +0 -0
  42. ldm/models/diffusion/__init__.py +0 -0
  43. ldm/models/diffusion/classifier.py +0 -0
  44. ldm/models/diffusion/ddim.py +0 -0
  45. ldm/models/diffusion/ddpm.py +0 -0
  46. ldm/models/diffusion/dpm_solver/__init__.py +0 -0
  47. ldm/models/diffusion/dpm_solver/dpm_solver.py +0 -0
  48. ldm/models/diffusion/dpm_solver/sampler.py +0 -0
  49. ldm/models/diffusion/plms.py +0 -0
  50. ldm/modules/attention.py +0 -0
.gitattributes CHANGED
File without changes
.gitignore ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignored folders
2
+ models
3
+
4
+ # ignored folders
5
+ tmp/*
6
+
7
+ *.DS_Store
8
+ .idea
9
+
10
+ # ignored files
11
+ version.py
12
+
13
+ # ignored files with suffix
14
+ # *.html
15
+ # *.png
16
+ # *.jpeg
17
+ # *.jpg
18
+ # *.gif
19
+ # *.pth
20
+ # *.zip
21
+
22
+ # template
23
+
24
+ # Byte-compiled / optimized / DLL files
25
+ __pycache__/
26
+ *.pyc
27
+ *.py[cod]
28
+ *$py.class
29
+
30
+ # C extensions
31
+ *.so
32
+
33
+ # Distribution / packaging
34
+ .Python
35
+ build/
36
+ develop-eggs/
37
+ dist/
38
+ downloads/
39
+ eggs/
40
+ .eggs/
41
+ lib/
42
+ lib64/
43
+ parts/
44
+ sdist/
45
+ var/
46
+ wheels/
47
+ *.egg-info/
48
+ .installed.cfg
49
+ *.egg
50
+ MANIFEST
51
+
52
+ # PyInstaller
53
+ # Usually these files are written by a python script from a template
54
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
55
+ *.manifest
56
+ *.spec
57
+
58
+ # Installer logs
59
+ pip-log.txt
60
+ pip-delete-this-directory.txt
61
+
62
+ # Unit test / coverage reports
63
+ htmlcov/
64
+ .tox/
65
+ .coverage
66
+ .coverage.*
67
+ .cache
68
+ nosetests.xml
69
+ coverage.xml
70
+ *.cover
71
+ .hypothesis/
72
+ .pytest_cache/
73
+
74
+ # Translations
75
+ *.mo
76
+ *.pot
77
+
78
+ # Django stuff:
79
+ *.log
80
+ local_settings.py
81
+ db.sqlite3
82
+
83
+ # Flask stuff:
84
+ instance/
85
+ .webassets-cache
86
+
87
+ # Scrapy stuff:
88
+ .scrapy
89
+
90
+ # Sphinx documentation
91
+ docs/_build/
92
+
93
+ # PyBuilder
94
+ target/
95
+
96
+ # Jupyter Notebook
97
+ .ipynb_checkpoints
98
+
99
+ # pyenv
100
+ .python-version
101
+
102
+ # celery beat schedule file
103
+ celerybeat-schedule
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .venv
111
+ env/
112
+ venv/
113
+ ENV/
114
+ env.bak/
115
+ venv.bak/
116
+
117
+ # Spyder project settings
118
+ .spyderproject
119
+ .spyproject
120
+
121
+ # Rope project settings
122
+ .ropeproject
123
+
124
+ # mkdocs documentation
125
+ /site
126
+
127
+ # mypy
128
+ .mypy_cache/
LICENSE CHANGED
File without changes
README.md CHANGED
File without changes
app.py CHANGED
@@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
- 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
@@ -72,5 +72,7 @@ with gr.Blocks(css='style.css') as demo:
72
  create_demo_draw(model.process_draw)
73
  with gr.TabItem('Segmentation'):
74
  create_demo_seg(model.process_seg)
 
 
75
 
76
  demo.queue().launch(debug=True, server_name='0.0.0.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
+ 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth'],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
72
  create_demo_draw(model.process_draw)
73
  with gr.TabItem('Segmentation'):
74
  create_demo_seg(model.process_seg)
75
+ with gr.TabItem('Depth'):
76
+ create_demo_depth(model.process_depth)
77
 
78
  demo.queue().launch(debug=True, server_name='0.0.0.0')
configs/stable-diffusion/app.yaml CHANGED
File without changes
configs/stable-diffusion/test_keypose.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: test_keypose
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/test_mask.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: test_mask
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/test_mask_sketch.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: test_mask_sketch
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/test_sketch.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: test_sketch
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/test_sketch_edit.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: test_sketch_edit
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/train_keypose.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: train_keypose
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/train_mask.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: train_mask
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/stable-diffusion/train_sketch.yaml DELETED
@@ -1,87 +0,0 @@
1
- name: train_sketch
2
- model:
3
- base_learning_rate: 1.0e-04
4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
5
- params:
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false # Note: different from the one we trained before
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
-
21
- scheduler_config: # 10000 warmup steps
22
- target: ldm.lr_scheduler.LambdaLinearScheduler
23
- params:
24
- warm_up_steps: [ 10000 ]
25
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
- f_start: [ 1.e-6 ]
27
- f_max: [ 1. ]
28
- f_min: [ 1. ]
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- image_size: 32 # unused
34
- in_channels: 4
35
- out_channels: 4
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
-
47
- first_stage_config:
48
- target: ldm.models.autoencoder.AutoencoderKL
49
- params:
50
- embed_dim: 4
51
- monitor: val/rec_loss
52
- ddconfig:
53
- double_z: true
54
- z_channels: 4
55
- resolution: 256
56
- in_channels: 3
57
- out_ch: 3
58
- ch: 128
59
- ch_mult:
60
- - 1
61
- - 2
62
- - 4
63
- - 4
64
- num_res_blocks: 2
65
- attn_resolutions: []
66
- dropout: 0.0
67
- lossconfig:
68
- target: torch.nn.Identity
69
-
70
- cond_stage_config: #__is_unconditional__
71
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
- params:
73
- version: models/clip-vit-large-patch14
74
-
75
- logger:
76
- print_freq: 100
77
- save_checkpoint_freq: !!float 1e4
78
- use_tb_logger: true
79
- wandb:
80
- project: ~
81
- resume_id: ~
82
- dist_params:
83
- backend: nccl
84
- port: 29500
85
- training:
86
- lr: !!float 1e-5
87
- save_freq: 1e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset_coco.py DELETED
@@ -1,138 +0,0 @@
1
- import torch
2
- import json
3
- import cv2
4
- import torch
5
- import os
6
- from basicsr.utils import img2tensor, tensor2img
7
- import random
8
-
9
- class dataset_coco():
10
- def __init__(self, path_json, root_path, image_size, mode='train'):
11
- super(dataset_coco, self).__init__()
12
- with open(path_json, 'r', encoding='utf-8') as fp:
13
- data = json.load(fp)
14
- data = data['images']
15
- self.paths = []
16
- self.root_path = root_path
17
- for file in data:
18
- input_path = file['filepath']
19
- if mode == 'train':
20
- if 'val' not in input_path:
21
- self.paths.append(file)
22
- else:
23
- if 'val' in input_path:
24
- self.paths.append(file)
25
-
26
- def __getitem__(self, idx):
27
- file = self.paths[idx]
28
- input_path = file['filepath']
29
- input_name = file['filename']
30
- path = os.path.join(self.root_path, input_path, input_name)
31
- im = cv2.imread(path)
32
- im = cv2.resize(im, (512,512))
33
- im = img2tensor(im, bgr2rgb=True, float32=True)/255.
34
- sentences = file['sentences']
35
- sentence = sentences[int(random.random()*len(sentences))]['raw'].strip('.')
36
- return {'im':im, 'sentence':sentence}
37
-
38
- def __len__(self):
39
- return len(self.paths)
40
-
41
-
42
- class dataset_coco_mask():
43
- def __init__(self, path_json, root_path_im, root_path_mask, image_size):
44
- super(dataset_coco_mask, self).__init__()
45
- with open(path_json, 'r', encoding='utf-8') as fp:
46
- data = json.load(fp)
47
- data = data['annotations']
48
- self.files = []
49
- self.root_path_im = root_path_im
50
- self.root_path_mask = root_path_mask
51
- for file in data:
52
- name = "%012d.png"%file['image_id']
53
- self.files.append({'name':name, 'sentence':file['caption']})
54
-
55
- def __getitem__(self, idx):
56
- file = self.files[idx]
57
- name = file['name']
58
- # print(os.path.join(self.root_path_im, name))
59
- im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
60
- im = cv2.resize(im, (512,512))
61
- im = img2tensor(im, bgr2rgb=True, float32=True)/255.
62
-
63
- mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
64
- mask = cv2.resize(mask, (512,512))
65
- mask = img2tensor(mask, bgr2rgb=True, float32=True)[0].unsqueeze(0)#/255.
66
-
67
- sentence = file['sentence']
68
- return {'im':im, 'mask':mask, 'sentence':sentence}
69
-
70
- def __len__(self):
71
- return len(self.files)
72
-
73
-
74
- class dataset_coco_mask_color():
75
- def __init__(self, path_json, root_path_im, root_path_mask, image_size):
76
- super(dataset_coco_mask_color, self).__init__()
77
- with open(path_json, 'r', encoding='utf-8') as fp:
78
- data = json.load(fp)
79
- data = data['annotations']
80
- self.files = []
81
- self.root_path_im = root_path_im
82
- self.root_path_mask = root_path_mask
83
- for file in data:
84
- name = "%012d.png"%file['image_id']
85
- self.files.append({'name':name, 'sentence':file['caption']})
86
-
87
- def __getitem__(self, idx):
88
- file = self.files[idx]
89
- name = file['name']
90
- # print(os.path.join(self.root_path_im, name))
91
- im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
92
- im = cv2.resize(im, (512,512))
93
- im = img2tensor(im, bgr2rgb=True, float32=True)/255.
94
-
95
- mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
96
- mask = cv2.resize(mask, (512,512))
97
- mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
98
-
99
- sentence = file['sentence']
100
- return {'im':im, 'mask':mask, 'sentence':sentence}
101
-
102
- def __len__(self):
103
- return len(self.files)
104
-
105
- class dataset_coco_mask_color_sig():
106
- def __init__(self, path_json, root_path_im, root_path_mask, image_size):
107
- super(dataset_coco_mask_color_sig, self).__init__()
108
- with open(path_json, 'r', encoding='utf-8') as fp:
109
- data = json.load(fp)
110
- data = data['annotations']
111
- self.files = []
112
- self.root_path_im = root_path_im
113
- self.root_path_mask = root_path_mask
114
- reg = {}
115
- for file in data:
116
- name = "%012d.png"%file['image_id']
117
- if name in reg:
118
- continue
119
- self.files.append({'name':name, 'sentence':file['caption']})
120
- reg[name] = name
121
-
122
- def __getitem__(self, idx):
123
- file = self.files[idx]
124
- name = file['name']
125
- # print(os.path.join(self.root_path_im, name))
126
- im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
127
- im = cv2.resize(im, (512,512))
128
- im = img2tensor(im, bgr2rgb=True, float32=True)/255.
129
-
130
- mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
131
- mask = cv2.resize(mask, (512,512))
132
- mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
133
-
134
- sentence = file['sentence']
135
- return {'im':im, 'mask':mask, 'sentence':sentence, 'name': name}
136
-
137
- def __len__(self):
138
- return len(self.files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/demos.py CHANGED
@@ -85,7 +85,32 @@ def create_demo_seg(process):
85
  with gr.Row():
86
  type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
87
  run_button = gr.Button(label="Run")
88
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
90
  fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
91
  base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
85
  with gr.Row():
86
  type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
87
  run_button = gr.Button(label="Run")
88
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=1, step=0.1)
89
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
90
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
91
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
92
+ with gr.Column():
93
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
94
+ ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
95
+ run_button.click(fn=process, inputs=ips, outputs=[result])
96
+ return demo
97
+
98
+ def create_demo_depth(process):
99
+ with gr.Blocks() as demo:
100
+ with gr.Row():
101
+ gr.Markdown('## T2I-Adapter (Depth)')
102
+ with gr.Row():
103
+ with gr.Column():
104
+ input_img = gr.Image(source='upload', type="numpy")
105
+ prompt = gr.Textbox(label="Prompt")
106
+ neg_prompt = gr.Textbox(label="Negative Prompt",
107
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
108
+ pos_prompt = gr.Textbox(label="Positive Prompt",
109
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
110
+ with gr.Row():
111
+ type_in = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
112
+ run_button = gr.Button(label="Run")
113
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the depth map to the result)", minimum=0, maximum=1, value=1, step=0.1)
114
  scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
115
  fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
116
  base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
demo/model.py CHANGED
@@ -4,7 +4,9 @@ from pytorch_lightning import seed_everything
4
  from ldm.models.diffusion.plms import PLMSSampler
5
  from ldm.modules.encoders.adapter import Adapter
6
  from ldm.util import instantiate_from_config
7
- from model_edge import pidinet
 
 
8
  import gradio as gr
9
  from omegaconf import OmegaConf
10
  import mmcv
@@ -13,7 +15,6 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process
13
  import os
14
  import cv2
15
  import numpy as np
16
- from seger import seger, Colorize
17
  import torch.nn.functional as F
18
 
19
  def preprocessing(image, device):
@@ -136,10 +137,8 @@ class Model_all:
136
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
137
  use_conv=False).to(device)
138
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
139
- self.model_edge = pidinet()
140
- ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
141
- self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
142
- self.model_edge.to(device)
143
 
144
  # segmentation part
145
  self.model_seger = seger().to(device)
@@ -147,6 +146,11 @@ class Model_all:
147
  self.coler = Colorize(n=182)
148
  self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
149
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
 
 
 
 
 
150
 
151
  # keypose part
152
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
@@ -248,6 +252,65 @@ class Model_all:
248
 
249
  return [im_edge, x_samples_ddim]
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  @torch.no_grad()
252
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
253
  con_strength, base_model):
4
  from ldm.models.diffusion.plms import PLMSSampler
5
  from ldm.modules.encoders.adapter import Adapter
6
  from ldm.util import instantiate_from_config
7
+ from ldm.modules.structure_condition.model_edge import pidinet
8
+ from ldm.modules.structure_condition.model_seg import seger, Colorize
9
+ from ldm.modules.structure_condition.midas.api import MiDaSInference
10
  import gradio as gr
11
  from omegaconf import OmegaConf
12
  import mmcv
15
  import os
16
  import cv2
17
  import numpy as np
 
18
  import torch.nn.functional as F
19
 
20
  def preprocessing(image, device):
137
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
138
  use_conv=False).to(device)
139
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
+ self.model_edge = pidinet().to(device)
141
+ self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
 
 
142
 
143
  # segmentation part
144
  self.model_seger = seger().to(device)
146
  self.coler = Colorize(n=182)
147
  self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
148
  self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
+ self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
+
151
+ # depth part
152
+ self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
153
+ self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
154
 
155
  # keypose part
156
  self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
252
 
253
  return [im_edge, x_samples_ddim]
254
 
255
+ @torch.no_grad()
256
+ def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
257
+ con_strength, base_model):
258
+ if self.current_base != base_model:
259
+ ckpt = os.path.join("models", base_model)
260
+ pl_sd = torch.load(ckpt, map_location="cuda")
261
+ if "state_dict" in pl_sd:
262
+ sd = pl_sd["state_dict"]
263
+ else:
264
+ sd = pl_sd
265
+ self.base_model.load_state_dict(sd, strict=False)
266
+ self.current_base = base_model
267
+ if 'anything' in base_model.lower():
268
+ self.load_vae()
269
+
270
+ con_strength = int((1 - con_strength) * 50)
271
+ if fix_sample == 'True':
272
+ seed_everything(42)
273
+ im = cv2.resize(input_img, (512, 512))
274
+
275
+ if type_in == 'Depth':
276
+ im_depth = im.copy()
277
+ depth = img2tensor(im).unsqueeze(0) / 255.
278
+ elif type_in == 'Image':
279
+ im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0
280
+ depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1)
281
+ depth -= torch.min(depth)
282
+ depth /= torch.max(depth)
283
+ im_depth = tensor2img(depth)
284
+
285
+ # extract condition features
286
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
287
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
288
+ features_adapter = self.model_depth(depth.to(self.device))
289
+ shape = [4, 64, 64]
290
+
291
+ # sampling
292
+ samples_ddim, _ = self.sampler.sample(S=50,
293
+ conditioning=c,
294
+ batch_size=1,
295
+ shape=shape,
296
+ verbose=False,
297
+ unconditional_guidance_scale=scale,
298
+ unconditional_conditioning=nc,
299
+ eta=0.0,
300
+ x_T=None,
301
+ features_adapter1=features_adapter,
302
+ mode='sketch',
303
+ con_strength=con_strength)
304
+
305
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
306
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
307
+ x_samples_ddim = x_samples_ddim.to('cpu')
308
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
309
+ x_samples_ddim = 255. * x_samples_ddim
310
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
311
+
312
+ return [im_depth, x_samples_ddim]
313
+
314
  @torch.no_grad()
315
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
316
  con_strength, base_model):
dist_util.py DELETED
@@ -1,91 +0,0 @@
1
- # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
- import functools
3
- import os
4
- import subprocess
5
- import torch
6
- import torch.distributed as dist
7
- import torch.multiprocessing as mp
8
- from torch.nn.parallel import DataParallel, DistributedDataParallel
9
-
10
-
11
- def init_dist(launcher, backend='nccl', **kwargs):
12
- if mp.get_start_method(allow_none=True) is None:
13
- mp.set_start_method('spawn')
14
- if launcher == 'pytorch':
15
- _init_dist_pytorch(backend, **kwargs)
16
- elif launcher == 'slurm':
17
- _init_dist_slurm(backend, **kwargs)
18
- else:
19
- raise ValueError(f'Invalid launcher type: {launcher}')
20
-
21
-
22
- def _init_dist_pytorch(backend, **kwargs):
23
- rank = int(os.environ['RANK'])
24
- num_gpus = torch.cuda.device_count()
25
- torch.cuda.set_device(rank % num_gpus)
26
- dist.init_process_group(backend=backend, **kwargs)
27
-
28
-
29
- def _init_dist_slurm(backend, port=None):
30
- """Initialize slurm distributed training environment.
31
-
32
- If argument ``port`` is not specified, then the master port will be system
33
- environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
34
- environment variable, then a default port ``29500`` will be used.
35
-
36
- Args:
37
- backend (str): Backend of torch.distributed.
38
- port (int, optional): Master port. Defaults to None.
39
- """
40
- proc_id = int(os.environ['SLURM_PROCID'])
41
- ntasks = int(os.environ['SLURM_NTASKS'])
42
- node_list = os.environ['SLURM_NODELIST']
43
- num_gpus = torch.cuda.device_count()
44
- torch.cuda.set_device(proc_id % num_gpus)
45
- addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
46
- # specify master port
47
- if port is not None:
48
- os.environ['MASTER_PORT'] = str(port)
49
- elif 'MASTER_PORT' in os.environ:
50
- pass # use MASTER_PORT in the environment variable
51
- else:
52
- # 29500 is torch.distributed default port
53
- os.environ['MASTER_PORT'] = '29500'
54
- os.environ['MASTER_ADDR'] = addr
55
- os.environ['WORLD_SIZE'] = str(ntasks)
56
- os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57
- os.environ['RANK'] = str(proc_id)
58
- dist.init_process_group(backend=backend)
59
-
60
-
61
- def get_dist_info():
62
- if dist.is_available():
63
- initialized = dist.is_initialized()
64
- else:
65
- initialized = False
66
- if initialized:
67
- rank = dist.get_rank()
68
- world_size = dist.get_world_size()
69
- else:
70
- rank = 0
71
- world_size = 1
72
- return rank, world_size
73
-
74
-
75
- def master_only(func):
76
-
77
- @functools.wraps(func)
78
- def wrapper(*args, **kwargs):
79
- rank, _ = get_dist_info()
80
- if rank == 0:
81
- return func(*args, **kwargs)
82
-
83
- return wrapper
84
-
85
- def get_bare_model(net):
86
- """Get bare model, especially under wrapping with
87
- DistributedDataParallel or DataParallel.
88
- """
89
- if isinstance(net, (DataParallel, DistributedDataParallel)):
90
- net = net.module
91
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environment.yaml CHANGED
File without changes
examples/edit_cat/edge.png DELETED
Binary file (5.98 kB)
examples/edit_cat/edge_2.png DELETED
Binary file (13.3 kB)
examples/edit_cat/im.png DELETED
Binary file (508 kB)
examples/edit_cat/mask.png DELETED
Binary file (4.65 kB)
examples/keypose/iron.png DELETED
Binary file (15.6 kB)
examples/seg/dinner.png DELETED
Binary file (17.8 kB)
examples/seg/motor.png DELETED
Binary file (20.9 kB)
examples/seg_sketch/edge.png DELETED
Binary file (12.9 kB)
examples/seg_sketch/mask.png DELETED
Binary file (22.2 kB)
examples/sketch/car.png DELETED
Binary file (13.2 kB)
examples/sketch/girl.jpeg DELETED
Binary file (214 kB)
examples/sketch/human.png DELETED
Binary file (768 kB)
examples/sketch/scenery.jpg DELETED
Binary file (99.8 kB)
examples/sketch/scenery2.jpg DELETED
Binary file (144 kB)
gradio_keypose.py DELETED
@@ -1,254 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- from basicsr.utils import img2tensor, tensor2img
8
- from pytorch_lightning import seed_everything
9
- from ldm.models.diffusion.plms import PLMSSampler
10
- from ldm.modules.encoders.adapter import Adapter
11
- from ldm.util import instantiate_from_config
12
- from model_edge import pidinet
13
- import gradio as gr
14
- from omegaconf import OmegaConf
15
- import mmcv
16
- from mmdet.apis import inference_detector, init_detector
17
- from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
18
-
19
- skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
20
- [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
21
-
22
- pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
23
- [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
24
- [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
25
-
26
- pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
27
- [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
28
- [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
29
- [51, 153, 255], [51, 153, 255], [51, 153, 255]]
30
-
31
- def imshow_keypoints(img,
32
- pose_result,
33
- skeleton=None,
34
- kpt_score_thr=0.1,
35
- pose_kpt_color=None,
36
- pose_link_color=None,
37
- radius=4,
38
- thickness=1):
39
- """Draw keypoints and links on an image.
40
-
41
- Args:
42
- img (ndarry): The image to draw poses on.
43
- pose_result (list[kpts]): The poses to draw. Each element kpts is
44
- a set of K keypoints as an Kx3 numpy.ndarray, where each
45
- keypoint is represented as x, y, score.
46
- kpt_score_thr (float, optional): Minimum score of keypoints
47
- to be shown. Default: 0.3.
48
- pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
49
- the keypoint will not be drawn.
50
- pose_link_color (np.array[Mx3]): Color of M links. If None, the
51
- links will not be drawn.
52
- thickness (int): Thickness of lines.
53
- """
54
-
55
- img_h, img_w, _ = img.shape
56
- img = np.zeros(img.shape)
57
-
58
- for idx, kpts in enumerate(pose_result):
59
- if idx > 1:
60
- continue
61
- kpts = kpts['keypoints']
62
- # print(kpts)
63
- kpts = np.array(kpts, copy=False)
64
-
65
- # draw each point on image
66
- if pose_kpt_color is not None:
67
- assert len(pose_kpt_color) == len(kpts)
68
-
69
- for kid, kpt in enumerate(kpts):
70
- x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
71
-
72
- if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
73
- # skip the point that should not be drawn
74
- continue
75
-
76
- color = tuple(int(c) for c in pose_kpt_color[kid])
77
- cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
78
-
79
- # draw links
80
- if skeleton is not None and pose_link_color is not None:
81
- assert len(pose_link_color) == len(skeleton)
82
-
83
- for sk_id, sk in enumerate(skeleton):
84
- pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
85
- pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
86
-
87
- if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
88
- or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
89
- or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
90
- # skip the link that should not be drawn
91
- continue
92
- color = tuple(int(c) for c in pose_link_color[sk_id])
93
- cv2.line(img, pos1, pos2, color, thickness=thickness)
94
-
95
- return img
96
-
97
- def load_model_from_config(config, ckpt, verbose=False):
98
- print(f"Loading model from {ckpt}")
99
- pl_sd = torch.load(ckpt, map_location="cpu")
100
- if "global_step" in pl_sd:
101
- print(f"Global Step: {pl_sd['global_step']}")
102
- if "state_dict" in pl_sd:
103
- sd = pl_sd["state_dict"]
104
- else:
105
- sd = pl_sd
106
- model = instantiate_from_config(config.model)
107
- m, u = model.load_state_dict(sd, strict=False)
108
-
109
- model.cuda()
110
- model.eval()
111
- return model
112
-
113
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
114
- config = OmegaConf.load("configs/stable-diffusion/test_keypose.yaml")
115
- config.model.params.cond_stage_config.params.device = device
116
- model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
117
- current_base = 'sd-v1-4.ckpt'
118
- model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
119
- model_ad.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
120
- sampler = PLMSSampler(model)
121
- ## mmpose
122
- det_config = 'models/faster_rcnn_r50_fpn_coco.py'
123
- det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
124
- pose_config = 'models/hrnet_w48_coco_256x192.py'
125
- pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
126
- det_cat_id = 1
127
- bbox_thr = 0.2
128
- ## detector
129
- det_config_mmcv = mmcv.Config.fromfile(det_config)
130
- det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
131
- pose_config_mmcv = mmcv.Config.fromfile(pose_config)
132
- pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
133
- W, H = 512, 512
134
-
135
-
136
- def process(input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
137
- global current_base
138
- if current_base != base_model:
139
- ckpt = os.path.join("models", base_model)
140
- pl_sd = torch.load(ckpt, map_location="cpu")
141
- if "state_dict" in pl_sd:
142
- sd = pl_sd["state_dict"]
143
- else:
144
- sd = pl_sd
145
- model.load_state_dict(sd, strict=False)
146
- current_base = base_model
147
- con_strength = int((1-con_strength)*50)
148
- if fix_sample == 'True':
149
- seed_everything(42)
150
- im = cv2.resize(input_img,(W,H))
151
-
152
- if type_in == 'Keypose':
153
- im_pose = im.copy()
154
- im = img2tensor(im).unsqueeze(0)/255.
155
- elif type_in == 'Image':
156
- image = im.copy()
157
- im = img2tensor(im).unsqueeze(0)/255.
158
- mmdet_results = inference_detector(det_model, image)
159
- # keep the person class bounding boxes.
160
- person_results = process_mmdet_results(mmdet_results, det_cat_id)
161
-
162
- # optional
163
- return_heatmap = False
164
- dataset = pose_model.cfg.data['test']['type']
165
-
166
- # e.g. use ('backbone', ) to return backbone feature
167
- output_layer_names = None
168
- pose_results, returned_outputs = inference_top_down_pose_model(
169
- pose_model,
170
- image,
171
- person_results,
172
- bbox_thr=bbox_thr,
173
- format='xyxy',
174
- dataset=dataset,
175
- dataset_info=None,
176
- return_heatmap=return_heatmap,
177
- outputs=output_layer_names)
178
-
179
- # show the results
180
- im_pose = imshow_keypoints(
181
- image,
182
- pose_results,
183
- skeleton=skeleton,
184
- pose_kpt_color=pose_kpt_color,
185
- pose_link_color=pose_link_color,
186
- radius=2,
187
- thickness=2)
188
- im_pose = cv2.resize(im_pose,(W,H))
189
-
190
- with torch.no_grad():
191
- c = model.get_learned_conditioning([prompt])
192
- nc = model.get_learned_conditioning([neg_prompt])
193
- # extract condition features
194
- pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
195
- pose = pose.unsqueeze(0)
196
- features_adapter = model_ad(pose.to(device))
197
-
198
- shape = [4, W//8, H//8]
199
-
200
- # sampling
201
- samples_ddim, _ = sampler.sample(S=50,
202
- conditioning=c,
203
- batch_size=1,
204
- shape=shape,
205
- verbose=False,
206
- unconditional_guidance_scale=scale,
207
- unconditional_conditioning=nc,
208
- eta=0.0,
209
- x_T=None,
210
- features_adapter1=features_adapter,
211
- mode = 'sketch',
212
- con_strength = con_strength)
213
-
214
- x_samples_ddim = model.decode_first_stage(samples_ddim)
215
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
216
- x_samples_ddim = x_samples_ddim.to('cpu')
217
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
218
- x_samples_ddim = 255.*x_samples_ddim
219
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
220
-
221
- return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
222
-
223
- DESCRIPTION = '''# T2I-Adapter (Keypose)
224
- [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
225
-
226
- This gradio demo is for keypose-guided generation. The current functions include:
227
- - Keypose to Image Generation
228
- - Image to Image Generation
229
- - Generation with **Anything** setting
230
- '''
231
- block = gr.Blocks().queue()
232
- with block:
233
- with gr.Row():
234
- gr.Markdown(DESCRIPTION)
235
- with gr.Row():
236
- with gr.Column():
237
- input_img = gr.Image(source='upload', type="numpy")
238
- prompt = gr.Textbox(label="Prompt")
239
- neg_prompt = gr.Textbox(label="Negative Prompt",
240
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
241
- with gr.Row():
242
- type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
243
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
244
- run_button = gr.Button(label="Run")
245
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
246
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
247
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
248
- with gr.Column():
249
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
250
- ips = [input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
251
- run_button.click(fn=process, inputs=ips, outputs=[result])
252
-
253
- block.launch(server_name='0.0.0.0')
254
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_sketch.py DELETED
@@ -1,147 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- from basicsr.utils import img2tensor, tensor2img
8
- from pytorch_lightning import seed_everything
9
- from ldm.models.diffusion.plms import PLMSSampler
10
- from ldm.modules.encoders.adapter import Adapter
11
- from ldm.util import instantiate_from_config
12
- from model_edge import pidinet
13
- import gradio as gr
14
- from omegaconf import OmegaConf
15
-
16
-
17
- def load_model_from_config(config, ckpt, verbose=False):
18
- print(f"Loading model from {ckpt}")
19
- pl_sd = torch.load(ckpt, map_location="cpu")
20
- if "global_step" in pl_sd:
21
- print(f"Global Step: {pl_sd['global_step']}")
22
- if "state_dict" in pl_sd:
23
- sd = pl_sd["state_dict"]
24
- else:
25
- sd = pl_sd
26
- model = instantiate_from_config(config.model)
27
- m, u = model.load_state_dict(sd, strict=False)
28
- # if len(m) > 0 and verbose:
29
- # print("missing keys:")
30
- # print(m)
31
- # if len(u) > 0 and verbose:
32
- # print("unexpected keys:")
33
- # print(u)
34
-
35
- model.cuda()
36
- model.eval()
37
- return model
38
-
39
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
- config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
41
- config.model.params.cond_stage_config.params.device = device
42
- model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
43
- current_base = 'sd-v1-4.ckpt'
44
- model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
45
- model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
46
- net_G = pidinet()
47
- ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
48
- net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
49
- net_G.to(device)
50
- sampler = PLMSSampler(model)
51
- save_memory=True
52
- W, H = 512, 512
53
-
54
-
55
- def process(input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
56
- global current_base
57
- if current_base != base_model:
58
- ckpt = os.path.join("models", base_model)
59
- pl_sd = torch.load(ckpt, map_location="cpu")
60
- if "state_dict" in pl_sd:
61
- sd = pl_sd["state_dict"]
62
- else:
63
- sd = pl_sd
64
- model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
65
- current_base = base_model
66
- con_strength = int((1-con_strength)*50)
67
- if fix_sample == 'True':
68
- seed_everything(42)
69
- im = cv2.resize(input_img,(W,H))
70
-
71
- if type_in == 'Sketch':
72
- if color_back == 'White':
73
- im = 255-im
74
- im_edge = im.copy()
75
- im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
76
- im = im>0.5
77
- im = im.float()
78
- elif type_in == 'Image':
79
- im = img2tensor(im).unsqueeze(0)/255.
80
- im = net_G(im.to(device))[-1]
81
- im = im>0.5
82
- im = im.float()
83
- im_edge = tensor2img(im)
84
-
85
- with torch.no_grad():
86
- c = model.get_learned_conditioning([prompt])
87
- nc = model.get_learned_conditioning([neg_prompt])
88
- # extract condition features
89
- features_adapter = model_ad(im.to(device))
90
- shape = [4, W//8, H//8]
91
-
92
- # sampling
93
- samples_ddim, _ = sampler.sample(S=50,
94
- conditioning=c,
95
- batch_size=1,
96
- shape=shape,
97
- verbose=False,
98
- unconditional_guidance_scale=scale,
99
- unconditional_conditioning=nc,
100
- eta=0.0,
101
- x_T=None,
102
- features_adapter1=features_adapter,
103
- mode = 'sketch',
104
- con_strength = con_strength)
105
-
106
- x_samples_ddim = model.decode_first_stage(samples_ddim)
107
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
108
- x_samples_ddim = x_samples_ddim.to('cpu')
109
- x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
110
- x_samples_ddim = 255.*x_samples_ddim
111
- x_samples_ddim = x_samples_ddim.astype(np.uint8)
112
-
113
- return [im_edge, x_samples_ddim]
114
-
115
- DESCRIPTION = '''# T2I-Adapter (Sketch)
116
- [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
117
-
118
- This gradio demo is for sketch-guided generation. The current functions include:
119
- - Sketch to Image Generation
120
- - Image to Image Generation
121
- - Generation with **Anything** setting
122
- '''
123
- block = gr.Blocks().queue()
124
- with block:
125
- with gr.Row():
126
- gr.Markdown(DESCRIPTION)
127
- with gr.Row():
128
- with gr.Column():
129
- input_img = gr.Image(source='upload', type="numpy")
130
- prompt = gr.Textbox(label="Prompt")
131
- neg_prompt = gr.Textbox(label="Negative Prompt",
132
- value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
133
- with gr.Row():
134
- type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
135
- color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
136
- run_button = gr.Button(label="Run")
137
- con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
138
- scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
139
- fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
140
- base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
141
- with gr.Column():
142
- result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
143
- ips = [input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
144
- run_button.click(fn=process, inputs=ips, outputs=[result])
145
-
146
- block.launch(server_name='0.0.0.0')
147
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/data/__init__.py CHANGED
File without changes
ldm/data/base.py CHANGED
File without changes
ldm/data/imagenet.py CHANGED
File without changes
ldm/data/lsun.py CHANGED
File without changes
ldm/lr_scheduler.py CHANGED
File without changes
ldm/models/autoencoder.py CHANGED
File without changes
ldm/models/diffusion/__init__.py CHANGED
File without changes
ldm/models/diffusion/classifier.py CHANGED
File without changes
ldm/models/diffusion/ddim.py CHANGED
File without changes
ldm/models/diffusion/ddpm.py CHANGED
File without changes
ldm/models/diffusion/dpm_solver/__init__.py CHANGED
File without changes
ldm/models/diffusion/dpm_solver/dpm_solver.py CHANGED
File without changes
ldm/models/diffusion/dpm_solver/sampler.py CHANGED
File without changes
ldm/models/diffusion/plms.py CHANGED
File without changes
ldm/modules/attention.py CHANGED
File without changes