from transformers import AutoTokenizer | |
from retnet.modeling_retnet import RetNetForCausalLM | |
model = RetNetForCausalLM.from_pretrained("./") | |
tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
tokenizer.model_max_length = 16384 | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.unk_token = tokenizer.eos_token | |
tokenizer.bos_token = tokenizer.eos_token | |
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") | |
# Generate output with max_length parameter | |
generation_output = model.generate(**inputs, max_length=50) # Adjust max_length as needed | |
output = tokenizer.decode(generation_output[0], skip_special_tokens=True) | |
print(output) |