TwitterAccounts / scripts /generative.py
aus10powell's picture
Upload 74 files
8158335
raw history blame
No virus
2.33 kB
"""
Train a language model:
Use PyTorch and spaCy to train a language model on the preprocessed tweet data.
Load a pre-trained language model from spaCy and fine-tune it on your tweet data using PyTorch. Here's an example code snippet:
"""
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import os
def generate_account_text(prompt, model_dir, num_return_sequences=5):
if not os.path.exists(model_dir):
print("****************** ERROR **************************")
print(f"Error: {model_dir} does not exist.")
print("****************** ERROR **************************")
return f"Error: {model_dir} does not exist."
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(model_dir)
# Prepend the BOS (beginning of sequence) token to the prompt
start_with_bos = "<|endoftext|>" + prompt
# Encode the prompt using the trainer's tokenizer and convert to PyTorch tensor
encoded_prompt = tokenizer(
start_with_bos, add_special_tokens=False, return_tensors="pt"
).input_ids
encoded_prompt = encoded_prompt.to(model.device)
# Generate sequences using the encoded prompt as input
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=200,
min_length=10,
temperature=0.85,
top_p=0.95,
do_sample=True,
num_return_sequences=num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
)
# Set a flag for whether new lines are allowed in the generated text
ALLOW_NEW_LINES = False
# Decode the generated sequences and store them in a list of dictionaries
generated_sequences = []
# decode prediction
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(
generated_sequence,
clean_up_tokenization_spaces=True,
skip_special_tokens=True,
)
if not ALLOW_NEW_LINES:
limit = text.find("\n")
text = text[: limit if limit != -1 else None]
generated_sequences.append({"prompt": prompt, "generated_text": text.strip()})
return generated_sequences