#!/usr/bin/env python3 import os from transformers import AutoTokenizer, GPT2Tokenizer from megatron.initialize import initialize_megatron from metaseq import checkpoint_utils from transformers import OPTForCausalLM import torch path = "./model" # just need to initialize args with something, # => doesn't need to correspond to the "correct" architecture for this checkpoint initialize_megatron(args_defaults={ "micro_batch_size": 1, "num_layers": 12, "hidden_size": 768, "num_attention_heads": 12, "max_position_embeddings": 2048, "encoder_seq_length": 2048 }) vocab_file = os.path.join(path, "gpt2-vocab.json") merges_file = os.path.join(path, "gpt2-merges.txt") tokenizer = GPT2Tokenizer(vocab_file, merges_file) tokenizer.save_pretrained(path) checkpoint = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(path, "restored.pt")], arg_overrides={ "vocab_filename": vocab_file, "merges_filename": merges_file, } ) model = checkpoint[0][0].eval() model = model.to("cuda:0").half() hf_model = OPTForCausalLM.from_pretrained("../opt-350m").to("cuda:1").half() # forward passes def single_batch_forward_logits(prompts): input_ids = tokenizer(prompts, return_tensors="pt").input_ids input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) input_ids = input_ids.to("cuda:0") with torch.no_grad(): logits = model(input_ids)[0] return logits # forward hf def forward_hf(prompts): input_ids = tokenizer(prompts, return_tensors="pt").input_ids input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) input_ids = input_ids.to("cuda:1") with torch.no_grad(): logits = hf_model(input_ids)[0] return logits prompts = [ "Today is a beautiful day and I want to", "In the city of", "Paris is the capital of France and", "Computers and mobile phones have taken", ] prompts = [ "Today is a beautiful day and I want to", ] #import ipdb; ipdb.set_trace() print("Next word generation") for prompt in prompts: print("-------------") print(f"Prompt: {prompt}...\n") logits_fsq = single_batch_forward_logits(prompt) pred_next_token = torch.argmax(logits_fsq[0, -1], -1) next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) next_token = next_token[0].replace("Ġ", "") print(f"Next word: {next_token}") print("-------------") logits = forward_hf(prompt) pred_next_token = torch.argmax(logits[0, -1], -1) next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) next_token = next_token[0].replace("Ġ", "") print(f"Next word: {next_token}") print("-------------") torch.allclose(logits_fsq.cpu(), logits.cpu(), atol=1e-3)