Rex Cheng commited on
Commit
c4dd2de
1 Parent(s): a1c80b9

fix for hf

Browse files
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_*.sh
2
+ log/
3
+ saves
4
+ saves/
5
+ weights/
6
+ weights
7
+ output/
8
+ output
9
+ pretrained/
10
+ workspace
11
+ workspace/
12
+ ext_weights/
13
+ ext_weights
14
+ .checkpoints/
15
+
16
+ # Byte-compiled / optimized / DLL files
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ pip-wheel-metadata/
39
+ share/python-wheels/
40
+ *.egg-info/
41
+ .installed.cfg
42
+ *.egg
43
+ MANIFEST
44
+
45
+ # PyInstaller
46
+ # Usually these files are written by a python script from a template
47
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
+ *.manifest
49
+ *.spec
50
+
51
+ # Installer logs
52
+ pip-log.txt
53
+ pip-delete-this-directory.txt
54
+
55
+ # Unit test / coverage reports
56
+ htmlcov/
57
+ .tox/
58
+ .nox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ *.py,cover
66
+ .hypothesis/
67
+ .pytest_cache/
68
+
69
+ # Translations
70
+ *.mo
71
+ *.pot
72
+
73
+ # Django stuff:
74
+ *.log
75
+ local_settings.py
76
+ db.sqlite3
77
+ db.sqlite3-journal
78
+
79
+ # Flask stuff:
80
+ instance/
81
+ .webassets-cache
82
+
83
+ # Scrapy stuff:
84
+ .scrapy
85
+
86
+ # Sphinx documentation
87
+ docs/_build/
88
+
89
+ # PyBuilder
90
+ target/
91
+
92
+ # Jupyter Notebook
93
+ .ipynb_checkpoints
94
+
95
+ # IPython
96
+ profile_default/
97
+ ipython_config.py
98
+
99
+ # pyenv
100
+ .python-version
101
+
102
+ # pipenv
103
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
105
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
106
+ # install all needed dependencies.
107
+ #Pipfile.lock
108
+
109
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110
+ __pypackages__/
111
+
112
+ # Celery stuff
113
+ celerybeat-schedule
114
+ celerybeat.pid
115
+
116
+ # SageMath parsed files
117
+ *.sage.py
118
+
119
+ # Environments
120
+ .env
121
+ .venv
122
+ env/
123
+ venv/
124
+ ENV/
125
+ env.bak/
126
+ venv.bak/
127
+
128
+ # Spyder project settings
129
+ .spyderproject
130
+ .spyproject
131
+
132
+ # Rope project settings
133
+ .ropeproject
134
+
135
+ # mkdocs documentation
136
+ /site
137
+
138
+ # mypy
139
+ .mypy_cache/
140
+ .dmypy.json
141
+ dmypy.json
142
+
143
+ # Pyre type checker
144
+ .pyre/
app.py CHANGED
@@ -5,6 +5,13 @@ from pathlib import Path
5
  import gradio as gr
6
  import torch
7
  import torchaudio
 
 
 
 
 
 
 
8
 
9
  from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
10
  setup_eval_logging)
 
5
  import gradio as gr
6
  import torch
7
  import torchaudio
8
+ import os
9
+
10
+ try:
11
+ import mmaudio
12
+ except ImportError:
13
+ os.system("pip install -e .")
14
+ import mmaudio
15
 
16
  from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
17
  setup_eval_logging)
mmaudio/ext/autoencoder/autoencoder.py CHANGED
@@ -19,7 +19,7 @@ class AutoEncoderModule(nn.Module):
19
  super().__init__()
20
  self.vae: VAE = get_my_vae(mode).eval()
21
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
22
- self.vae.load_state_dict(vae_state_dict)
23
  self.vae.remove_weight_norm()
24
 
25
  if mode == '16k':
 
19
  super().__init__()
20
  self.vae: VAE = get_my_vae(mode).eval()
21
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
22
+ self.vae.load_state_dict(vae_state_dict, strict=False)
23
  self.vae.remove_weight_norm()
24
 
25
  if mode == '16k':
mmaudio/ext/autoencoder/vae.py CHANGED
@@ -75,11 +75,15 @@ class VAE(nn.Module):
75
  super().__init__()
76
 
77
  if data_dim == 80:
78
- self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda())
79
- self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda())
 
 
80
  elif data_dim == 128:
81
- self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda())
82
- self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda())
 
 
83
 
84
  self.data_mean = self.data_mean.view(1, -1, 1)
85
  self.data_std = self.data_std.view(1, -1, 1)
 
75
  super().__init__()
76
 
77
  if data_dim == 80:
78
+ # self.data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda()
79
+ # self.data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda()
80
+ self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
81
+ self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32))
82
  elif data_dim == 128:
83
+ # torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda()
84
+ # self.data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda()
85
+ self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
86
+ self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32))
87
 
88
  self.data_mean = self.data_mean.view(1, -1, 1)
89
  self.data_std = self.data_std.view(1, -1, 1)
mmaudio/model/embeddings.py CHANGED
@@ -21,12 +21,11 @@ class TimestepEmbedder(nn.Module):
21
  assert dim % 2 == 0, 'dim must be even.'
22
 
23
  with torch.autocast('cuda', enabled=False):
24
- self.freqs = nn.Buffer(
25
  1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
26
- frequency_embedding_size)),
27
- persistent=False)
28
  freq_scale = 10000 / max_period
29
- self.freqs = freq_scale * self.freqs
30
 
31
  def timestep_embedding(self, t):
32
  """
 
21
  assert dim % 2 == 0, 'dim must be even.'
22
 
23
  with torch.autocast('cuda', enabled=False):
24
+ self.freqs = (
25
  1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
26
+ frequency_embedding_size)))
 
27
  freq_scale = 10000 / max_period
28
+ self.freqs = nn.Parameter(freq_scale * self.freqs)
29
 
30
  def timestep_embedding(self, t):
31
  """
mmaudio/model/networks.py CHANGED
@@ -166,8 +166,10 @@ class MMAudio(nn.Module):
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
- self.latent_rot = nn.Buffer(latent_rot, persistent=False)
170
- self.clip_rot = nn.Buffer(clip_rot, persistent=False)
 
 
171
 
172
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
173
  self._latent_seq_len = latent_seq_len
@@ -346,7 +348,7 @@ class MMAudio(nn.Module):
346
  if 'clip_rot' in src_dict:
347
  del src_dict['clip_rot']
348
 
349
- self.load_state_dict(src_dict, strict=True)
350
 
351
  @property
352
  def device(self) -> torch.device:
 
166
  self._clip_seq_len,
167
  device=self.device)
168
 
169
+ # self.latent_rot = latent_rot.to(self.device)
170
+ # self.clip_rot = clip_rot.to(self.device)
171
+ self.register_buffer('latent_rot', latent_rot)
172
+ self.register_buffer('clip_rot', clip_rot)
173
 
174
  def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
175
  self._latent_seq_len = latent_seq_len
 
348
  if 'clip_rot' in src_dict:
349
  del src_dict['clip_rot']
350
 
351
+ self.load_state_dict(src_dict, strict=False)
352
 
353
  @property
354
  def device(self) -> torch.device:
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- torch >= 2.5.1
2
  torchaudio
3
  torchvision
4
  python-dotenv
 
1
+ torch == 2.4.0
2
  torchaudio
3
  torchvision
4
  python-dotenv