Checkpoints of facebookresearch/coconut runing on a A100 40GB.

Logs are available at wandb.

Quickstart:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import torch
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from coconut import Coconut

def main():
    load_model_path = "save_models/gsm-coconut/checkpoint_22"
    model_id = "openai-community/gpt2"
    # load the configuration file
    print(f"Loading from {load_model_path}.")

    model = AutoModelForCausalLM.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_tokens("<|start-latent|>")
    tokenizer.add_tokens("<|end-latent|>")
    tokenizer.add_tokens("<|latent|>")
    latent_id = tokenizer.convert_tokens_to_ids("<|latent|>")
    start_id = tokenizer.convert_tokens_to_ids("<|start-latent|>")
    end_id = tokenizer.convert_tokens_to_ids("<|end-latent|>")

    saved_weights = torch.load(
        load_model_path, map_location=torch.device("cuda")
    )
    
    model.resize_token_embeddings(len(tokenizer))
    embeddings = model.get_input_embeddings()
    target_id = tokenizer.convert_tokens_to_ids("<<")
    # initialize the new token embeddings with a known token
    # it helps stablize the training
    for token_id in [latent_id, start_id, end_id]:
        target_embedding = embeddings.weight.data[token_id]
        embeddings.weight.data[token_id] = target_embedding
        # The input embeddings and lm heads are tied in GPT2. So the code below is not necessary
        lm_head = model.lm_head
        lm_head.weight.data[token_id] = lm_head.weight.data[target_id]

    model = Coconut(model, latent_id, start_id, end_id, tokenizer.eos_token_id)
    print(model.load_state_dict(saved_weights, strict=False))
    model = model.to("cuda")
    
    prompt = "Sally received the following scores on her math quizzes: 50, 80, 80. Find her mean score."
    prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
    output = model.generate(
        **prompt,
        max_new_tokens=20
    )
    for i, o in enumerate(output):
        print(f"Output {i}: {tokenizer.decode(o, skip_special_tokens=True)}")
    
if __name__ == "__main__":
    main()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support