File size: 2,580 Bytes
0d80816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler

from models.svc.base import SVCInference
from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
from modules.encoder.condition_encoder import ConditionEncoder


class DiffusionInference(SVCInference):
    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
        SVCInference.__init__(self, args, cfg, infer_type)

        settings = {
            **cfg.model.diffusion.scheduler_settings,
            **cfg.inference.diffusion.scheduler_settings,
        }
        settings.pop("num_inference_timesteps")

        if cfg.inference.diffusion.scheduler.lower() == "ddpm":
            self.scheduler = DDPMScheduler(**settings)
            self.logger.info("Using DDPM scheduler.")
        elif cfg.inference.diffusion.scheduler.lower() == "ddim":
            self.scheduler = DDIMScheduler(**settings)
            self.logger.info("Using DDIM scheduler.")
        elif cfg.inference.diffusion.scheduler.lower() == "pndm":
            self.scheduler = PNDMScheduler(**settings)
            self.logger.info("Using PNDM scheduler.")
        else:
            raise NotImplementedError(
                "Unsupported scheduler type: {}".format(
                    cfg.inference.diffusion.scheduler.lower()
                )
            )

        self.pipeline = DiffusionInferencePipeline(
            self.model[1],
            self.scheduler,
            cfg.inference.diffusion.scheduler_settings.num_inference_timesteps,
        )

    def _build_model(self):
        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
        self.acoustic_mapper = DiffusionWrapper(self.cfg)
        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
        return model

    def _inference_each_batch(self, batch_data):
        device = self.accelerator.device
        for k, v in batch_data.items():
            batch_data[k] = v.to(device)

        conditioner = self.model[0](batch_data)
        noise = torch.randn_like(batch_data["mel"], device=device)
        y_pred = self.pipeline(noise, conditioner)
        return y_pred