Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L40S
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L40S
	
		root
		
	commited on
		
		
					Commit 
							
							·
						
						6e28f61
	
1
								Parent(s):
							
							4846f0c
								
remove fairseq
Browse files
    	
        codeclm/tokenizer/Flow1dVAE/generate_septoken.py
    CHANGED
    
    | @@ -14,8 +14,8 @@ import tools.torch_tools as torch_tools | |
| 14 | 
             
            from safetensors.torch import load_file
         | 
| 15 | 
             
            from third_party.demucs.models.pretrained import get_model_from_yaml
         | 
| 16 | 
             
            from filelock import FileLock
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
             
            class Separator:
         | 
| 20 | 
             
                def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
         | 
| 21 | 
             
                    if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
         | 
|  | |
| 14 | 
             
            from safetensors.torch import load_file
         | 
| 15 | 
             
            from third_party.demucs.models.pretrained import get_model_from_yaml
         | 
| 16 | 
             
            from filelock import FileLock
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
             
            class Separator:
         | 
| 20 | 
             
                def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
         | 
| 21 | 
             
                    if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
         | 
    	
        codeclm/tokenizer/Flow1dVAE/model_1rvq.py
    CHANGED
    
    | @@ -19,12 +19,11 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize | |
| 19 |  | 
| 20 | 
             
            from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
         | 
| 21 | 
             
            from models_gpt.models.gpt2_config import GPT2Config
         | 
|  | |
| 22 |  | 
| 23 | 
             
            from torch.cuda.amp import autocast
         | 
| 24 |  | 
| 25 |  | 
| 26 | 
            -
            from our_MERT_BESTRQ.test import load_model
         | 
| 27 | 
            -
             | 
| 28 | 
             
            class HubertModelWithFinalProj(HubertModel):
         | 
| 29 | 
             
                def __init__(self, config):
         | 
| 30 | 
             
                    super().__init__(config)
         | 
| @@ -272,6 +271,7 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 272 | 
             
                    ssl_layer=None,
         | 
| 273 | 
             
                    uncondition=True,
         | 
| 274 | 
             
                    out_paint=False,
         | 
|  | |
| 275 | 
             
                ):
         | 
| 276 | 
             
                    super().__init__()
         | 
| 277 |  | 
| @@ -294,28 +294,24 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 294 | 
             
                    self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
         | 
| 295 | 
             
                    # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 296 | 
             
                    # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 297 | 
            -
                    self.bestrq =  | 
| 298 | 
            -
                        model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
         | 
| 299 | 
            -
                        checkpoint_dir='ckpt/encode-s12k.pt',
         | 
| 300 | 
            -
                    )
         | 
| 301 | 
             
                    self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
         | 
| 302 | 
             
                    self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
         | 
| 303 | 
            -
                    for v in self.bestrq.parameters():v.requires_grad = False
         | 
| 304 | 
             
                    self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 305 | 
             
                    for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
         | 
| 306 | 
             
                    # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
         | 
| 307 | 
             
                    # for v in self.hubert.parameters():v.requires_grad = False
         | 
| 308 | 
             
                    self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
         | 
| 309 | 
             
                    # self.xvecmodel = XVECModel()
         | 
| 310 | 
            -
                    config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
         | 
| 311 | 
            -
                    unet = GPT2Model(config)
         | 
| 312 | 
            -
                    mlp =  nn.Sequential(
         | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
| 315 | 
            -
             | 
| 316 | 
            -
             | 
| 317 | 
            -
             | 
| 318 | 
            -
                    )
         | 
| 319 | 
             
                    self.set_from = "random"
         | 
| 320 | 
             
                    # self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
         | 
| 321 | 
             
                    self.mask_emb = torch.nn.Embedding(3, 48)
         | 
| @@ -538,8 +534,6 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 538 | 
             
                    input_audio_0 = self.preprocess_audio(input_audio_0)
         | 
| 539 | 
             
                    input_audio_1 = self.preprocess_audio(input_audio_1)
         | 
| 540 |  | 
| 541 | 
            -
                    self.bestrq.eval()
         | 
| 542 | 
            -
             | 
| 543 | 
             
                    # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
         | 
| 544 | 
             
                    # bestrq_middle = bestrq_middle.detach()
         | 
| 545 | 
             
                    # bestrq_last = bestrq_last.detach()
         | 
| @@ -575,8 +569,6 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 575 | 
             
                    input_audio_0 = self.preprocess_audio(input_audio_0)
         | 
| 576 | 
             
                    input_audio_1 = self.preprocess_audio(input_audio_1)
         | 
| 577 |  | 
| 578 | 
            -
                    self.bestrq.eval()
         | 
| 579 | 
            -
             | 
| 580 | 
             
                    # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
         | 
| 581 | 
             
                    # bestrq_middle = bestrq_middle.detach()
         | 
| 582 | 
             
                    # bestrq_last = bestrq_last.detach()
         | 
|  | |
| 19 |  | 
| 20 | 
             
            from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
         | 
| 21 | 
             
            from models_gpt.models.gpt2_config import GPT2Config
         | 
| 22 | 
            +
            from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
         | 
| 23 |  | 
| 24 | 
             
            from torch.cuda.amp import autocast
         | 
| 25 |  | 
| 26 |  | 
|  | |
|  | |
| 27 | 
             
            class HubertModelWithFinalProj(HubertModel):
         | 
| 28 | 
             
                def __init__(self, config):
         | 
| 29 | 
             
                    super().__init__(config)
         | 
|  | |
| 271 | 
             
                    ssl_layer=None,
         | 
| 272 | 
             
                    uncondition=True,
         | 
| 273 | 
             
                    out_paint=False,
         | 
| 274 | 
            +
                    ssl_path='ckpt/encode-s12k.pt'
         | 
| 275 | 
             
                ):
         | 
| 276 | 
             
                    super().__init__()
         | 
| 277 |  | 
|  | |
| 294 | 
             
                    self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
         | 
| 295 | 
             
                    # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 296 | 
             
                    # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 297 | 
            +
                    self.bestrq = MusicFMModel(MusicFMConfig())
         | 
|  | |
|  | |
|  | |
| 298 | 
             
                    self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
         | 
| 299 | 
             
                    self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
         | 
|  | |
| 300 | 
             
                    self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 301 | 
             
                    for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
         | 
| 302 | 
             
                    # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
         | 
| 303 | 
             
                    # for v in self.hubert.parameters():v.requires_grad = False
         | 
| 304 | 
             
                    self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
         | 
| 305 | 
             
                    # self.xvecmodel = XVECModel()
         | 
| 306 | 
            +
                    # config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
         | 
| 307 | 
            +
                    # unet = GPT2Model(config)
         | 
| 308 | 
            +
                    # mlp =  nn.Sequential(
         | 
| 309 | 
            +
                    #     nn.Linear(1200, 1024), 
         | 
| 310 | 
            +
                    #     nn.SiLU(),                  
         | 
| 311 | 
            +
                    #     nn.Linear(1024, 1024),      
         | 
| 312 | 
            +
                    #     nn.SiLU(),                 
         | 
| 313 | 
            +
                    #     nn.Linear(1024, 768)  
         | 
| 314 | 
            +
                    # )
         | 
| 315 | 
             
                    self.set_from = "random"
         | 
| 316 | 
             
                    # self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
         | 
| 317 | 
             
                    self.mask_emb = torch.nn.Embedding(3, 48)
         | 
|  | |
| 534 | 
             
                    input_audio_0 = self.preprocess_audio(input_audio_0)
         | 
| 535 | 
             
                    input_audio_1 = self.preprocess_audio(input_audio_1)
         | 
| 536 |  | 
|  | |
|  | |
| 537 | 
             
                    # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
         | 
| 538 | 
             
                    # bestrq_middle = bestrq_middle.detach()
         | 
| 539 | 
             
                    # bestrq_last = bestrq_last.detach()
         | 
|  | |
| 569 | 
             
                    input_audio_0 = self.preprocess_audio(input_audio_0)
         | 
| 570 | 
             
                    input_audio_1 = self.preprocess_audio(input_audio_1)
         | 
| 571 |  | 
|  | |
|  | |
| 572 | 
             
                    # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
         | 
| 573 | 
             
                    # bestrq_middle = bestrq_middle.detach()
         | 
| 574 | 
             
                    # bestrq_last = bestrq_last.detach()
         | 
    	
        codeclm/tokenizer/Flow1dVAE/model_septoken.py
    CHANGED
    
    | @@ -20,9 +20,9 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize | |
| 20 |  | 
| 21 | 
             
            from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
         | 
| 22 | 
             
            from models_gpt.models.gpt2_config import GPT2Config
         | 
|  | |
| 23 |  | 
| 24 | 
             
            from torch.cuda.amp import autocast
         | 
| 25 | 
            -
            from our_MERT_BESTRQ.test import load_model
         | 
| 26 |  | 
| 27 | 
             
            class HubertModelWithFinalProj(HubertModel):
         | 
| 28 | 
             
                def __init__(self, config):
         | 
| @@ -253,6 +253,7 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 253 | 
             
                    snr_gamma=None,
         | 
| 254 | 
             
                    uncondition=True,
         | 
| 255 | 
             
                    out_paint=False,
         | 
|  | |
| 256 | 
             
                ):
         | 
| 257 | 
             
                    super().__init__()
         | 
| 258 |  | 
| @@ -273,13 +274,9 @@ class PromptCondAudioDiffusion(nn.Module): | |
| 273 | 
             
                    self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
         | 
| 274 | 
             
                    # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 275 | 
             
                    # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 276 | 
            -
                    self.bestrq =  | 
| 277 | 
            -
                        model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
         | 
| 278 | 
            -
                        checkpoint_dir='ckpt/encode-s12k.pt',
         | 
| 279 | 
            -
                    )
         | 
| 280 | 
             
                    self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
         | 
| 281 | 
             
                    self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
         | 
| 282 | 
            -
                    for v in self.bestrq.parameters():v.requires_grad = False
         | 
| 283 | 
             
                    self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 284 | 
             
                    self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 285 | 
             
                    # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
         | 
|  | |
| 20 |  | 
| 21 | 
             
            from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
         | 
| 22 | 
             
            from models_gpt.models.gpt2_config import GPT2Config
         | 
| 23 | 
            +
            from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
         | 
| 24 |  | 
| 25 | 
             
            from torch.cuda.amp import autocast
         | 
|  | |
| 26 |  | 
| 27 | 
             
            class HubertModelWithFinalProj(HubertModel):
         | 
| 28 | 
             
                def __init__(self, config):
         | 
|  | |
| 253 | 
             
                    snr_gamma=None,
         | 
| 254 | 
             
                    uncondition=True,
         | 
| 255 | 
             
                    out_paint=False,
         | 
| 256 | 
            +
                    ssl_path='ckpt/encode-s12k.pt'
         | 
| 257 | 
             
                ):
         | 
| 258 | 
             
                    super().__init__()
         | 
| 259 |  | 
|  | |
| 274 | 
             
                    self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
         | 
| 275 | 
             
                    # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 276 | 
             
                    # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
         | 
| 277 | 
            +
                    self.bestrq = MusicFMModel(MusicFMConfig())
         | 
|  | |
|  | |
|  | |
| 278 | 
             
                    self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
         | 
| 279 | 
             
                    self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
         | 
|  | |
| 280 | 
             
                    self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 281 | 
             
                    self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
         | 
| 282 | 
             
                    # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
         | 
    	
        codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py
    CHANGED
    
    | @@ -4,14 +4,6 @@ except: | |
| 4 | 
             
                import sys, os
         | 
| 5 | 
             
                sys.path.append(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
             
                from model.musicfm_25hz import MusicFM25Hz
         | 
| 7 | 
            -
            try:
         | 
| 8 | 
            -
                from fairseq.fairseq.dataclass import FairseqDataclass
         | 
| 9 | 
            -
                from fairseq.fairseq.models import BaseFairseqModel, register_model
         | 
| 10 | 
            -
                from fairseq.fairseq.tasks.fairseq_task import FairseqTask
         | 
| 11 | 
            -
            except:
         | 
| 12 | 
            -
                from fairseq.dataclass import FairseqDataclass
         | 
| 13 | 
            -
                from fairseq.models import BaseFairseqModel, register_model
         | 
| 14 | 
            -
                from fairseq.tasks.fairseq_task import FairseqTask
         | 
| 15 |  | 
| 16 | 
             
            from dataclasses import dataclass, field
         | 
| 17 | 
             
            from typing import List, Tuple, Optional
         | 
| @@ -22,7 +14,7 @@ from logging import getLogger | |
| 22 | 
             
            logger = getLogger(__name__)
         | 
| 23 |  | 
| 24 | 
             
            @dataclass
         | 
| 25 | 
            -
            class MusicFMConfig | 
| 26 | 
             
                label_rate:int = field(default=25)
         | 
| 27 | 
             
                num_codebooks:int = field(default=1)
         | 
| 28 | 
             
                codebook_dim:int = field(default=16)
         | 
| @@ -45,9 +37,8 @@ class MusicFMConfig(FairseqDataclass): | |
| 45 |  | 
| 46 | 
             
            SAMPLE_RATE = 24_000
         | 
| 47 |  | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                def __init__(self, cfg: MusicFMConfig, task_cfg: FairseqTask):
         | 
| 51 | 
             
                    super().__init__()
         | 
| 52 | 
             
                    self.cfg = cfg
         | 
| 53 | 
             
                    self.model = MusicFM25Hz(
         | 
| @@ -91,19 +82,3 @@ class MusicFMModel(BaseFairseqModel): | |
| 91 | 
             
                        result["logits"] = logits
         | 
| 92 | 
             
                        result["hidden_emb"] = hidden_emb
         | 
| 93 | 
             
                        return result
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                @classmethod
         | 
| 96 | 
            -
                def build_model(cls, cfg: MusicFMConfig, task: FairseqTask):
         | 
| 97 | 
            -
                    """Build a new model instance."""
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    model = MusicFMModel(cfg, task.cfg)
         | 
| 100 | 
            -
                    import numpy as np
         | 
| 101 | 
            -
                    s = 0
         | 
| 102 | 
            -
                    for param in model.parameters():
         | 
| 103 | 
            -
                        s += np.product(param.size())
         | 
| 104 | 
            -
                    print('# of parameters: '+str(s/1024.0/1024.0))
         | 
| 105 | 
            -
                    return model
         | 
| 106 | 
            -
             | 
| 107 | 
            -
                def get_losses(self, result, batch):
         | 
| 108 | 
            -
                    return result['losses']
         | 
| 109 | 
            -
                
         | 
|  | |
| 4 | 
             
                import sys, os
         | 
| 5 | 
             
                sys.path.append(os.path.dirname(os.path.abspath(__file__)))
         | 
| 6 | 
             
                from model.musicfm_25hz import MusicFM25Hz
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 |  | 
| 8 | 
             
            from dataclasses import dataclass, field
         | 
| 9 | 
             
            from typing import List, Tuple, Optional
         | 
|  | |
| 14 | 
             
            logger = getLogger(__name__)
         | 
| 15 |  | 
| 16 | 
             
            @dataclass
         | 
| 17 | 
            +
            class MusicFMConfig:
         | 
| 18 | 
             
                label_rate:int = field(default=25)
         | 
| 19 | 
             
                num_codebooks:int = field(default=1)
         | 
| 20 | 
             
                codebook_dim:int = field(default=16)
         | 
|  | |
| 37 |  | 
| 38 | 
             
            SAMPLE_RATE = 24_000
         | 
| 39 |  | 
| 40 | 
            +
            class MusicFMModel(torch.nn.Module):
         | 
| 41 | 
            +
                def __init__(self, cfg: MusicFMConfig):
         | 
|  | |
| 42 | 
             
                    super().__init__()
         | 
| 43 | 
             
                    self.cfg = cfg
         | 
| 44 | 
             
                    self.model = MusicFM25Hz(
         | 
|  | |
| 82 | 
             
                        result["logits"] = logits
         | 
| 83 | 
             
                        result["hidden_emb"] = hidden_emb
         | 
| 84 | 
             
                        return result
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
