BERTIN-GPT-J-6B with 8-bit weights (Quantized)

Go here to use the latest checkpoint.

This model (and model card) is an adaptation of hivemind/gpt-j-6B-8bit, so all credits to him/her.

This is a version of bertin-project/bertin-gpt-j-6B that is modified so you can generate and fine-tune the model in colab or equivalent desktop GPU (e.g. single 1080Ti).

Here's how to run it: colab

The original GPT-J takes 22+ GB memory for float32 parameters alone, and that's before you account for gradients & optimizer. Even if you cast everything to 16-bit, it will still not fit onto most single-GPU setups short of A6000 and A100. You can inference it on TPU or CPUs, but fine-tuning is way more expensive.

Here, we apply several techniques to make GPT-J usable and fine-tunable on a single GPU with ~11 GB memory:

  • large weight tensors are quantized using dynamic 8-bit quantization and de-quantized just-in-time for multiplication
  • using gradient checkpoints to store one only activation per layer: using dramatically less memory at the cost of 30% slower training
  • scalable fine-tuning with LoRA and 8-bit Adam

In other words, all of the large weight-matrices are frozen in 8-bit, and you only train small adapters and optionally 1d tensors (layernorm scales, biases).

img

Does 8-bit affect model quality? Technically yes, but the effect is negligible in practice. This notebook measures wikitext test perplexity and it is nigh indistinguishable from the original GPT-J. Quantized model is even slightly better, but that is not statistically significant.

Our code differs from other 8-bit methods in that we use 8-bit only for storage, and all computations are performed in float16 or float32. As a result, we can take advantage of nonlinear quantization that fits to each individual weight distribution. Such nonlinear quantization does not accelerate inference, but it allows for much smaller error.

What about performance? Both checkpointing and de-quantization has some overhead, but it's surprisingly manageable. Depending on GPU and batch size, the quantized model is 1-10% slower than the original model on top of using gradient checkpoints (which is 30% overhead). In short, this is because block-wise quantization from bitsandbytes is really fast on GPU.

How should I fine-tune the model?

We recommend starting with the original hyperparameters from the LoRA paper. On top of that, there is one more trick to consider: the overhead from de-quantizing weights does not depend on batch size. As a result, the larger batch size you can fit, the more efficient you will train.

Where can I train for free?

You can train fine in colab, but if you get a K80, it's probably best to switch to other free gpu providers: kaggle, aws sagemaker or paperspace. For intance, this is the same notebook running in kaggle using a more powerful P100 instance.

Can I use this technique with other models?

The model was converted using this notebook. It can be adapted to work with other model types. However, please bear in mind that some models replace Linear and Embedding with custom alternatives that require their own BNBWhateverWithAdapters.

How to use

wget https://huggingface.co/mrm8488/bertin-gpt-j-6B-ES-8bit/resolve/main/utils.py -O Utils.py
pip install transformers
pip install bitsandbytes-cuda111==0.26.0
import transformers
import torch

from Utils import GPTJBlock, GPTJForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"

transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

ckpt = "mrm8488/bertin-gpt-j-6B-ES-8bit"

tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt)
model = GPTJForCausalLM.from_pretrained(ckpt, pad_token_id=tokenizer.eos_token_id, low_cpu_mem_usage=True).to(device)


prompt = tokenizer("El sentido de la vida es", return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}

out = model.generate(**prompt, max_length=64, do_sample=True)

print(tokenizer.decode(out[0]))
Downloads last month
31
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using mrm8488/bertin-gpt-j-6B-ES-8bit 1