Spaces:
Sleeping
new unlooper ui (#9)
Browse files- first commit (50f034f90d55446d572c869278b664ed381dd9ab)
- readme cleanup (5582d2e7b94dfab12e7dc90c1fa384f02af09b41)
- refactor (fc839a6a78647ed5e706d0e5b2eee512e8f8a8ec)
- refactor bugfixes (534a89cc9cf61b96dd2bcc82f7898cf77eb94bc8)
- remove wavenet, readability (04c5b94d12624e6a884f6e06a90c1abfb8afeae2)
- fix: sample prefix suffix (326b5bbb45c9ceeef030192c83e1d09035a7f509)
- remove library refs (b8622751c80761f8fd2ebdff5cfa6c58bfe34533)
- add save_epochs (225770674d206c769ab4300f2990634a01febd49)
- loudnorm, (dde5c212d420bfaba0a7d4100fc02c5fc53e200c)
- update reqs (d6b9d5b546a9d85fa14d74e94eebf6a17b24b8ae)
- fix sampling logic for paella (9439b644fe82f14614630b16e4cd6f1921d95f0f)
- fix random seeds for train! (79bcce65c6d7e0f60730055a9919866b48a6d070)
- Merge branch 'main' of github.com:descriptinc/lyrebird-vampnet into main (3d0828584db269031562ab23743bb0228543f2b8)
- add a coarse2fine eval script (260b46dfc267c9e2ef7807c18ad4379f6e0609f1)
- rm old eval script (cc3a37b0704fe38f4dedeff555ad598c9db5da88)
- save each metric on its own (275afd0ca5725a99558b5f65fcf697ff923d2530)
- interface (b54865d690d9f96e5ecceb3a539401c8ac86d339)
- fix volumenorm t -24 (34fcef96617aefac478da2667eb2000fc1bbecae)
- remove seq_len (920f55c77de892f058e26b98c1f8dd8d1acaafc3)
- fix transforms (d6a029bc0fa273ebf2229432f8bb1e8e1f611b5c)
- add opt for no prefix and no suffix (183d21c9728a3563625eaba80097b3a78580db11)
- update confs (b688797f68df14fb502f03049d5ab389b4a02cee)
- confs (f1ccdc1d6b867ce985af26725e0b1669be401ab8)
- interface, cleanup imputation code (4a2dc41ab59e9afdc5b2ce71a4527bd7bf08f4ba)
- interface improvements (a63cce029ca1ac0fd0c6326c935cc31f54f975b1)
- towards beat tracking in the interface (91f86380efebdb549e42686d5c025e523a33f54a)
- beat tracker bugfixes (5a0a80a25fda79285bf740b2158f303cbe577472)
- rm comment (554c010ce3fa91983548c4729de3cae18781eae5)
- c2f (e4e3c4e936cf0f127f7c021fb30f0b5e0e2052af)
- Merge branch 'main' of https://github.com/descriptinc/lyrebird-vampnet into main (b9277bd5d60c4f7e79d6a0074fa6197dfd12f771)
- upgrade to new codec ckpt (d48dcc40be3114122307de4e0a5eba44e8baee37)
- remove cancerous subprocess call (0a036acd9ed7b6d25eeaa0b54af54ce4435869e6)
- per-instrument models (4687dd93844cb3c6718085890238c44bc20faffa)
- exps (6f6fd13dfa58b2500b674821702865f7d1c85df9)
- typical filtering set to true (22a680a5c76cd779407a15e5eb52cae3a2f00ef9)
- coarse eval (57047e5fcc7fd5cc03de8c88b8fbb7cd24063f91)
- exps (9fbfaa60fffd03918045b3fc91ed1a546c63a16a)
- fix baseline cond (84d4ed6e7a531509140aba6b96dc347182a65556)
- c2f can be None (e3c7f4691a33dceac82b9d36c5654c606e6d94c2)
- use items instead of tensors (c1b9ba0d82891213ce63ebe02a20a3a7f100eb13)
- revert docker image (972000ef6b970bf834c88db486378974817683cd)
- negate sisdr (1fc975729b1babae43899a4e96d8450021841099)
- rm clmr (39847d40d8fd2d82f05442550d4f0cb6c8efca06)
- fix dockerfile (bcc33054a49ecb649755e7d202a5623ad0f40537)
- changes (ac059f495a6e84db76b20e97eb597886f0bb3bd4)
- eval, demo (3815be3a1c5368b233d430f7cf7a661481245cfc)
- gooood outputs (f3f463449cd952f0ebeee104c4a89ab9e270923a)
- better sampling defaults (03f09ee3138b5b6ee4a03bc5ef94cc9cf2b9e7c7)
- demo (128981df2a5a90b1350c2752ca6b350b7d29b10e)
- more tweaks (93b48cbe204776ce6946237d13272dd708f8f290)
- maestro interface (8544bbfee05e914106d8fbe74431dad70deadf4e)
- moving out (322cc3a20235540596bf36d06635ce8b75dc6636)
- reqs (4908bb47c041426e738c85c0d5139322f04fd947)
- interface for max (fa490b8a8ed1fb74ed956d3fbc36145d2bb6b53a)
- lora prep (1a5973b052a24dbcb9c02af05708462d41a7f83e)
- looks like it's working? (63015d58fe3e9e4b6a7813c64083400fb3890586)
- confs (910b45f2864f1e63471f6b12f0b6d7ed447e0a59)
- constructions (f4c9665b03e75ef98bc504ebba460d7761ca5b58)
- the refactor begins (5a343f4b6619bd99f5887d19e9a55870033c3611)
- basic readme stuff (99122c4357d54bd8ee98cbbbc8bb33f98d40dfb4)
- readme (7aa3063cad5ad723808e1d1f4306d22bc9771908)
- refactor masking, interface, demo (e3ca5f7b568fd1da332209583e9ff595bbb637ed)
- update readme (6fcf6a46dc9da33ade982190623c8f8c9cfb5a32)
- lac (bad2d3fe7ac6ed21963bc29f4366546c0058a5f0)
- add model ckpt path (6f55a79b36c42f37b501f063e97143d4b43aa299)
- demo cleanup, onset masks, pitch shifting (881d56d533c7fb9d7b96668ae1ad69603b257ec4)
- fix dropout bug for masks, refactor interfaces, add finetune setup script (c940f256e750010f780abcefb8a06fa0d49f9fd4)
- gamelan xeno canto (85e8a86c2b936d65aed99ba4056edf00f326a9c0)
- tiny sampling refactor (4d0cbfe60c83a9efd3aa7fc670a6c90baa832ad0)
- more demo ctrls (b61e699ad77c0d0d468e4c2a5bbca1e39d6f7308)
- confs (13b04cf217b9c4dc46055ac1e6684ae318df706b)
- efficient lora ckpts (75a71694824cf17fd51cce0705c062009a583660)
- critical sampling fix, two demoes for comparing old and new sampling (3f6f517b4fd2f14c7443d813a55af56ccbd79b3b)
- settling down on the new sampling routine (09b9691a666738aa47618e869266a5b282efff0c)
- update splits, reqs (cf172ac4841de1161ba130b62504d02d2f4f42b3)
- dariacore (b3caf82e70e770b17a1cceff1572ae173797f257)
- maestro script (c068a295a6fe239c94253716693a8cbf3b50c808)
- cleanup (d98455c3df585892baf81b4d0ea0130ab7ba51da)
- Update README.md (9da46f9203f655e935868ba0068f56608d4e3bf1)
- Update README.md (a84c25c272432ce9aa05e04c9bf1b2cda8d916c7)
- fix setup reqs (45390f991ea29f76afb711df3d7c32b923126221)
- update setup reqs (22b423a67d6f44dcb2e0bd0cbdda7710e84dead9)
- Create LICENSE (3445a716606e07f38aa0687d331c587f4eb7a9e0)
- more sampling fixes (33469209487504d1193f5be0a94b5edd3a6f9b5e)
- interface (2f3fb3279b5defd34899668c0731ba2d62b99ddb)
- add fig plots (e5dcb5f3962516056c96673794d5cc1d57f2c640)
- cleanup (e9fd215995de88f5c98bbe965f9e909d77f9c194)
- update readme (74cced76b6748ee5de93e176b16ed5be65a8199a)
- license updates! (eaa691b3ff238a4b8bb65aea182da776a1ae6515)
- update license (e251e23d39fdd4b0d375979094cb701d1fadfe51)
- pin audiotools version! (bfacd003d6729c963d62f9a410724ff6e0099ec1)
- add loralib (fd975c2086845ed0df7b0a929e26e9c62d0215ff)
- pip install from commit hangs? (04f1577f31421027dc71d0a0df846027ffb1d0ee)
- old audiotools commit can't find an html file (f373fd1cd920efeafaae9a0b9ceab8e3d72aad2f)
- update readme (b1cbc10df865c2eae3ae50fe4ba73b97fe275aa4)
- python version (c039932823574158f8772803e2d3a8a55b81c74b)
- cleaning up (03107fd63345468cfa32dacdb558e74b4cc5d6ae)
- host weights in zenodo (fff28a2ecf290953a2948bfd39c3127786ca5d61)
- pin numy (4c6c719ffde2395c8fc82e63e5f9ea360fb0c7e2)
- readme (3efca1407013ac845c232d9d1d7f41310a1780bc)
- disable wavebeat for now (d51eb6d10e0acb9586e547a5db3c417c1ab5899a)
- HF space prep (91e3ceb923bbd34a2cc78e1ccaf8088b4d9c7e87)
- example audio (5bd16c27e643e899e849a1a1d0f6793afbbdf1bd)
- demo (a004369aab5bee7125b564a124b6709811ba8bd7)
- LAC! (93ca7213c4745221b1b83e805ca961f5c8c6f055)
- pin to my audiotools fork for now (419ccdcace423e2eec652da5ee0f4a376d4a9be9)
- point to audiotools fork (4d1e39c1c90cb092950f60f146ef30e483b6a623)
- add wavebeat ckpt again (fed03a187e43a2578c294641222e7fa9c03614ee)
- add wavebeat (e2a08a917c691abaa5474fa78249148967599f69)
- lac :( (e32ae3e0f6f507c6d339a904ae7375d892efb1f4)
- update setup, reade (82c5a3634286fe4840e423d3cea64344cd5b5c34)
- lac again (bdcf1d8642781f084c35d01065a0fe03b7155126)
- update reqs (7aee7dbef2d44453620ff280735040b6175d0fc9)
- update presets (4b2f92ac2ceabbce0edf630f1794f747a0a1126e)
- update zenodo (09e4bee9d3edec4600b932f94108e80f742d4bb1)
- add example.wav (2e62788cc05a3655b50b47fd8bb71e7b1bf98339)
- fix readme (df7025d6bb4e5e9c9af901a0ce56f7525ed9089b)
- fixes #3 (5dafbacc8d6235b38f8313977dea3e5604821f01)
- finetune confs should remain local (fca32330734bd8e96ddf8c32c11e94ea72eb0d34)
- Update README.md (3277bd0d4774ae31e2869102ed4871a83e81f36d)
- update audiotools version, update recipe (457f9d164c4bf6da8a62d8a133fd5b7d88367328)
- Merge branch 'ismir' of https://github.com/hugofloresgarcia/vampnet into ismir (2ddbddc8e05199f2d5059641086cf2d1040bfb65)
- taking prompt c2f tokens into account (03217458d72ba6b01496686b891262e759d52dc1)
- rm breakpoint (ed6898a9dc81dd8f1c1efc725f397a30e9ca1f68)
- add wavebeat ckpt to finetuned confs, rm temperature typing (07aa55eb84083deb8652a596d484b0498a0b3bb3)
- num steps should propagate to coarse2fine too (88c78e13d18d1077c0c67e7446f029d6ec1dd766)
- use torch.compile for training (405226b57221c9768a6e7b1aab45e42fa49a40c4)
- sampling tricks! (9496f0ed33bea61ded6dc0fa0fcf0cf413192644)
- runs (d25da2d5dba1c300a1ddc677c9f33990fbd30cf0)
- runs (308d855638cba98b15f20ff2f7e711bb5bbf2477)
- fix reqs (3a5996b8da04acc9eed5b5fe42a5437a25d46de8)
- sampling cutoff trick (793d060d3267753792d859b7ba54b47a01f4be94)
- defaults (dac02f71acb8aba7a1a72df1a945a522c6a95bd9)
- fix interface mask b4 (9fdab572e9f6c9914f4b395f28ec00b9b1f7c965)
- beat prompt ctx (8c3b3e7d11675cd6ecb9294008f5e1c8430ef4aa)
- lora interface (bf35d4569094d2677a098d53dee904b248a4431e)
- sample cutoff (46c305cba0c4a26bcf933f97e19f2f83fcdb41b7)
- pass sampling steps to coarse2fine (62f49b0dd37f40e7f8457da020f2352f81829480)
- better onset detection!!!!! (7b88c072f4d731fda11c37650f763f63c98ff718)
- disable lora choice for now (de03185fe03fed648d7e47a46b70ae77cfebd495)
- Merge branch 'lora-app' into ismir (a66dc9cb8aa8494f8d8ed53ac1e5bf99a6d6483e)
- dropping torch.compile for now (31b771cb7c14d18bb03c47330519e63329e7cf6b)
- compile is back (421682e85050b2caacb8667247f4a0febd3e41ef)
- install_requires now has all packages needed for the app to run out of the box. (be9045c3bab0dd4c9cade57f385796f2d8203779)
- Removed requirements.txt as it contains redundant information and is not used if installing via pip. (66e2202fa81e7f9f3bc6e52d8f96e7a6f2a90eb1)
- Merge pull request #16 from orrp/ismir (df78974865f945ec40452d5568c8963441c4ec9d)
- require python 3.9 in readme (#15) (238e0c702e5d97cc057d46a07b8ddad9233130a6)
- fix script naming in readme (a876330c8994540da268ca898aa1bc655f4951dd)
- drop tensorboardX (980f2e7e74bceb5002907a7ca295b720b9023e0c)
- lora ckpts aren't working?! (7f524eaea9475bb022b3d0266d9a0faede9873cf)
- lora bloat (925e3e04e2eff4b5639831351ba10e21bcf2f088)
- scripts (ebb11735979242cf8b12d8f0ca7b6fa6b9683724)
- Merge branch 'ismir' into pr/9 (e288b6f8ea806b8e38587a6f3c40cd2d18ebe0b4)
- interface oo
- .gitignore +0 -3
- README.md +11 -2
- app.py +106 -23
- conf/lora/lora.yml +4 -2
- requirements.txt +3 -2
- scripts/exp/fine_tune.py +2 -4
- scripts/exp/train.py +9 -3
- scripts/utils/augment.py +38 -24
- scripts/utils/remove_quiet_files.py +29 -0
- scripts/utils/split_long_audio_file.py +34 -0
- scripts/utils/xeno-canto-dl.py +234 -0
- setup.py +3 -2
- vampnet/interface.py +2 -3
- vampnet/mask.py +38 -20
- vampnet/modules/transformer.py +16 -17
@@ -179,9 +179,6 @@ models/
|
|
179 |
samples*/
|
180 |
models-all/
|
181 |
models.zip
|
182 |
-
audiotools/
|
183 |
-
descript-audio-codec/
|
184 |
-
# *.pth
|
185 |
.git-old
|
186 |
conf/generated/*
|
187 |
runs*/
|
|
|
179 |
samples*/
|
180 |
models-all/
|
181 |
models.zip
|
|
|
|
|
|
|
182 |
.git-old
|
183 |
conf/generated/*
|
184 |
runs*/
|
@@ -7,6 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
# VampNet
|
@@ -18,7 +19,15 @@ you can try vampnet in a co-creative looper called unloop. see this link: https:
|
|
18 |
|
19 |
# Setting up
|
20 |
|
21 |
-
Requires Python 3.9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
install VampNet
|
@@ -91,7 +100,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
|
91 |
|
92 |
launch the interface:
|
93 |
```bash
|
94 |
-
python
|
95 |
```
|
96 |
|
97 |
|
|
|
7 |
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.9
|
11 |
---
|
12 |
|
13 |
# VampNet
|
|
|
19 |
|
20 |
# Setting up
|
21 |
|
22 |
+
**Requires Python 3.9**.
|
23 |
+
|
24 |
+
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
25 |
+
|
26 |
+
(for example, using conda)
|
27 |
+
```bash
|
28 |
+
conda create -n vampnet python=3.9
|
29 |
+
conda activate vampnet
|
30 |
+
```
|
31 |
|
32 |
|
33 |
install VampNet
|
|
|
100 |
|
101 |
launch the interface:
|
102 |
```bash
|
103 |
+
python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
104 |
```
|
105 |
|
106 |
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from pathlib import Path
|
2 |
from typing import Tuple
|
3 |
import yaml
|
@@ -18,16 +24,35 @@ from vampnet import mask as pmask
|
|
18 |
# Interface = argbind.bind(Interface)
|
19 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
20 |
|
21 |
-
|
22 |
-
coarse_ckpt="./models/vampnet/coarse.pth",
|
23 |
-
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
24 |
-
codec_ckpt="./models/vampnet/codec.pth",
|
25 |
-
wavebeat_ckpt="./models/wavebeat.pth",
|
26 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
27 |
-
)
|
28 |
|
29 |
# loader = AudioLoader()
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# dataset = at.data.datasets.AudioDataset(
|
33 |
# loader,
|
@@ -50,7 +75,7 @@ def load_audio(file):
|
|
50 |
)
|
51 |
sig = interface.preprocess(sig)
|
52 |
|
53 |
-
out_dir = OUT_DIR / str(uuid.uuid4())
|
54 |
out_dir.mkdir(parents=True, exist_ok=True)
|
55 |
sig.write(out_dir / "input.wav")
|
56 |
return sig.path_to_file
|
@@ -68,6 +93,10 @@ def _vamp(data, return_mask=False):
|
|
68 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
69 |
out_dir.mkdir()
|
70 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
|
|
|
|
|
71 |
|
72 |
z = interface.encode(sig)
|
73 |
|
@@ -107,7 +136,27 @@ def _vamp(data, return_mask=False):
|
|
107 |
mask = pmask.codebook_unmask(mask, ncc)
|
108 |
|
109 |
|
110 |
-
print(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
_top_p = data[top_p] if data[top_p] > 0 else None
|
112 |
# save the mask as a txt file
|
113 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
@@ -126,6 +175,7 @@ def _vamp(data, return_mask=False):
|
|
126 |
top_p=_top_p,
|
127 |
gen_fn=interface.coarse.generate,
|
128 |
seed=_seed,
|
|
|
129 |
)
|
130 |
|
131 |
if use_coarse2fine:
|
@@ -134,7 +184,8 @@ def _vamp(data, return_mask=False):
|
|
134 |
mask_temperature=data[masktemp]*10,
|
135 |
sampling_temperature=data[sampletemp],
|
136 |
mask=mask,
|
137 |
-
sampling_steps=data[num_steps],
|
|
|
138 |
seed=_seed,
|
139 |
)
|
140 |
|
@@ -183,6 +234,7 @@ def save_vamp(data):
|
|
183 |
"use_coarse2fine": data[use_coarse2fine],
|
184 |
"stretch_factor": data[stretch_factor],
|
185 |
"seed": data[seed],
|
|
|
186 |
}
|
187 |
|
188 |
# save with yaml
|
@@ -265,29 +317,41 @@ with gr.Blocks() as demo:
|
|
265 |
"beat_mask_downbeats": False,
|
266 |
},
|
267 |
"slight periodic variation": {
|
268 |
-
"periodic_p":
|
269 |
-
"onset_mask_width":
|
270 |
"beat_mask_width": 0,
|
271 |
"beat_mask_downbeats": False,
|
272 |
},
|
273 |
-
"
|
274 |
"periodic_p": 13,
|
275 |
"onset_mask_width": 5,
|
276 |
"beat_mask_width": 0,
|
277 |
"beat_mask_downbeats": False,
|
278 |
},
|
279 |
-
"
|
280 |
"periodic_p": 17,
|
281 |
"onset_mask_width": 5,
|
282 |
"beat_mask_width": 0,
|
283 |
"beat_mask_downbeats": False,
|
284 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
"beat-driven variation": {
|
286 |
"periodic_p": 0,
|
287 |
"onset_mask_width": 0,
|
288 |
-
"beat_mask_width":
|
289 |
"beat_mask_downbeats": False,
|
290 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
"beat-driven variation (downbeats only, strong)": {
|
292 |
"periodic_p": 0,
|
293 |
"onset_mask_width": 0,
|
@@ -309,14 +373,14 @@ with gr.Blocks() as demo:
|
|
309 |
minimum=0,
|
310 |
maximum=128,
|
311 |
step=1,
|
312 |
-
value=
|
313 |
)
|
314 |
|
315 |
|
316 |
onset_mask_width = gr.Slider(
|
317 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
318 |
minimum=0,
|
319 |
-
maximum=
|
320 |
step=1,
|
321 |
value=5,
|
322 |
)
|
@@ -334,6 +398,14 @@ with gr.Blocks() as demo:
|
|
334 |
|
335 |
|
336 |
with gr.Accordion("extras ", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
rand_mask_intensity = gr.Slider(
|
338 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
339 |
minimum=0.0,
|
@@ -396,14 +468,15 @@ with gr.Blocks() as demo:
|
|
396 |
masktemp = gr.Slider(
|
397 |
label="mask temperature",
|
398 |
minimum=0.0,
|
399 |
-
maximum=
|
400 |
value=1.5
|
401 |
)
|
402 |
sampletemp = gr.Slider(
|
403 |
label="sample temperature",
|
404 |
minimum=0.1,
|
405 |
-
maximum=
|
406 |
-
value=1.0
|
|
|
407 |
)
|
408 |
|
409 |
|
@@ -419,7 +492,7 @@ with gr.Blocks() as demo:
|
|
419 |
label="typical filtering ",
|
420 |
value=False
|
421 |
)
|
422 |
-
typical_mass = gr.Slider(
|
423 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
424 |
minimum=0.01,
|
425 |
maximum=0.99,
|
@@ -432,6 +505,13 @@ with gr.Blocks() as demo:
|
|
432 |
step=1,
|
433 |
value=64
|
434 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
436 |
use_coarse2fine = gr.Checkbox(
|
437 |
label="use coarse2fine",
|
@@ -506,8 +586,11 @@ with gr.Blocks() as demo:
|
|
506 |
typical_mass,
|
507 |
typical_min_tokens,
|
508 |
beat_mask_width,
|
|
|
509 |
seed,
|
510 |
-
|
|
|
|
|
511 |
}
|
512 |
|
513 |
# connect widgets
|
|
|
1 |
+
# huggingface space exclusive
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.system('pip install cython')
|
5 |
+
os.system('pip install madmom')
|
6 |
+
|
7 |
from pathlib import Path
|
8 |
from typing import Tuple
|
9 |
import yaml
|
|
|
24 |
# Interface = argbind.bind(Interface)
|
25 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
26 |
|
27 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# loader = AudioLoader()
|
30 |
+
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
31 |
+
|
32 |
+
conf = argbind.parse_args()
|
33 |
+
|
34 |
+
|
35 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
36 |
+
def shift_pitch(signal, interval: int):
|
37 |
+
signal.samples = pitch_shift(
|
38 |
+
signal.samples,
|
39 |
+
shift=interval,
|
40 |
+
sample_rate=signal.sample_rate
|
41 |
+
)
|
42 |
+
return signal
|
43 |
+
|
44 |
+
def load_interface():
|
45 |
+
interface = Interface(
|
46 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
47 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
48 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
49 |
+
wavebeat_ckpt="./models/wavebeat.pth",
|
50 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
51 |
+
)
|
52 |
+
return interface
|
53 |
+
|
54 |
+
|
55 |
+
interface = load_interface()
|
56 |
|
57 |
# dataset = at.data.datasets.AudioDataset(
|
58 |
# loader,
|
|
|
75 |
)
|
76 |
sig = interface.preprocess(sig)
|
77 |
|
78 |
+
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
|
79 |
out_dir.mkdir(parents=True, exist_ok=True)
|
80 |
sig.write(out_dir / "input.wav")
|
81 |
return sig.path_to_file
|
|
|
93 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
94 |
out_dir.mkdir()
|
95 |
sig = at.AudioSignal(data[input_audio])
|
96 |
+
sig = interface.preprocess(sig)
|
97 |
+
|
98 |
+
if data[pitch_shift_amt] != 0:
|
99 |
+
sig = shift_pitch(sig, data[pitch_shift_amt])
|
100 |
|
101 |
z = interface.encode(sig)
|
102 |
|
|
|
136 |
mask = pmask.codebook_unmask(mask, ncc)
|
137 |
|
138 |
|
139 |
+
print(f"dropout {data[dropout]}")
|
140 |
+
print(f"masktemp {data[masktemp]}")
|
141 |
+
print(f"sampletemp {data[sampletemp]}")
|
142 |
+
print(f"top_p {data[top_p]}")
|
143 |
+
print(f"prefix_s {data[prefix_s]}")
|
144 |
+
print(f"suffix_s {data[suffix_s]}")
|
145 |
+
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
146 |
+
print(f"num_steps {data[num_steps]}")
|
147 |
+
print(f"periodic_p {data[periodic_p]}")
|
148 |
+
print(f"periodic_w {data[periodic_w]}")
|
149 |
+
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
150 |
+
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
151 |
+
print(f"onset_mask_width {data[onset_mask_width]}")
|
152 |
+
print(f"beat_mask_width {data[beat_mask_width]}")
|
153 |
+
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
154 |
+
print(f"stretch_factor {data[stretch_factor]}")
|
155 |
+
print(f"seed {data[seed]}")
|
156 |
+
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
157 |
+
print(f"sample_cutoff {data[sample_cutoff]}")
|
158 |
+
|
159 |
+
|
160 |
_top_p = data[top_p] if data[top_p] > 0 else None
|
161 |
# save the mask as a txt file
|
162 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
|
|
175 |
top_p=_top_p,
|
176 |
gen_fn=interface.coarse.generate,
|
177 |
seed=_seed,
|
178 |
+
sample_cutoff=data[sample_cutoff],
|
179 |
)
|
180 |
|
181 |
if use_coarse2fine:
|
|
|
184 |
mask_temperature=data[masktemp]*10,
|
185 |
sampling_temperature=data[sampletemp],
|
186 |
mask=mask,
|
187 |
+
sampling_steps=data[num_steps],
|
188 |
+
sample_cutoff=data[sample_cutoff],
|
189 |
seed=_seed,
|
190 |
)
|
191 |
|
|
|
234 |
"use_coarse2fine": data[use_coarse2fine],
|
235 |
"stretch_factor": data[stretch_factor],
|
236 |
"seed": data[seed],
|
237 |
+
"samplecutoff": data[sample_cutoff],
|
238 |
}
|
239 |
|
240 |
# save with yaml
|
|
|
317 |
"beat_mask_downbeats": False,
|
318 |
},
|
319 |
"slight periodic variation": {
|
320 |
+
"periodic_p": 5,
|
321 |
+
"onset_mask_width": 5,
|
322 |
"beat_mask_width": 0,
|
323 |
"beat_mask_downbeats": False,
|
324 |
},
|
325 |
+
"moderate periodic variation": {
|
326 |
"periodic_p": 13,
|
327 |
"onset_mask_width": 5,
|
328 |
"beat_mask_width": 0,
|
329 |
"beat_mask_downbeats": False,
|
330 |
},
|
331 |
+
"strong periodic variation": {
|
332 |
"periodic_p": 17,
|
333 |
"onset_mask_width": 5,
|
334 |
"beat_mask_width": 0,
|
335 |
"beat_mask_downbeats": False,
|
336 |
},
|
337 |
+
"very strong periodic variation": {
|
338 |
+
"periodic_p": 21,
|
339 |
+
"onset_mask_width": 5,
|
340 |
+
"beat_mask_width": 0,
|
341 |
+
"beat_mask_downbeats": False,
|
342 |
+
},
|
343 |
"beat-driven variation": {
|
344 |
"periodic_p": 0,
|
345 |
"onset_mask_width": 0,
|
346 |
+
"beat_mask_width": 50,
|
347 |
"beat_mask_downbeats": False,
|
348 |
},
|
349 |
+
"beat-driven variation (downbeats only)": {
|
350 |
+
"periodic_p": 0,
|
351 |
+
"onset_mask_width": 0,
|
352 |
+
"beat_mask_width": 50,
|
353 |
+
"beat_mask_downbeats": True,
|
354 |
+
},
|
355 |
"beat-driven variation (downbeats only, strong)": {
|
356 |
"periodic_p": 0,
|
357 |
"onset_mask_width": 0,
|
|
|
373 |
minimum=0,
|
374 |
maximum=128,
|
375 |
step=1,
|
376 |
+
value=5,
|
377 |
)
|
378 |
|
379 |
|
380 |
onset_mask_width = gr.Slider(
|
381 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
382 |
minimum=0,
|
383 |
+
maximum=100,
|
384 |
step=1,
|
385 |
value=5,
|
386 |
)
|
|
|
398 |
|
399 |
|
400 |
with gr.Accordion("extras ", open=False):
|
401 |
+
pitch_shift_amt = gr.Slider(
|
402 |
+
label="pitch shift amount (semitones)",
|
403 |
+
minimum=-12,
|
404 |
+
maximum=12,
|
405 |
+
step=1,
|
406 |
+
value=0,
|
407 |
+
)
|
408 |
+
|
409 |
rand_mask_intensity = gr.Slider(
|
410 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
411 |
minimum=0.0,
|
|
|
468 |
masktemp = gr.Slider(
|
469 |
label="mask temperature",
|
470 |
minimum=0.0,
|
471 |
+
maximum=100.0,
|
472 |
value=1.5
|
473 |
)
|
474 |
sampletemp = gr.Slider(
|
475 |
label="sample temperature",
|
476 |
minimum=0.1,
|
477 |
+
maximum=10.0,
|
478 |
+
value=1.0,
|
479 |
+
step=0.001
|
480 |
)
|
481 |
|
482 |
|
|
|
492 |
label="typical filtering ",
|
493 |
value=False
|
494 |
)
|
495 |
+
typical_mass = gr.Slider(
|
496 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
497 |
minimum=0.01,
|
498 |
maximum=0.99,
|
|
|
505 |
step=1,
|
506 |
value=64
|
507 |
)
|
508 |
+
sample_cutoff = gr.Slider(
|
509 |
+
label="sample cutoff",
|
510 |
+
minimum=0.0,
|
511 |
+
maximum=1.0,
|
512 |
+
value=0.5,
|
513 |
+
step=0.01
|
514 |
+
)
|
515 |
|
516 |
use_coarse2fine = gr.Checkbox(
|
517 |
label="use coarse2fine",
|
|
|
586 |
typical_mass,
|
587 |
typical_min_tokens,
|
588 |
beat_mask_width,
|
589 |
+
beat_mask_downbeats,
|
590 |
seed,
|
591 |
+
# lora_choice,
|
592 |
+
pitch_shift_amt,
|
593 |
+
sample_cutoff
|
594 |
}
|
595 |
|
596 |
# connect widgets
|
@@ -4,14 +4,16 @@ $include:
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
-
val/AudioDataset.n_examples:
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
batch_size: 7
|
13 |
num_workers: 7
|
14 |
-
save_iters: [
|
|
|
|
|
15 |
|
16 |
AdamW.lr: 0.0001
|
17 |
|
|
|
4 |
fine_tune: True
|
5 |
|
6 |
train/AudioDataset.n_examples: 100000000
|
7 |
+
val/AudioDataset.n_examples: 500
|
8 |
|
9 |
|
10 |
NoamScheduler.warmup: 500
|
11 |
|
12 |
batch_size: 7
|
13 |
num_workers: 7
|
14 |
+
save_iters: [10000, 20000, 30000, 40000, 50000]
|
15 |
+
sample_freq: 1000
|
16 |
+
val_freq: 500
|
17 |
|
18 |
AdamW.lr: 0.0001
|
19 |
|
@@ -1,8 +1,9 @@
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
-
numpy==1.
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
-
descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
|
|
|
|
1 |
torch
|
2 |
argbind>=0.3.2
|
3 |
+
numpy==1.23
|
4 |
gradio
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
+
descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
|
9 |
+
torch_pitch_shift
|
@@ -48,11 +48,9 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
-
"Interface.coarse_ckpt": f"./
|
52 |
-
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
53 |
|
54 |
-
"Interface.coarse2fine_ckpt": f"./
|
55 |
-
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
56 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
57 |
|
58 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
|
|
48 |
}
|
49 |
|
50 |
interface_conf = {
|
51 |
+
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
|
|
52 |
|
53 |
+
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
|
|
54 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
55 |
|
56 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
@@ -14,7 +14,7 @@ from audiotools.data import transforms
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
-
from
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
@@ -29,6 +29,9 @@ from audiotools.ml.decorators import (
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
# Enable cudnn autotuner to speed up training
|
34 |
# (can be altered by the funcs.seed function)
|
@@ -601,7 +604,7 @@ def train(
|
|
601 |
accel=accel,
|
602 |
tracker=tracker,
|
603 |
save_path=save_path)
|
604 |
-
|
605 |
|
606 |
train_dataloader = accel.prepare_dataloader(
|
607 |
state.train_data,
|
@@ -616,13 +619,15 @@ def train(
|
|
616 |
num_workers=num_workers,
|
617 |
batch_size=batch_size,
|
618 |
collate_fn=state.val_data.collate,
|
619 |
-
persistent_workers=
|
620 |
)
|
|
|
621 |
|
622 |
|
623 |
|
624 |
if fine_tune:
|
625 |
lora.mark_only_lora_as_trainable(state.model)
|
|
|
626 |
|
627 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
628 |
# and only run when specific conditions are met.
|
@@ -637,6 +642,7 @@ def train(
|
|
637 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
638 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
639 |
|
|
|
640 |
with tracker.live:
|
641 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
642 |
train_loop(state, batch, accel)
|
|
|
14 |
from einops import rearrange
|
15 |
from rich import pretty
|
16 |
from rich.traceback import install
|
17 |
+
from torch.utils.tensorboard import SummaryWriter
|
18 |
|
19 |
import vampnet
|
20 |
from vampnet.modules.transformer import VampNet
|
|
|
29 |
|
30 |
import loralib as lora
|
31 |
|
32 |
+
import torch._dynamo
|
33 |
+
torch._dynamo.config.verbose=True
|
34 |
+
|
35 |
|
36 |
# Enable cudnn autotuner to speed up training
|
37 |
# (can be altered by the funcs.seed function)
|
|
|
604 |
accel=accel,
|
605 |
tracker=tracker,
|
606 |
save_path=save_path)
|
607 |
+
print("initialized state.")
|
608 |
|
609 |
train_dataloader = accel.prepare_dataloader(
|
610 |
state.train_data,
|
|
|
619 |
num_workers=num_workers,
|
620 |
batch_size=batch_size,
|
621 |
collate_fn=state.val_data.collate,
|
622 |
+
persistent_workers=num_workers > 0,
|
623 |
)
|
624 |
+
print("initialized dataloader.")
|
625 |
|
626 |
|
627 |
|
628 |
if fine_tune:
|
629 |
lora.mark_only_lora_as_trainable(state.model)
|
630 |
+
print("marked only lora as trainable.")
|
631 |
|
632 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
633 |
# and only run when specific conditions are met.
|
|
|
642 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
643 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
644 |
|
645 |
+
print("starting training loop.")
|
646 |
with tracker.live:
|
647 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
648 |
train_loop(state, batch, accel)
|
@@ -5,34 +5,19 @@ from audiotools import AudioSignal
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
|
|
8 |
|
9 |
|
10 |
-
from
|
11 |
-
|
12 |
-
)
|
13 |
-
from pedalboard.io import AudioFile
|
14 |
|
15 |
-
|
16 |
-
samplerate = 44100.0
|
17 |
-
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
18 |
-
audio = f.read(f.frames)
|
19 |
-
|
20 |
-
# Make a pretty interesting sounding guitar pedalboard:
|
21 |
-
board = Pedalboard([
|
22 |
-
Compressor(threshold_db=-50, ratio=25),
|
23 |
-
Gain(gain_db=30),
|
24 |
-
Chorus(),
|
25 |
-
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
26 |
-
Phaser(),
|
27 |
-
Convolution("./guitar_amp.wav", 1.0),
|
28 |
-
Reverb(room_size=0.25),
|
29 |
-
])
|
30 |
|
31 |
|
32 |
@argbind.bind(without_prefix=True)
|
33 |
def augment(
|
34 |
-
audio_folder: Path,
|
35 |
-
dest_folder: Path,
|
36 |
n_augmentations: int = 10,
|
37 |
):
|
38 |
"""
|
@@ -41,7 +26,8 @@ def augment(
|
|
41 |
The dest foler will contain a folder for each of the clean dataset's files.
|
42 |
Under each of these folders, there will be a clean file and many augmented files.
|
43 |
"""
|
44 |
-
|
|
|
45 |
audio_files = at.util.find_audio(audio_folder)
|
46 |
|
47 |
for audio_file in tqdm.tqdm(audio_files):
|
@@ -49,5 +35,33 @@ def augment(
|
|
49 |
subdir = subtree / audio_file.stem
|
50 |
subdir.mkdir(parents=True, exist_ok=True)
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
import argbind
|
7 |
import tqdm
|
8 |
+
import torch
|
9 |
|
10 |
|
11 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
12 |
+
from torch_time_stretch import time_stretch, get_fast_stretches
|
|
|
|
|
13 |
|
14 |
+
from audiotools.core.util import sample_from_dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@argbind.bind(without_prefix=True)
|
18 |
def augment(
|
19 |
+
audio_folder: Path = None,
|
20 |
+
dest_folder: Path = None,
|
21 |
n_augmentations: int = 10,
|
22 |
):
|
23 |
"""
|
|
|
26 |
The dest foler will contain a folder for each of the clean dataset's files.
|
27 |
Under each of these folders, there will be a clean file and many augmented files.
|
28 |
"""
|
29 |
+
assert audio_folder is not None
|
30 |
+
assert dest_folder is not None
|
31 |
audio_files = at.util.find_audio(audio_folder)
|
32 |
|
33 |
for audio_file in tqdm.tqdm(audio_files):
|
|
|
35 |
subdir = subtree / audio_file.stem
|
36 |
subdir.mkdir(parents=True, exist_ok=True)
|
37 |
|
38 |
+
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
|
40 |
+
|
41 |
+
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
42 |
+
# apply pedalboard transforms
|
43 |
+
for j in range(n_augmentations):
|
44 |
+
# pitch shift between -7 and 7 semitones
|
45 |
+
import random
|
46 |
+
dst = chunk.clone()
|
47 |
+
dst.samples = pitch_shift(
|
48 |
+
dst.samples,
|
49 |
+
shift=random.choice(get_fast_shifts(src.sample_rate,
|
50 |
+
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
51 |
+
sample_rate=src.sample_rate
|
52 |
+
)
|
53 |
+
dst.samples = time_stretch(
|
54 |
+
dst.samples,
|
55 |
+
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
56 |
+
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
57 |
+
sample_rate=src.sample_rate,
|
58 |
+
)
|
59 |
+
|
60 |
+
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
args = argbind.parse_args()
|
65 |
+
|
66 |
+
with argbind.scope(args):
|
67 |
+
augment()
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# removes files with loudness below 24db
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import audiotools as at
|
6 |
+
import argbind
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def remove_quiet_files(
|
10 |
+
src_dir: Path = None,
|
11 |
+
dest_dir: Path = None,
|
12 |
+
min_loudness: float = -30,
|
13 |
+
):
|
14 |
+
# copy src to dest
|
15 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
16 |
+
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
17 |
+
|
18 |
+
audio_files = at.util.find_audio(dest_dir)
|
19 |
+
for audio_file in audio_files:
|
20 |
+
sig = at.AudioSignal(audio_file)
|
21 |
+
if sig.loudness() < min_loudness:
|
22 |
+
audio_file.unlink()
|
23 |
+
print(f"removed {audio_file}")
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
args = argbind.parse_args()
|
27 |
+
|
28 |
+
with argbind.scope(args):
|
29 |
+
remove_quiet_files()
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argbind
|
3 |
+
|
4 |
+
import audiotools as at
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
@argbind.bind(without_prefix=True)
|
9 |
+
def split_long_audio_file(
|
10 |
+
file: str = None,
|
11 |
+
max_chunk_size_s: int = 60*10
|
12 |
+
):
|
13 |
+
file = Path(file)
|
14 |
+
output_dir = file.parent / file.stem
|
15 |
+
output_dir.mkdir()
|
16 |
+
|
17 |
+
sig = at.AudioSignal(file)
|
18 |
+
|
19 |
+
# split into chunks
|
20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
22 |
+
preprocess=True))
|
23 |
+
):
|
24 |
+
sig.write(output_dir / f"{i}.wav")
|
25 |
+
|
26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
27 |
+
|
28 |
+
return output_dir
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
args = argbind.parse_args()
|
32 |
+
|
33 |
+
with argbind.scope(args):
|
34 |
+
split_long_audio_file()
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from xenopy import Query
|
2 |
+
|
3 |
+
|
4 |
+
SPECIES = [
|
5 |
+
"American Robin",
|
6 |
+
"Northern Cardinal",
|
7 |
+
"Mourning Dove",
|
8 |
+
"American Crow",
|
9 |
+
"Baltimore Oriole",
|
10 |
+
"Blue Jay",
|
11 |
+
"Eastern Bluebird",
|
12 |
+
"House Finch",
|
13 |
+
"American Goldfinch",
|
14 |
+
"House Sparrow",
|
15 |
+
"Song Sparrow",
|
16 |
+
"Tufted Titmouse",
|
17 |
+
"White-breasted Nuthatch",
|
18 |
+
"European Starling",
|
19 |
+
"American Redstart",
|
20 |
+
"Red-winged Blackbird",
|
21 |
+
"Brown-headed Cowbird",
|
22 |
+
"Common Grackle",
|
23 |
+
"Boat-tailed Grackle",
|
24 |
+
"Common Yellowthroat",
|
25 |
+
"Northern Mockingbird",
|
26 |
+
"Carolina Wren",
|
27 |
+
"Eastern Meadowlark",
|
28 |
+
"Chipping Sparrow",
|
29 |
+
"Tree Swallow",
|
30 |
+
"Barn Swallow",
|
31 |
+
"Cliff Swallow",
|
32 |
+
"Pine Siskin",
|
33 |
+
"Indigo Bunting",
|
34 |
+
"Eastern Towhee",
|
35 |
+
"Carolina Chickadee",
|
36 |
+
"Great Crested Flycatcher",
|
37 |
+
"Eastern Wood-Pewee",
|
38 |
+
"Ovenbird",
|
39 |
+
"Northern Flicker",
|
40 |
+
"Red-eyed Vireo",
|
41 |
+
"American Woodcock",
|
42 |
+
"Eastern Phoebe",
|
43 |
+
"Downy Woodpecker",
|
44 |
+
"Scarlet Tanager",
|
45 |
+
"Yellow Warbler",
|
46 |
+
"White-eyed Vireo",
|
47 |
+
"Common Loon",
|
48 |
+
"White-throated Sparrow",
|
49 |
+
"Yellow-throated Vireo",
|
50 |
+
"Great Blue Heron",
|
51 |
+
"Belted Kingfisher",
|
52 |
+
"Pied-billed Grebe",
|
53 |
+
"Wild Turkey",
|
54 |
+
"Wood Thrush",
|
55 |
+
"Rose-breasted Grosbeak",
|
56 |
+
"Field Sparrow",
|
57 |
+
"Hooded Warbler",
|
58 |
+
"Northern Parula",
|
59 |
+
"Chestnut-sided Warbler",
|
60 |
+
"Blue-winged Warbler",
|
61 |
+
"Red-bellied Woodpecker",
|
62 |
+
"Yellow-billed Cuckoo",
|
63 |
+
"Gray Catbird",
|
64 |
+
"Northern Saw-whet Owl",
|
65 |
+
"Osprey",
|
66 |
+
"Common Nighthawk",
|
67 |
+
"Broad-winged Hawk",
|
68 |
+
"Black-throated Green Warbler",
|
69 |
+
"Great Horned Owl",
|
70 |
+
"Common Raven",
|
71 |
+
"Barred Owl",
|
72 |
+
"Canada Warbler",
|
73 |
+
"Magnolia Warbler",
|
74 |
+
"Black-and-white Warbler",
|
75 |
+
"Eastern Kingbird",
|
76 |
+
"Swainson's Thrush",
|
77 |
+
"Worm-eating Warbler",
|
78 |
+
"Prairie Warbler",
|
79 |
+
"Baltimore Oriole",
|
80 |
+
"Black-throated Blue Warbler",
|
81 |
+
"Louisiana Waterthrush",
|
82 |
+
"Blackburnian Warbler",
|
83 |
+
"Black-capped Chickadee",
|
84 |
+
"Cerulean Warbler",
|
85 |
+
"Red-shouldered Hawk",
|
86 |
+
"Cooper's Hawk",
|
87 |
+
"Yellow-throated Warbler",
|
88 |
+
"Blue-headed Vireo",
|
89 |
+
"Blackpoll Warbler",
|
90 |
+
"Ruffed Grouse",
|
91 |
+
"Kentucky Warbler",
|
92 |
+
"Hermit Thrush",
|
93 |
+
"Cedar Waxwing",
|
94 |
+
"Eastern Screech-Owl",
|
95 |
+
"Northern Goshawk",
|
96 |
+
"Green Heron",
|
97 |
+
"Red-tailed Hawk",
|
98 |
+
"Black Vulture",
|
99 |
+
"Hairy Woodpecker",
|
100 |
+
"Golden-crowned Kinglet",
|
101 |
+
"Ruby-crowned Kinglet",
|
102 |
+
"Bicknell's Thrush",
|
103 |
+
"Blue-gray Gnatcatcher",
|
104 |
+
"Veery",
|
105 |
+
"Pileated Woodpecker",
|
106 |
+
"Purple Finch",
|
107 |
+
"White-crowned Sparrow",
|
108 |
+
"Snow Bunting",
|
109 |
+
"Pine Grosbeak",
|
110 |
+
"American Tree Sparrow",
|
111 |
+
"Dark-eyed Junco",
|
112 |
+
"Snowy Owl",
|
113 |
+
"White-winged Crossbill",
|
114 |
+
"Red Crossbill",
|
115 |
+
"Common Redpoll",
|
116 |
+
"Northern Shrike",
|
117 |
+
"Northern Harrier",
|
118 |
+
"Rough-legged Hawk",
|
119 |
+
"Long-eared Owl",
|
120 |
+
"Evening Grosbeak",
|
121 |
+
"Northern Pintail",
|
122 |
+
"American Black Duck",
|
123 |
+
"Mallard",
|
124 |
+
"Canvasback",
|
125 |
+
"Redhead",
|
126 |
+
"Ring-necked Duck",
|
127 |
+
"Greater Scaup",
|
128 |
+
"Lesser Scaup",
|
129 |
+
"Bufflehead",
|
130 |
+
"Common Goldeneye",
|
131 |
+
"Hooded Merganser",
|
132 |
+
"Common Merganser",
|
133 |
+
"Red-breasted Merganser",
|
134 |
+
"Ruddy Duck",
|
135 |
+
"Wood Duck",
|
136 |
+
"Gadwall",
|
137 |
+
"American Wigeon",
|
138 |
+
"Northern Shoveler",
|
139 |
+
"Green-winged Teal",
|
140 |
+
"Blue-winged Teal",
|
141 |
+
"Cinnamon Teal",
|
142 |
+
"Ringed Teal",
|
143 |
+
"Cape Teal",
|
144 |
+
"Northern Fulmar",
|
145 |
+
"Yellow-billed Loon",
|
146 |
+
"Red-throated Loon",
|
147 |
+
"Arctic Loon",
|
148 |
+
"Pacific Loon",
|
149 |
+
"Horned Grebe",
|
150 |
+
"Red-necked Grebe",
|
151 |
+
"Eared Grebe",
|
152 |
+
"Western Grebe",
|
153 |
+
"Clark's Grebe",
|
154 |
+
"Double-crested Cormorant",
|
155 |
+
"Pelagic Cormorant",
|
156 |
+
"Great Cormorant",
|
157 |
+
"American White Pelican",
|
158 |
+
"Brown Pelican",
|
159 |
+
"Brandt's Cormorant",
|
160 |
+
"Least Bittern",
|
161 |
+
"Great Egret",
|
162 |
+
"Snowy Egret",
|
163 |
+
"Little Blue Heron",
|
164 |
+
"Tricolored Heron",
|
165 |
+
"Reddish Egret",
|
166 |
+
"Black-crowned Night-Heron",
|
167 |
+
"Yellow-crowned Night-Heron",
|
168 |
+
"White Ibis",
|
169 |
+
"Glossy Ibis",
|
170 |
+
"Roseate Spoonbill",
|
171 |
+
"Wood Stork",
|
172 |
+
"Black-bellied Whistling-Duck",
|
173 |
+
"Fulvous Whistling-Duck",
|
174 |
+
"Greater White-fronted Goose",
|
175 |
+
"Snow Goose",
|
176 |
+
"Ross's Goose",
|
177 |
+
"Canada Goose",
|
178 |
+
"Brant",
|
179 |
+
"Mute Swan",
|
180 |
+
"Tundra Swan",
|
181 |
+
"Whooper Swan",
|
182 |
+
"Sandhill Crane",
|
183 |
+
"Black-necked Stilt",
|
184 |
+
"American Avocet",
|
185 |
+
"Northern Jacana",
|
186 |
+
"Greater Yellowlegs",
|
187 |
+
"Lesser Yellowlegs",
|
188 |
+
"Willet",
|
189 |
+
"Spotted Sandpiper",
|
190 |
+
"Upland Sandpiper",
|
191 |
+
"Whimbrel",
|
192 |
+
"Long-billed Curlew",
|
193 |
+
"Marbled Godwit",
|
194 |
+
"Ruddy Turnstone",
|
195 |
+
"Red Knot",
|
196 |
+
"Sanderling",
|
197 |
+
"Semipalmated Sandpiper",
|
198 |
+
"Western Sandpiper",
|
199 |
+
"Least Sandpiper",
|
200 |
+
"White-rumped Sandpiper",
|
201 |
+
"Baird's Sandpiper",
|
202 |
+
"Pectoral Sandpiper",
|
203 |
+
"Dunlin",
|
204 |
+
"Buff-breasted Sandpiper",
|
205 |
+
"Short-billed Dowitcher",
|
206 |
+
"Long-billed Dowitcher",
|
207 |
+
"Common Snipe",
|
208 |
+
"American Woodcock",
|
209 |
+
"Wilson's Phalarope",
|
210 |
+
"Red-necked Phalarope",
|
211 |
+
"Red Phalarope"
|
212 |
+
]
|
213 |
+
|
214 |
+
from pathlib import Path
|
215 |
+
|
216 |
+
def remove_spaces(s):
|
217 |
+
return s.replace(" ", "")
|
218 |
+
|
219 |
+
for species in SPECIES:
|
220 |
+
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
221 |
+
continue
|
222 |
+
try:
|
223 |
+
q = Query(
|
224 |
+
name=species, q="A", length="10-30",
|
225 |
+
)
|
226 |
+
|
227 |
+
# retrieve metadata
|
228 |
+
metafiles = q.retrieve_meta(verbose=True)
|
229 |
+
# retrieve recordings
|
230 |
+
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
231 |
+
|
232 |
+
except:
|
233 |
+
print("Failed to download " + species)
|
234 |
+
continue
|
@@ -28,12 +28,13 @@ setup(
|
|
28 |
install_requires=[
|
29 |
"torch",
|
30 |
"argbind>=0.3.2",
|
31 |
-
"numpy==1.
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
35 |
"gradio",
|
36 |
-
"tensorboardX",
|
37 |
"loralib",
|
|
|
|
|
38 |
],
|
39 |
)
|
|
|
28 |
install_requires=[
|
29 |
"torch",
|
30 |
"argbind>=0.3.2",
|
31 |
+
"numpy==1.23",
|
32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
34 |
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
35 |
"gradio",
|
|
|
36 |
"loralib",
|
37 |
+
"torch_pitch_shift",
|
38 |
+
"madmom",
|
39 |
],
|
40 |
)
|
@@ -120,17 +120,16 @@ class Interface(torch.nn.Module):
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
-
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
-
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
133 |
-
|
134 |
def s2t(self, seconds: float):
|
135 |
"""seconds to tokens"""
|
136 |
if isinstance(seconds, np.ndarray):
|
|
|
120 |
if coarse_ckpt is not None:
|
121 |
self.coarse.to("cpu")
|
122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
123 |
+
print(f"loading coarse from {coarse_ckpt}")
|
124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
125 |
self.coarse.to(self.device)
|
126 |
if c2f_ckpt is not None:
|
127 |
self.c2f.to("cpu")
|
128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
129 |
+
print(f"loading c2f from {c2f_ckpt}")
|
130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
131 |
self.c2f.to(self.device)
|
132 |
|
|
|
133 |
def s2t(self, seconds: float):
|
134 |
"""seconds to tokens"""
|
135 |
if isinstance(seconds, np.ndarray):
|
@@ -191,29 +191,47 @@ def onset_mask(
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
return mask
|
214 |
|
215 |
|
216 |
|
217 |
if __name__ == "__main__":
|
218 |
-
|
219 |
-
|
|
|
191 |
width: int = 1
|
192 |
):
|
193 |
import librosa
|
194 |
+
import madmom
|
195 |
+
from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
|
196 |
+
import tempfile
|
197 |
+
import numpy as np
|
198 |
+
|
199 |
+
with tempfile.NamedTemporaryFile(suffix='.wav') as f:
|
200 |
+
sig = sig.clone()
|
201 |
+
sig.write(f.name)
|
202 |
+
|
203 |
+
proc = RNNOnsetProcessor(online=False)
|
204 |
+
onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
|
205 |
+
fps=sig.sample_rate/interface.codec.hop_length)
|
206 |
+
|
207 |
+
act = proc(f.name)
|
208 |
+
onset_times = onsetproc(act)
|
209 |
+
|
210 |
+
# convert to indices for z array
|
211 |
+
onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
|
212 |
+
|
213 |
+
if onset_indices.shape[0] == 0:
|
214 |
+
mask = empty_mask(z)
|
215 |
+
print(f"no onsets found, returning empty mask")
|
216 |
+
else:
|
217 |
+
torch.set_printoptions(threshold=1000)
|
218 |
+
print("onset indices: ", onset_indices)
|
219 |
+
print("onset times: ", onset_times)
|
220 |
+
|
221 |
+
# create a mask, set onset
|
222 |
+
mask = torch.ones_like(z)
|
223 |
+
n_timesteps = z.shape[-1]
|
224 |
+
|
225 |
+
for onset_index in onset_indices:
|
226 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
227 |
+
onset_index = max(onset_index, 0)
|
228 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
229 |
+
|
230 |
+
print(mask)
|
231 |
|
232 |
return mask
|
233 |
|
234 |
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
+
pass
|
|
@@ -367,15 +367,6 @@ class TransformerLayer(nn.Module):
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
370 |
-
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
371 |
-
x = np.linspace(0, 1, n_steps)
|
372 |
-
a = (0.5 - min_temp) / (max_temp - min_temp)
|
373 |
-
|
374 |
-
x = (x * 12) - 6
|
375 |
-
x0 = np.log((1 / a - 1) + 1e-5) / k
|
376 |
-
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
377 |
-
|
378 |
-
return y
|
379 |
|
380 |
class TransformerStack(nn.Module):
|
381 |
def __init__(
|
@@ -587,17 +578,18 @@ class VampNet(at.ml.BaseModel):
|
|
587 |
self,
|
588 |
codec,
|
589 |
time_steps: int = 300,
|
590 |
-
sampling_steps: int =
|
591 |
start_tokens: Optional[torch.Tensor] = None,
|
592 |
sampling_temperature: float = 1.0,
|
593 |
mask: Optional[torch.Tensor] = None,
|
594 |
-
mask_temperature: float =
|
595 |
typical_filtering=False,
|
596 |
typical_mass=0.2,
|
597 |
typical_min_tokens=1,
|
598 |
top_p=None,
|
599 |
return_signal=True,
|
600 |
-
seed: int = None
|
|
|
601 |
):
|
602 |
if seed is not None:
|
603 |
at.util.seed(seed)
|
@@ -650,7 +642,6 @@ class VampNet(at.ml.BaseModel):
|
|
650 |
#################
|
651 |
# begin sampling #
|
652 |
#################
|
653 |
-
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
654 |
|
655 |
for i in range(sampling_steps):
|
656 |
logging.debug(f"step {i} of {sampling_steps}")
|
@@ -676,10 +667,13 @@ class VampNet(at.ml.BaseModel):
|
|
676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
677 |
|
678 |
sampled_z, selected_probs = sample_from_logits(
|
679 |
-
logits, sample=
|
|
|
|
|
|
|
680 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
681 |
typical_min_tokens=typical_min_tokens,
|
682 |
-
top_k=None, top_p=top_p, return_probs=True
|
683 |
)
|
684 |
|
685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
@@ -839,7 +833,11 @@ def sample_from_logits(
|
|
839 |
|
840 |
|
841 |
|
842 |
-
def mask_by_random_topk(
|
|
|
|
|
|
|
|
|
843 |
"""
|
844 |
Args:
|
845 |
num_to_mask (int): number of tokens to mask
|
@@ -852,7 +850,8 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
|
|
852 |
logging.debug(f"temperature: {temperature}")
|
853 |
logging.debug("")
|
854 |
|
855 |
-
|
|
|
856 |
logging.debug(f"confidence shape: {confidence.shape}")
|
857 |
|
858 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
class TransformerStack(nn.Module):
|
372 |
def __init__(
|
|
|
578 |
self,
|
579 |
codec,
|
580 |
time_steps: int = 300,
|
581 |
+
sampling_steps: int = 36,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
583 |
sampling_temperature: float = 1.0,
|
584 |
mask: Optional[torch.Tensor] = None,
|
585 |
+
mask_temperature: float = 10.5,
|
586 |
typical_filtering=False,
|
587 |
typical_mass=0.2,
|
588 |
typical_min_tokens=1,
|
589 |
top_p=None,
|
590 |
return_signal=True,
|
591 |
+
seed: int = None,
|
592 |
+
sample_cutoff: float = 0.5,
|
593 |
):
|
594 |
if seed is not None:
|
595 |
at.util.seed(seed)
|
|
|
642 |
#################
|
643 |
# begin sampling #
|
644 |
#################
|
|
|
645 |
|
646 |
for i in range(sampling_steps):
|
647 |
logging.debug(f"step {i} of {sampling_steps}")
|
|
|
667 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
668 |
|
669 |
sampled_z, selected_probs = sample_from_logits(
|
670 |
+
logits, sample=(
|
671 |
+
(i / sampling_steps) <= sample_cutoff
|
672 |
+
),
|
673 |
+
temperature=sampling_temperature,
|
674 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
675 |
typical_min_tokens=typical_min_tokens,
|
676 |
+
top_k=None, top_p=top_p, return_probs=True,
|
677 |
)
|
678 |
|
679 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
|
|
833 |
|
834 |
|
835 |
|
836 |
+
def mask_by_random_topk(
|
837 |
+
num_to_mask: int,
|
838 |
+
probs: torch.Tensor,
|
839 |
+
temperature: float = 1.0,
|
840 |
+
):
|
841 |
"""
|
842 |
Args:
|
843 |
num_to_mask (int): number of tokens to mask
|
|
|
850 |
logging.debug(f"temperature: {temperature}")
|
851 |
logging.debug("")
|
852 |
|
853 |
+
noise = gumbel_noise_like(probs)
|
854 |
+
confidence = torch.log(probs) + temperature * noise
|
855 |
logging.debug(f"confidence shape: {confidence.shape}")
|
856 |
|
857 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|