dmolino commited on
Commit
168a510
·
verified ·
1 Parent(s): 8773294

Upload 276 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/model/codi_2.yaml +21 -0
  2. configs/model/openai_unet.yaml +3 -1
  3. configs/model/openai_unet_2.yaml +87 -0
  4. configs/model/optimus.yaml +3 -4
  5. configs/model/prova.yaml +1 -1
  6. core/__pycache__/cfg_helper.cpython-311.pyc +0 -0
  7. core/__pycache__/cfg_helper.cpython-38.pyc +0 -0
  8. core/__pycache__/cfg_holder.cpython-311.pyc +0 -0
  9. core/__pycache__/cfg_holder.cpython-38.pyc +0 -0
  10. core/__pycache__/sync.cpython-311.pyc +0 -0
  11. core/__pycache__/sync.cpython-38.pyc +0 -0
  12. core/common/__pycache__/utils.cpython-311.pyc +0 -0
  13. core/common/__pycache__/utils.cpython-38.pyc +0 -0
  14. core/common/utils.py +0 -3
  15. core/models/__pycache__/__init__.cpython-311.pyc +0 -0
  16. core/models/__pycache__/codi.cpython-311.pyc +0 -0
  17. core/models/__pycache__/codi_2.cpython-311.pyc +0 -0
  18. core/models/__pycache__/dani_model.cpython-311.pyc +0 -0
  19. core/models/__pycache__/ema.cpython-311.pyc +0 -0
  20. core/models/__pycache__/ema.cpython-38.pyc +0 -0
  21. core/models/__pycache__/model_module_infer.cpython-311.pyc +0 -0
  22. core/models/__pycache__/model_module_infer.cpython-38.pyc +0 -0
  23. core/models/__pycache__/sd.cpython-311.pyc +0 -0
  24. core/models/__pycache__/sd.cpython-38.pyc +0 -0
  25. core/models/codi.py +5 -4
  26. core/models/codi_2.py +226 -221
  27. core/models/common/__pycache__/get_model.cpython-311.pyc +0 -0
  28. core/models/common/__pycache__/get_model.cpython-38.pyc +0 -0
  29. core/models/common/__pycache__/get_optimizer.cpython-311.pyc +0 -0
  30. core/models/common/__pycache__/get_optimizer.cpython-38.pyc +0 -0
  31. core/models/common/__pycache__/get_scheduler.cpython-311.pyc +0 -0
  32. core/models/common/__pycache__/get_scheduler.cpython-38.pyc +0 -0
  33. core/models/common/__pycache__/utils.cpython-311.pyc +0 -0
  34. core/models/common/__pycache__/utils.cpython-38.pyc +0 -0
  35. core/models/dani_model.py +3 -1
  36. core/models/ddim/__pycache__/ddim.cpython-311.pyc +0 -0
  37. core/models/ddim/__pycache__/ddim.cpython-38.pyc +0 -0
  38. core/models/ddim/__pycache__/ddim_vd.cpython-311.pyc +0 -0
  39. core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc +0 -0
  40. core/models/ddim/__pycache__/diffusion_utils.cpython-311.pyc +0 -0
  41. core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
  42. core/models/ddim/ddim.py +10 -1
  43. core/models/ddim/ddim_vd.py +1 -1
  44. core/models/encoders/__pycache__/clap.cpython-311.pyc +0 -0
  45. core/models/encoders/__pycache__/clip.cpython-311.pyc +0 -0
  46. core/models/encoders/__pycache__/clip.cpython-38.pyc +0 -0
  47. core/models/encoders/clap_modules/__pycache__/__init__.cpython-311.pyc +0 -0
  48. core/models/encoders/clap_modules/open_clip/__pycache__/__init__.cpython-311.pyc +0 -0
  49. core/models/encoders/clap_modules/open_clip/__pycache__/factory.cpython-311.pyc +0 -0
  50. core/models/encoders/clap_modules/open_clip/__pycache__/feature_fusion.cpython-311.pyc +0 -0
configs/model/codi_2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########
2
+ # CoDi #
3
+ ########
4
+
5
+ codi_2:
6
+ type: codi_2
7
+ symbol: codi_2
8
+ find_unused_parameters: true
9
+ args:
10
+ autokl_cfg: MODEL(sd_autoencoder)
11
+ optimus_cfg: MODEL(optimus_vae)
12
+ clip_cfg: MODEL(clip_frozen)
13
+ unet_config: MODEL(openai_unet_codi_2)
14
+ beta_linear_start: 0.00085
15
+ beta_linear_end: 0.012
16
+ timesteps: 1000
17
+ vision_scale_factor: 0.18215
18
+ text_scale_factor: 4.3108
19
+ audio_scale_factor: 0.9228
20
+ use_ema: false
21
+ parameterization : "eps"
configs/model/openai_unet.yaml CHANGED
@@ -82,4 +82,6 @@ openai_unet_codi:
82
  unet_image_cfg: MODEL(openai_unet_2d)
83
  unet_text_cfg: MODEL(openai_unet_0dmd)
84
  unet_audio_cfg: MODEL(openai_unet_2d_audio)
85
- model_type: ['video', 'image', 'text']
 
 
 
82
  unet_image_cfg: MODEL(openai_unet_2d)
83
  unet_text_cfg: MODEL(openai_unet_0dmd)
84
  unet_audio_cfg: MODEL(openai_unet_2d_audio)
85
+ # model_type: ['video', 'image']
86
+ # model_type: ['text']
87
+ model_type: ['audio', 'image', 'video', 'text']
configs/model/openai_unet_2.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_unet_sd:
2
+ type: openai_unet
3
+ args:
4
+ image_size: null # no use
5
+ in_channels: 4
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: [ 2, 2, 2, 2 ]
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ num_heads: 8
12
+ use_spatial_transformer: True
13
+ transformer_depth: 1
14
+ context_dim: 768
15
+ use_checkpoint: True
16
+ legacy: False
17
+
18
+ openai_unet_dual_context:
19
+ super_cfg: openai_unet_sd
20
+ type: openai_unet_dual_context
21
+
22
+ ########################
23
+ # Code cleaned version #
24
+ ########################
25
+
26
+ openai_unet_2d_audio:
27
+ type: openai_unet_2d
28
+ args:
29
+ input_channels: 8
30
+ model_channels: 192
31
+ output_channels: 8
32
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ with_attn: [true, true, true, false]
35
+ channel_mult_connector: [1, 2, 4]
36
+ num_noattn_blocks_connector: [1, 1, 1]
37
+ with_connector: [True, True, True, False]
38
+ connector_output_channel: 1280
39
+ num_heads: 8
40
+ context_dim: 768
41
+ use_checkpoint: False
42
+
43
+ openai_unet_2d:
44
+ type: openai_unet_2d
45
+ args:
46
+ input_channels: 4
47
+ model_channels: 320
48
+ output_channels: 4
49
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
50
+ channel_mult: [ 1, 2, 4, 4 ]
51
+ with_attn: [true, true, true, false]
52
+ channel_mult_connector: [1, 2, 4]
53
+ num_noattn_blocks_connector: [1, 1, 1]
54
+ with_connector: [True, True, True, False]
55
+ connector_output_channel: 1280
56
+ num_heads: 8
57
+ context_dim: 768
58
+ use_checkpoint: True
59
+ use_video_architecture: True
60
+
61
+ openai_unet_0dmd:
62
+ type: openai_unet_0dmd
63
+ args:
64
+ input_channels: 768
65
+ model_channels: 320
66
+ output_channels: 768
67
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
68
+ channel_mult: [ 1, 2, 4, 4 ]
69
+ second_dim: [ 4, 4, 4, 4 ]
70
+ with_attn: [true, true, true, false]
71
+ num_noattn_blocks_connector: [1, 1, 1]
72
+ second_dim_connector: [4, 4, 4]
73
+ with_connector: [True, True, True, False]
74
+ connector_output_channel: 1280
75
+ num_heads: 8
76
+ context_dim: 768
77
+ use_checkpoint: True
78
+
79
+ openai_unet_codi_2:
80
+ type: openai_unet_codi_2
81
+ args:
82
+ unet_frontal_cfg: MODEL(openai_unet_2d)
83
+ unet_lateral_cfg: MODEL(openai_unet_2d)
84
+ unet_text_cfg: MODEL(openai_unet_0dmd)
85
+ # model_type: ['lateral', 'text']
86
+ # model_type: ['text']
87
+ model_type: ['frontal', 'lateral', 'text']
configs/model/optimus.yaml CHANGED
@@ -100,8 +100,7 @@ optimus_vae:
100
  tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
101
  args:
102
  latent_size: 768
103
- beta: 1.0
104
- fb_mode: 0
105
- length_weighted_loss: false
106
  dim_target_kl : 3.0
107
-
 
100
  tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
101
  args:
102
  latent_size: 768
103
+ beta : 1.0
104
+ fb_mode : 0
105
+ length_weighted_loss : false
106
  dim_target_kl : 3.0
 
configs/model/prova.yaml CHANGED
@@ -82,4 +82,4 @@ prova:
82
  unet_frontal_cfg: MODEL(openai_unet_2d)
83
  unet_lateral_cfg: MODEL(openai_unet_2d)
84
  unet_text_cfg: MODEL(openai_unet_0dmd)
85
- model_type: ['text']
 
82
  unet_frontal_cfg: MODEL(openai_unet_2d)
83
  unet_lateral_cfg: MODEL(openai_unet_2d)
84
  unet_text_cfg: MODEL(openai_unet_0dmd)
85
+ model_type: ['frontal', 'lateral', 'text']
core/__pycache__/cfg_helper.cpython-311.pyc ADDED
Binary file (31.2 kB). View file
 
core/__pycache__/cfg_helper.cpython-38.pyc CHANGED
Binary files a/core/__pycache__/cfg_helper.cpython-38.pyc and b/core/__pycache__/cfg_helper.cpython-38.pyc differ
 
core/__pycache__/cfg_holder.cpython-311.pyc ADDED
Binary file (1.7 kB). View file
 
core/__pycache__/cfg_holder.cpython-38.pyc CHANGED
Binary files a/core/__pycache__/cfg_holder.cpython-38.pyc and b/core/__pycache__/cfg_holder.cpython-38.pyc differ
 
core/__pycache__/sync.cpython-311.pyc ADDED
Binary file (11.8 kB). View file
 
core/__pycache__/sync.cpython-38.pyc CHANGED
Binary files a/core/__pycache__/sync.cpython-38.pyc and b/core/__pycache__/sync.cpython-38.pyc differ
 
core/common/__pycache__/utils.cpython-311.pyc ADDED
Binary file (23.2 kB). View file
 
core/common/__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/core/common/__pycache__/utils.cpython-38.pyc and b/core/common/__pycache__/utils.cpython-38.pyc differ
 
core/common/utils.py CHANGED
@@ -99,7 +99,6 @@ def remove_duplicate_word(tx):
99
 
100
 
101
  def regularize_image(x, image_size=512):
102
- BICUBIC = T.InterpolationMode.BICUBIC
103
  if isinstance(x, str):
104
  x = Image.open(x)
105
  size = min(x.size)
@@ -111,7 +110,6 @@ def regularize_image(x, image_size=512):
111
  size = min(x.size)
112
  elif isinstance(x, torch.Tensor):
113
  # normalize to [0, 1]
114
- x = x/255.0
115
  size = min(x.size()[1:])
116
  else:
117
  assert False, 'Unknown image type'
@@ -126,7 +124,6 @@ def regularize_image(x, image_size=512):
126
  T.ToTensor(),
127
  ])
128
  x = transforms(x)
129
-
130
  assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \
131
  'Wrong image size'
132
  """
 
99
 
100
 
101
  def regularize_image(x, image_size=512):
 
102
  if isinstance(x, str):
103
  x = Image.open(x)
104
  size = min(x.size)
 
110
  size = min(x.size)
111
  elif isinstance(x, torch.Tensor):
112
  # normalize to [0, 1]
 
113
  size = min(x.size()[1:])
114
  else:
115
  assert False, 'Unknown image type'
 
124
  T.ToTensor(),
125
  ])
126
  x = transforms(x)
 
127
  assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \
128
  'Wrong image size'
129
  """
core/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (466 Bytes). View file
 
core/models/__pycache__/codi.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
core/models/__pycache__/codi_2.cpython-311.pyc ADDED
Binary file (14.4 kB). View file
 
core/models/__pycache__/dani_model.cpython-311.pyc ADDED
Binary file (8.96 kB). View file
 
core/models/__pycache__/ema.cpython-311.pyc ADDED
Binary file (5.51 kB). View file
 
core/models/__pycache__/ema.cpython-38.pyc CHANGED
Binary files a/core/models/__pycache__/ema.cpython-38.pyc and b/core/models/__pycache__/ema.cpython-38.pyc differ
 
core/models/__pycache__/model_module_infer.cpython-311.pyc ADDED
Binary file (9.05 kB). View file
 
core/models/__pycache__/model_module_infer.cpython-38.pyc CHANGED
Binary files a/core/models/__pycache__/model_module_infer.cpython-38.pyc and b/core/models/__pycache__/model_module_infer.cpython-38.pyc differ
 
core/models/__pycache__/sd.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
core/models/__pycache__/sd.cpython-38.pyc CHANGED
Binary files a/core/models/__pycache__/sd.cpython-38.pyc and b/core/models/__pycache__/sd.cpython-38.pyc differ
 
core/models/codi.py CHANGED
@@ -75,16 +75,16 @@ class CoDi(DDPM):
75
  @torch.no_grad()
76
  def optimus_encode(self, text):
77
  if isinstance(text, List):
78
- tokenizer = self.optimus.tokenizer_encoder
79
- token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
80
  token_id = []
81
  for tokeni in token:
82
- token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
83
- token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
84
  token_id.append(torch.LongTensor(token_sentence))
85
  token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
86
  else:
87
  token_id = text
 
88
  z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
89
  z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
90
  return z_mu.squeeze(1) * self.text_scale_factor
@@ -92,6 +92,7 @@ class CoDi(DDPM):
92
  @torch.no_grad()
93
  def optimus_decode(self, z, temperature=1.0, max_length=30):
94
  z = 1.0 / self.text_scale_factor * z
 
95
  return self.optimus.decode(z, temperature, max_length=max_length)
96
 
97
  @torch.no_grad()
 
75
  @torch.no_grad()
76
  def optimus_encode(self, text):
77
  if isinstance(text, List):
78
+ token = [self.optimus.tokenizer_encoder.tokenize(sentence.lower()) for sentence in text]
 
79
  token_id = []
80
  for tokeni in token:
81
+ token_sentence = [self.optimus.tokenizer_encoder._convert_token_to_id(i) for i in tokeni]
82
+ token_sentence = self.optimus.tokenizer_encoder.add_special_tokens_single_sentence(token_sentence)
83
  token_id.append(torch.LongTensor(token_sentence))
84
  token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
85
  else:
86
  token_id = text
87
+ token_id = token_id.to(self.device)
88
  z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
89
  z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
90
  return z_mu.squeeze(1) * self.text_scale_factor
 
92
  @torch.no_grad()
93
  def optimus_decode(self, z, temperature=1.0, max_length=30):
94
  z = 1.0 / self.text_scale_factor * z
95
+ z = z.to(self.device)
96
  return self.optimus.decode(z, temperature, max_length=max_length)
97
 
98
  @torch.no_grad()
core/models/codi_2.py CHANGED
@@ -1,221 +1,226 @@
1
- from typing import Dict, List
2
- import os
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import numpy as np
8
- import numpy.random as npr
9
- import copy
10
- from functools import partial
11
- from contextlib import contextmanager
12
-
13
- from .common.get_model import get_model, register
14
- from .sd import DDPM
15
-
16
- version = '0'
17
- symbol = 'thesis_model'
18
-
19
-
20
- @register('thesis_model', version)
21
- class CoDi(DDPM):
22
- def __init__(self,
23
- autokl_cfg=None,
24
- optimus_cfg=None,
25
- clip_cfg=None,
26
- vision_scale_factor=0.1812,
27
- text_scale_factor=4.3108,
28
- audio_scale_factor=0.9228,
29
- scale_by_std=False,
30
- *args,
31
- **kwargs):
32
- super().__init__(*args, **kwargs)
33
-
34
- if autokl_cfg is not None:
35
- self.autokl = get_model()(autokl_cfg)
36
-
37
- if optimus_cfg is not None:
38
- self.optimus = get_model()(optimus_cfg)
39
-
40
- if clip_cfg is not None:
41
- self.clip = get_model()(clip_cfg)
42
-
43
- if not scale_by_std:
44
- self.vision_scale_factor = vision_scale_factor
45
- self.text_scale_factor = text_scale_factor
46
- self.audio_scale_factor = audio_scale_factor
47
- else:
48
- self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
49
- self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
50
- self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
51
-
52
- @property
53
- def device(self):
54
- return next(self.parameters()).device
55
-
56
- @torch.no_grad()
57
- def autokl_encode(self, image):
58
- encoder_posterior = self.autokl.encode(image)
59
- z = encoder_posterior.sample().to(image.dtype)
60
- return self.vision_scale_factor * z
61
-
62
- @torch.no_grad()
63
- def autokl_decode(self, z):
64
- z = 1. / self.vision_scale_factor * z
65
- return self.autokl.decode(z)
66
-
67
- @torch.no_grad()
68
- def optimus_encode(self, text):
69
- if isinstance(text, List):
70
- tokenizer = self.optimus.tokenizer_encoder
71
- token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
72
- token_id = []
73
- for tokeni in token:
74
- token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
75
- token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
76
- token_id.append(torch.LongTensor(token_sentence))
77
- token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
78
- else:
79
- token_id = text
80
- z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
81
- z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
82
- return z_mu.squeeze(1) * self.text_scale_factor
83
-
84
- @torch.no_grad()
85
- def optimus_decode(self, z, temperature=1.0):
86
- z = 1.0 / self.text_scale_factor * z
87
- return self.optimus.decode(z, temperature)
88
-
89
- @torch.no_grad()
90
- def clip_encode_text(self, text, encode_type='encode_text'):
91
- swap_type = self.clip.encode_type
92
- self.clip.encode_type = encode_type
93
- embedding = self.clip(text, encode_type)
94
- self.clip.encode_type = swap_type
95
- return embedding
96
-
97
- @torch.no_grad()
98
- def clip_encode_vision(self, vision, encode_type='encode_vision'):
99
- swap_type = self.clip.encode_type
100
- self.clip.encode_type = encode_type
101
- embedding = self.clip(vision, encode_type)
102
- self.clip.encode_type = swap_type
103
- return embedding
104
-
105
- @torch.no_grad()
106
- def clap_encode_audio(self, audio):
107
- embedding = self.clap(audio)
108
- return embedding
109
-
110
- def forward(self, x=None, c=None, noise=None, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
111
- if isinstance(x, list):
112
- t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
113
- else:
114
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
115
- return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents, env_enc)
116
-
117
- def apply_model(self, x_noisy, t, cond, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
118
- return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc=env_enc)
119
-
120
- def get_pixel_loss(self, pred, target, mean=True):
121
- if self.loss_type == 'l1':
122
- loss = (target - pred).abs()
123
- if mean:
124
- loss = loss.mean()
125
- elif self.loss_type == 'l2':
126
- if mean:
127
- loss = torch.nn.functional.mse_loss(target, pred)
128
- else:
129
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
130
- else:
131
- raise NotImplementedError("unknown loss type '{loss_type}'")
132
- loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
133
- return loss
134
-
135
- def get_text_loss(self, pred, target):
136
- if self.loss_type == 'l1':
137
- loss = (target - pred).abs()
138
- elif self.loss_type == 'l2':
139
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
140
- loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
141
- return loss
142
-
143
- def p_losses(self, x_start, cond, t, noise=None, xtype='frontal', ctype='text', u=None,
144
- return_algined_latents=False, env_enc=False):
145
- if isinstance(x_start, list):
146
- noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
147
- x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in
148
- zip(x_start, noise)]
149
- if not env_enc:
150
- model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
151
- else:
152
- model_output, h_con = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
153
- if return_algined_latents:
154
- return model_output
155
-
156
- loss_dict = {}
157
-
158
- if self.parameterization == "x0":
159
- target = x_start
160
- elif self.parameterization == "eps":
161
- target = noise
162
- else:
163
- raise NotImplementedError()
164
-
165
- loss = 0.0
166
- for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
167
- if xtype_i == 'frontal':
168
- loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
169
- elif xtype_i == 'text':
170
- loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
171
- elif xtype_i == 'lateral':
172
- loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
173
- loss += loss_simple.mean()
174
-
175
- # Controlliamo se il modello ha restituito anche h_con
176
- # In tal caso, abbiamo le rappresentazioni latenti delle due modalità
177
- # estratte dagli environmental encoder, essendo due tensori di dimensione batch_sizex1x1280
178
- # possiamo utilizzarli per calcolare anche un termine di contrastive loss (crossentropy come in CLIP)
179
- if h_con is not None:
180
- def similarity(z_a, z_b):
181
- return F.cosine_similarity(z_a, z_b)
182
-
183
- z_a, z_b = h_con
184
-
185
- z_a = z_a / z_a.norm(dim=-1, keepdim=True)
186
- z_b = z_b / z_b.norm(dim=-1, keepdim=True)
187
-
188
- logits_a = z_a.squeeze() @ z_b.squeeze().t()
189
- logits_b = z_a.squeeze() @ z_b.squeeze().t()
190
-
191
- labels = torch.arange(len(z_a)).to(z_a.device)
192
-
193
- loss_a = F.cross_entropy(logits_a, labels)
194
- loss_b = F.cross_entropy(logits_b, labels)
195
-
196
- loss_con = (loss_a + loss_b) / 2
197
- loss += loss_con
198
- return loss / len(xtype)
199
-
200
- else:
201
- noise = torch.randn_like(x_start) if noise is None else noise
202
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
203
- model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
204
-
205
- loss_dict = {}
206
-
207
- if self.parameterization == "x0":
208
- target = x_start
209
- elif self.parameterization == "eps":
210
- target = noise
211
- else:
212
- raise NotImplementedError()
213
-
214
- if xtype == 'frontal':
215
- loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
216
- elif xtype == 'text':
217
- loss_simple = self.get_text_loss(model_output, target).mean([1])
218
- elif xtype == 'lateral':
219
- loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
220
- loss = loss_simple.mean()
221
- return loss
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import numpy.random as npr
9
+ import copy
10
+ from functools import partial
11
+ from contextlib import contextmanager
12
+
13
+ from .common.get_model import get_model, register
14
+ from .sd import DDPM
15
+
16
+ version = '0'
17
+ symbol = 'thesis_model'
18
+
19
+
20
+ @register('thesis_model', version)
21
+ class CoDi(DDPM):
22
+ def __init__(self,
23
+ autokl_cfg=None,
24
+ optimus_cfg=None,
25
+ clip_cfg=None,
26
+ vision_scale_factor=0.1812,
27
+ text_scale_factor=4.3108,
28
+ audio_scale_factor=0.9228,
29
+ scale_by_std=False,
30
+ *args,
31
+ **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ if autokl_cfg is not None:
35
+ self.autokl = get_model()(autokl_cfg)
36
+
37
+ if optimus_cfg is not None:
38
+ self.optimus = get_model()(optimus_cfg)
39
+
40
+ if clip_cfg is not None:
41
+ self.clip = get_model()(clip_cfg)
42
+
43
+ if not scale_by_std:
44
+ self.vision_scale_factor = vision_scale_factor
45
+ self.text_scale_factor = text_scale_factor
46
+ self.audio_scale_factor = audio_scale_factor
47
+ else:
48
+ self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
49
+ self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
50
+ self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
51
+
52
+ @property
53
+ def device(self):
54
+ return next(self.parameters()).device
55
+
56
+ @torch.no_grad()
57
+ def autokl_encode(self, image):
58
+ encoder_posterior = self.autokl.encode(image)
59
+ z = encoder_posterior.sample().to(image.dtype)
60
+ return self.vision_scale_factor * z
61
+
62
+ @torch.no_grad()
63
+ def autokl_decode(self, z):
64
+ z = 1. / self.vision_scale_factor * z
65
+ return self.autokl.decode(z)
66
+
67
+ @torch.no_grad()
68
+ def optimus_encode(self, text):
69
+ if isinstance(text, List):
70
+ tokenizer = self.optimus.tokenizer_encoder
71
+ token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
72
+ token_id = []
73
+ for tokeni in token:
74
+ token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
75
+ token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
76
+ token_id.append(torch.LongTensor(token_sentence))
77
+ token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
78
+ else:
79
+ token_id = text
80
+ token_id = token_id.to(self.device)
81
+ z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
82
+ z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
83
+ return z_mu.squeeze(1) * self.text_scale_factor
84
+
85
+ @torch.no_grad()
86
+ def optimus_decode(self, z, temperature=1.0, max_length=30):
87
+ z = 1.0 / self.text_scale_factor * z
88
+ z = z.to(self.device)
89
+ return self.optimus.decode(z, temperature, max_length=max_length)
90
+
91
+ @torch.no_grad()
92
+ def clip_encode_text(self, text, encode_type='encode_text'):
93
+ swap_type = self.clip.encode_type
94
+ self.clip.encode_type = encode_type
95
+ embedding = self.clip(text, encode_type)
96
+ self.clip.encode_type = swap_type
97
+ return embedding
98
+
99
+ @torch.no_grad()
100
+ def clip_encode_vision(self, vision, encode_type='encode_vision'):
101
+ swap_type = self.clip.encode_type
102
+ self.clip.encode_type = encode_type
103
+ embedding = self.clip(vision, encode_type)
104
+ self.clip.encode_type = swap_type
105
+ return embedding
106
+
107
+ @torch.no_grad()
108
+ def clap_encode_audio(self, audio):
109
+ embedding = self.clap(audio)
110
+ return embedding
111
+
112
+ def forward(self, x=None, c=None, noise=None, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
113
+ if isinstance(x, list):
114
+ t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
115
+ else:
116
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
117
+ return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents, env_enc)
118
+
119
+ def apply_model(self, x_noisy, t, cond, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
120
+ return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc=env_enc)
121
+
122
+ def get_pixel_loss(self, pred, target, mean=True):
123
+ if self.loss_type == 'l1':
124
+ loss = (target - pred).abs()
125
+ if mean:
126
+ loss = loss.mean()
127
+ elif self.loss_type == 'l2':
128
+ if mean:
129
+ loss = torch.nn.functional.mse_loss(target, pred)
130
+ else:
131
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
132
+ else:
133
+ raise NotImplementedError("unknown loss type '{loss_type}'")
134
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
135
+ return loss
136
+
137
+ def get_text_loss(self, pred, target):
138
+ if self.loss_type == 'l1':
139
+ loss = (target - pred).abs()
140
+ elif self.loss_type == 'l2':
141
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
142
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
143
+ return loss
144
+
145
+ def p_losses(self, x_start, cond, t, noise=None, xtype='frontal', ctype='text', u=None,
146
+ return_algined_latents=False, env_enc=False):
147
+ if isinstance(x_start, list):
148
+ noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
149
+ x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in
150
+ zip(x_start, noise)]
151
+ if not env_enc:
152
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
153
+ else:
154
+ model_output, h_con = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
155
+ if return_algined_latents:
156
+ return model_output
157
+
158
+ loss_dict = {}
159
+
160
+ if self.parameterization == "x0":
161
+ target = x_start
162
+ elif self.parameterization == "eps":
163
+ target = noise
164
+ else:
165
+ raise NotImplementedError()
166
+
167
+ loss = 0.0
168
+ for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
169
+ if xtype_i == 'frontal':
170
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
171
+ elif xtype_i == 'text':
172
+ loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
173
+ elif xtype_i == 'lateral':
174
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
175
+ loss += loss_simple.mean()
176
+
177
+
178
+ # Controlliamo se il modello ha restituito anche h_con
179
+ # In tal caso, abbiamo le rappresentazioni latenti delle due modalità
180
+ # estratte dagli environmental encoder, essendo due tensori di dimensione batch_sizex1x1280
181
+ # possiamo utilizzarli per calcolare anche un termine di contrastive loss (crossentropy come in CLIP)
182
+ if h_con is not None:
183
+ def similarity(z_a, z_b):
184
+ return F.cosine_similarity(z_a, z_b)
185
+
186
+ z_a, z_b = h_con
187
+
188
+ z_a = z_a / z_a.norm(dim=-1, keepdim=True)
189
+ z_b = z_b / z_b.norm(dim=-1, keepdim=True)
190
+
191
+ logits_a = z_a.squeeze() @ z_b.squeeze().t()
192
+ logits_b = z_a.squeeze() @ z_b.squeeze().t()
193
+
194
+ labels = torch.arange(len(z_a)).to(z_a.device)
195
+
196
+ loss_a = F.cross_entropy(logits_a, labels)
197
+ loss_b = F.cross_entropy(logits_b, labels)
198
+
199
+ loss_con = (loss_a + loss_b) / 2
200
+ loss += loss_con
201
+
202
+
203
+ return loss / len(xtype)
204
+
205
+ else:
206
+ noise = torch.randn_like(x_start) if noise is None else noise
207
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
208
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
209
+
210
+ loss_dict = {}
211
+
212
+ if self.parameterization == "x0":
213
+ target = x_start
214
+ elif self.parameterization == "eps":
215
+ target = noise
216
+ else:
217
+ raise NotImplementedError()
218
+
219
+ if xtype == 'frontal':
220
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
221
+ elif xtype == 'text':
222
+ loss_simple = self.get_text_loss(model_output, target).mean([1])
223
+ elif xtype == 'lateral':
224
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
225
+ loss = loss_simple.mean()
226
+ return loss
core/models/common/__pycache__/get_model.cpython-311.pyc ADDED
Binary file (4.81 kB). View file
 
core/models/common/__pycache__/get_model.cpython-38.pyc CHANGED
Binary files a/core/models/common/__pycache__/get_model.cpython-38.pyc and b/core/models/common/__pycache__/get_model.cpython-38.pyc differ
 
core/models/common/__pycache__/get_optimizer.cpython-311.pyc ADDED
Binary file (3.2 kB). View file
 
core/models/common/__pycache__/get_optimizer.cpython-38.pyc CHANGED
Binary files a/core/models/common/__pycache__/get_optimizer.cpython-38.pyc and b/core/models/common/__pycache__/get_optimizer.cpython-38.pyc differ
 
core/models/common/__pycache__/get_scheduler.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
core/models/common/__pycache__/get_scheduler.cpython-38.pyc CHANGED
Binary files a/core/models/common/__pycache__/get_scheduler.cpython-38.pyc and b/core/models/common/__pycache__/get_scheduler.cpython-38.pyc differ
 
core/models/common/__pycache__/utils.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
core/models/common/__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/core/models/common/__pycache__/utils.cpython-38.pyc and b/core/models/common/__pycache__/utils.cpython-38.pyc differ
 
core/models/dani_model.py CHANGED
@@ -160,7 +160,9 @@ class dani_model(pl.LightningModule):
160
  condition_types=condition_types,
161
  eta=ddim_eta,
162
  verbose=False,
163
- mix_weight=mix_weight)
 
 
164
 
165
  out_all = []
166
  for i, xtype_i in enumerate(xtype):
 
160
  condition_types=condition_types,
161
  eta=ddim_eta,
162
  verbose=False,
163
+ mix_weight=mix_weight,
164
+ progress_bar=None
165
+ )
166
 
167
  out_all = []
168
  for i, xtype_i in enumerate(xtype):
core/models/ddim/__pycache__/ddim.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
core/models/ddim/__pycache__/ddim.cpython-38.pyc CHANGED
Binary files a/core/models/ddim/__pycache__/ddim.cpython-38.pyc and b/core/models/ddim/__pycache__/ddim.cpython-38.pyc differ
 
core/models/ddim/__pycache__/ddim_vd.cpython-311.pyc ADDED
Binary file (8.39 kB). View file
 
core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc CHANGED
Binary files a/core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc and b/core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc differ
 
core/models/ddim/__pycache__/diffusion_utils.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc CHANGED
Binary files a/core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc and b/core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc differ
 
core/models/ddim/ddim.py CHANGED
@@ -7,6 +7,7 @@ from functools import partial
7
 
8
  from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
 
 
10
 
11
  class DDIMSampler(object):
12
  def __init__(self, model, schedule="linear", **kwargs):
@@ -136,7 +137,8 @@ class DDIMSampler(object):
136
  score_corrector=None,
137
  corrector_kwargs=None,
138
  unconditional_guidance_scale=1.,
139
- unconditional_conditioning=None,):
 
140
  device = self.model.betas.device
141
  b = shape[0]
142
  if x_T is None:
@@ -157,7 +159,11 @@ class DDIMSampler(object):
157
 
158
  iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
159
 
 
 
160
  for i, step in enumerate(iterator):
 
 
161
  index = total_steps - i - 1
162
  ts = torch.full((b,), step, device=device, dtype=torch.long)
163
 
@@ -180,6 +186,9 @@ class DDIMSampler(object):
180
  intermediates['x_inter'].append(img)
181
  intermediates['pred_x0'].append(pred_x0)
182
 
 
 
 
183
  return img, intermediates
184
 
185
  @torch.no_grad()
 
7
 
8
  from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
 
10
+ import streamlit as st
11
 
12
  class DDIMSampler(object):
13
  def __init__(self, model, schedule="linear", **kwargs):
 
137
  score_corrector=None,
138
  corrector_kwargs=None,
139
  unconditional_guidance_scale=1.,
140
+ unconditional_conditioning=None,
141
+ progress_bar=None,):
142
  device = self.model.betas.device
143
  b = shape[0]
144
  if x_T is None:
 
159
 
160
  iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
161
 
162
+ if progress_bar is not None:
163
+ progress_bar.text("Generating samples...")
164
  for i, step in enumerate(iterator):
165
+ if progress_bar is not None:
166
+ progress_bar.progress(i/total_steps)
167
  index = total_steps - i - 1
168
  ts = torch.full((b,), step, device=device, dtype=torch.long)
169
 
 
186
  intermediates['x_inter'].append(img)
187
  intermediates['pred_x0'].append(pred_x0)
188
 
189
+ if progress_bar is not None:
190
+ progress_bar.success("Sampling complete.")
191
+
192
  return img, intermediates
193
 
194
  @torch.no_grad()
core/models/ddim/ddim_vd.py CHANGED
@@ -184,4 +184,4 @@ class DDIMSampler_VD(DDIMSampler):
184
  x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
185
  x_prev.append(x_prev_i)
186
  pred_x0.append(pred_x0_i)
187
- return x_prev, pred_x0
 
184
  x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
185
  x_prev.append(x_prev_i)
186
  pred_x0.append(pred_x0_i)
187
+ return x_prev, pred_x0
core/models/encoders/__pycache__/clap.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/__pycache__/clap.cpython-311.pyc and b/core/models/encoders/__pycache__/clap.cpython-311.pyc differ
 
core/models/encoders/__pycache__/clip.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/__pycache__/clip.cpython-311.pyc and b/core/models/encoders/__pycache__/clip.cpython-311.pyc differ
 
core/models/encoders/__pycache__/clip.cpython-38.pyc CHANGED
Binary files a/core/models/encoders/__pycache__/clip.cpython-38.pyc and b/core/models/encoders/__pycache__/clip.cpython-38.pyc differ
 
core/models/encoders/clap_modules/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/clap_modules/__pycache__/__init__.cpython-311.pyc and b/core/models/encoders/clap_modules/__pycache__/__init__.cpython-311.pyc differ
 
core/models/encoders/clap_modules/open_clip/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/clap_modules/open_clip/__pycache__/__init__.cpython-311.pyc and b/core/models/encoders/clap_modules/open_clip/__pycache__/__init__.cpython-311.pyc differ
 
core/models/encoders/clap_modules/open_clip/__pycache__/factory.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/clap_modules/open_clip/__pycache__/factory.cpython-311.pyc and b/core/models/encoders/clap_modules/open_clip/__pycache__/factory.cpython-311.pyc differ
 
core/models/encoders/clap_modules/open_clip/__pycache__/feature_fusion.cpython-311.pyc CHANGED
Binary files a/core/models/encoders/clap_modules/open_clip/__pycache__/feature_fusion.cpython-311.pyc and b/core/models/encoders/clap_modules/open_clip/__pycache__/feature_fusion.cpython-311.pyc differ