基于llama2训练的模型,你们有一个bug并没有修复

#1
by Trangle - opened

llama2的实际max_position_embeddings=4096,而不是修正前的2048,2048是因为transformers的老代码没更新就提前转换之前默认参数导致。另外可以测测IF的能力,3次模型的表现都很差

Tiger Research org

你好,max_position_embeddings在llama中做rotary embedding的事先cache用,modeling_llama.py中做了实际序列长度大于这个值时的重新计算。由于我们的训练数据处理成了2048长度内,所以我们是用了2048长度的cache。如果您认为这是个bug,可以提供更详细的说明或者case吗?

这个问题在llama2发布的第三天应该就更新修复了,可以查阅最新的模型配置。就是llama2是在2.2T tokens上以4096的长度训练的。

Tiger Research org
edited Aug 9, 2023

对,但是这个max_position_embedding并不会引起计算上的不一致,哪怕是以4096长度训练,在infer构建模型时写2、4、8也是没有问题的,我们挑选了和继续训练的数据长度更一致的2048。另外,您提到的IF,可以提供更详细一些的信息吗?

@Trangle 大部分llama2中文化的工作max_position_embedding都是2048,可以看看。

i4never changed discussion status to closed
i4never changed discussion status to open
Tiger Research org

@Trangle
这里提供一段简单的代码,以解释为什么虽然meta更改了这个参数,但是并不会引起计算上的不一致,因此并不是bug:

import torch
import transformers
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding

print(transformers.__version__)

bs = 4
num_heads = 64
seq_lengths = [30, 300, 3000, 6000]
hidden_size = 64

llama_linear_scaling_re_4096 = LlamaLinearScalingRotaryEmbedding(dim=hidden_size,
                                                                 max_position_embeddings=4096,
                                                                 device='cpu')
llama_linear_scaling_re_2048 = LlamaLinearScalingRotaryEmbedding(dim=hidden_size,
                                                                 max_position_embeddings=2048,
                                                                 device='cpu')

for seq_length in seq_lengths:
    hidden_state = torch.rand(bs, num_heads, seq_length, hidden_size)
    cos_by_4096, sin_by_4096 = llama_linear_scaling_re_4096(hidden_state, seq_length)
    cos_by_2048, sin_by_2048 = llama_linear_scaling_re_2048(hidden_state, seq_length)
    print(f"""seq_length: {seq_length}, {cos_by_4096.shape} {sin_by_4096.shape} {cos_by_2048.shape} {sin_by_2048.shape}
cos pos is equal: {torch.equal(cos_by_4096, cos_by_2048)}
sin pos is equal: {torch.equal(sin_by_4096, sin_by_2048)}""")

输出如下:

4.31.0
seq_length: 30, torch.Size([1, 1, 30, 64]) torch.Size([1, 1, 30, 64]) torch.Size([1, 1, 30, 64]) torch.Size([1, 1, 30, 64])
cos pos is equal: True
sin pos is equal: True
seq_length: 300, torch.Size([1, 1, 300, 64]) torch.Size([1, 1, 300, 64]) torch.Size([1, 1, 300, 64]) torch.Size([1, 1, 300, 64])
cos pos is equal: True
sin pos is equal: True
seq_length: 3000, torch.Size([1, 1, 3000, 64]) torch.Size([1, 1, 3000, 64]) torch.Size([1, 1, 3000, 64]) torch.Size([1, 1, 3000, 64])
cos pos is equal: True
sin pos is equal: True
seq_length: 6000, torch.Size([1, 1, 6000, 64]) torch.Size([1, 1, 6000, 64]) torch.Size([1, 1, 6000, 64]) torch.Size([1, 1, 6000, 64])
cos pos is equal: True
sin pos is equal: True

如果您认为这个值的确会导致问题,欢迎提供更solid的分析以帮助我们优化模型。
另外可以提供下您提到的IF能力的相关信息和测试方法吗?谢谢。

我好奇想问一下,这个13b-chat模型的上下文长度是多少呢?2K还是4K?

Tiger Research org

我好奇想问一下,这个13b-chat模型的上下文长度是多少呢?2K还是4K?

我们的训练数据长度是2K,但是由于RoPE有外推性,所以实际推理时候可以实际场景和显存情况来。模型见过的数据都是在2k以内,长于2k的效果建议实际测试一下。

i4never changed discussion status to closed

Sign up or log in to comment