Calculation of _mscale during YARN RoPE scaling

#4
by sszymczyk - opened

I noticed that you calculate the cached sin and cos YARN RoPE values like this:

        _mscale = float(
            yarn_get_mscale(self.scaling_factor, self.mscale)
            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
        )

        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer(
            "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
        )
        self.register_buffer(
            "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
        )

But in config.json self.mscale (0.707) is equal to self.mscale_all_dim (also 0.707), so yarn_get_mscale(self.scaling_factor, self.mscale) will be equal to yarn_get_mscale(self.scaling_factor, self.mscale_all_dim), therefore _mscale will simply be 1.0. Is this intentional?

If anyone is interested I think I finally figured it out: https://github.com/ggerganov/llama.cpp/discussions/7416

Sign up or log in to comment