Checkpoints of facebookresearch/coconut runing on a A100 40GB.
Logs are available at wandb.
Quickstart:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import torch
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from coconut import Coconut
def main():
load_model_path = "save_models/gsm-coconut/checkpoint_22"
model_id = "openai-community/gpt2"
# load the configuration file
print(f"Loading from {load_model_path}.")
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens("<|start-latent|>")
tokenizer.add_tokens("<|end-latent|>")
tokenizer.add_tokens("<|latent|>")
latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")
saved_weights = torch.load(
load_model_path, map_location=torch.device("cuda")
)
model.resize_token_embeddings(len(tokenizer))
embeddings = model.get_input_embeddings()
target_id = tokenizer.convert_tokens_to_ids("<<")
# initialize the new token embeddings with a known token
# it helps stablize the training
for token_id in [latent_id, start_id, end_id]:
target_embedding = embeddings.weight.data[token_id]
embeddings.weight.data[token_id] = target_embedding
# The input embeddings and lm heads are tied in GPT2. So the code below is not necessary
lm_head = model.lm_head
lm_head.weight.data[token_id] = lm_head.weight.data[target_id]
model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id)
print(model.load_state_dict(saved_weights, strict=False))
model = model.to("cuda")
prompt = "Sally received the following scores on her math quizzes: 50, 80, 80. Find her mean score."
prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(
**prompt,
max_new_tokens=20
)
for i, o in enumerate(output):
print(f"Output {i}: {tokenizer.decode(o, skip_special_tokens=True)}")
if __name__ == "__main__":
main()
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
HF Inference deployability: The model has no library tag.