Papers
arxiv:2307.02486

LongNet: Scaling Transformers to 1,000,000,000 Tokens

Published on Jul 5
ยท Featured in Daily Papers on Jul 6
Authors:
,
,
,
,
,

Abstract

Scaling sequence length has become a critical demand in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. In this work, we introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, we propose dilated attention, which expands the attentive field exponentially as the distance grows. LongNet has significant advantages: 1) it has a linear computation complexity and a logarithm dependency between tokens; 2) it can be served as a distributed trainer for extremely long sequences; 3) its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization. Experiments results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Our work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.

View arXiv page View PDF

Community

They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?

They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?

Good points. I think the next one will be Deformable Masked Attention

โ€œ Different from vanilla attention, both sizes of K and V are independent of the sequence length N, making the
communication cost constant.โ€
This sentence is doubtful.Show the provement that the K_i and V_i are independent of the sequence length N.These tensors' size are still related to sequence length even you did the dilation.

we should have a dislike button too.. don't you think?

Is there any model avaliable ?

This is happening for quite a sometime now. Using NLP in CV and CV in NLP.
End of the day its the math.

Hey, Im reviewing deep learning papers on twitter daily in Hebrew via hashtag #https://twitter.com/hashtag/shorthebrewpapereviews?src=hashtag_click. So far I've shortly reviewed about deep learning papers. You are invited to follow and comment

This paper review can be found at: https://twitter.com/MikeE_3_14/status/1676988738377744388?s=20

They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?

Is this so bad, as long as they cite CV papers? It's... arguably... how science ought to work?

No matter the source of their inspiration (deepmind always does this..), we want it on huggingface ASAP !

please gibe me model

Is "LongNet: Scaling Transformers to 1,000,000,000 Tokens" something like this?

import torch
import torch.nn as nn
from tqdm import tqdm

class CrossBar(nn.Module):
def init(self, dim, heads):
super().init()
self.dim = dim
self.heads = heads
self.crossbar_linear = nn.Linear(self.dim, self.dim * self.heads)
self.scale = nn.Parameter(torch.ones(1))

def forward(self, input):
    # reshaping input and scaling
    input = self.crossbar_linear(input).reshape(*input.shape[:-1], self.heads, -1)
    return self.scale * torch.gelu(input)

class DilatedMHAttention(nn.Module):
def init(self, dim, num_heads=8, qkv_bias=False, dilation_rates=[1]):
super().init()
self.dim = dim
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.dilation_rates = dilation_rates
self.crossbars = nn.ModuleList([CrossBar(dim, num_heads) for _ in dilation_rates])

def forward(self, x):
    # mapping tensor to each crossbar's dimension
    q, k, v = map(lambda t: t.view(*t.shape[:-1], self.num_heads, -1), (self.q(x), self.k(x), self.v(x)))

    # forwarding to each crossbar and outputting
    outputs = [crossbar((q * k).mean(dim=-1)) for crossbar, q, k in zip(self.crossbars, q.chunk(len(self.crossbars), dim=-2), k.chunk(len(self.crossbars), dim=-2))]
    return sum(outputs) / len(outputs)

class FeedForward(nn.Module):
def init(self, dim, hidden_dim, dropout=0.):
super().init()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
    return self.net(x)

class LongNet(nn.Module):
def init(self, dim, depth, heads, mlp_dim, num_classes, dilation_rates=None):
super().init()
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(dim),
DilatedMHAttention(dim, heads, dilation_rates=[dilation_rates[i]]),
nn.LayerNorm(dim),
FeedForward(dim, mlp_dim),
)
for i in range(depth)
])
self.classifier = nn.Linear(dim, num_classes)

def forward(self, x):
    try:
        for block in tqdm(self.blocks, desc='Progress:', bar_format='{l_bar}{bar} | {n_fmt}/{total_fmt}', ascii=False, dynamic_ncols=True):
            x = block(x) + x
        x = x.mean(dim=1)
        return self.classifier(x)
    except Exception as e:
        with open("errors.txt", "a", encoding="utf-8") as f:
            f.write(str(e) + "\n")
        print("Error occurred! Please check errors.txt file for details.")

@Emil-Zakirov here is an implementation you can refer https://github.com/kyegomez/LongNet

@unknownentity does is accept inputs_embeds like in huggingface ?

Sign up or log in to comment

Models citing this paper 1

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2307.02486 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2307.02486 in a Space README.md to link it from this page.