|
|
|
|
|
import os |
|
from transformers import AutoTokenizer, GPT2Tokenizer |
|
from megatron.initialize import initialize_megatron |
|
from metaseq import checkpoint_utils |
|
import torch |
|
|
|
path = "./model" |
|
|
|
|
|
|
|
initialize_megatron(args_defaults={ |
|
"micro_batch_size": 1, |
|
"num_layers": 12, |
|
"hidden_size": 768, |
|
"num_attention_heads": 12, |
|
"max_position_embeddings": 2048, |
|
"encoder_seq_length": 2048 |
|
}) |
|
|
|
vocab_file = os.path.join(path, "gpt2-vocab.json") |
|
merges_file = os.path.join(path, "gpt2-merges.txt") |
|
|
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file) |
|
tokenizer.save_pretrained(path) |
|
|
|
checkpoint = checkpoint_utils.load_model_ensemble_and_task( |
|
[os.path.join(path, "restored.pt")], |
|
arg_overrides={ |
|
"vocab_filename": vocab_file, |
|
"merges_filename": merges_file, |
|
} |
|
) |
|
|
|
model = checkpoint[0][0].eval() |
|
model = model.cuda().half() |
|
|
|
|
|
|
|
def single_batch_forward_logits(prompts): |
|
input_ids = tokenizer(prompts, return_tensors="pt").input_ids |
|
input_ids = torch.cat([torch.tensor([[0]]), input_ids], dim=-1) |
|
input_ids = input_ids.cuda() |
|
with torch.no_grad(): |
|
logits = model(input_ids)[0] |
|
return logits |
|
|
|
prompts = [ |
|
"Today is a beautiful day and I want to", |
|
"In the city of", |
|
"Paris is the capital of France and", |
|
"Computers and mobile phones have taken", |
|
] |
|
|
|
print("Next word generation") |
|
for prompt in prompts: |
|
print("-------------") |
|
print(f"Prompt: {prompt}...\n") |
|
logits = single_batch_forward_logits(prompt) |
|
pred_next_token = torch.argmax(logits[0, -1], -1) |
|
next_token = tokenizer.convert_ids_to_tokens([pred_next_token]) |
|
next_token = next_token[0].replace("Ġ", "") |
|
print(f"Next word: {next_token}") |
|
print("-------------") |
|
|