Spaces:
Runtime error
Runtime error
""" | |
Train a language model: | |
Use PyTorch and spaCy to train a language model on the preprocessed tweet data. | |
Load a pre-trained language model from spaCy and fine-tune it on your tweet data using PyTorch. Here's an example code snippet: | |
""" | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import os | |
def generate_account_text(prompt, model_dir, num_return_sequences=5): | |
if not os.path.exists(model_dir): | |
print("****************** ERROR **************************") | |
print(f"Error: {model_dir} does not exist.") | |
print("****************** ERROR **************************") | |
return f"Error: {model_dir} does not exist." | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForCausalLM.from_pretrained(model_dir) | |
# Prepend the BOS (beginning of sequence) token to the prompt | |
start_with_bos = "<|endoftext|>" + prompt | |
# Encode the prompt using the trainer's tokenizer and convert to PyTorch tensor | |
encoded_prompt = tokenizer( | |
start_with_bos, add_special_tokens=False, return_tensors="pt" | |
).input_ids | |
encoded_prompt = encoded_prompt.to(model.device) | |
# Generate sequences using the encoded prompt as input | |
output_sequences = model.generate( | |
input_ids=encoded_prompt, | |
max_length=200, | |
min_length=10, | |
temperature=0.85, | |
top_p=0.95, | |
do_sample=True, | |
num_return_sequences=num_return_sequences, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
# Set a flag for whether new lines are allowed in the generated text | |
ALLOW_NEW_LINES = False | |
# Decode the generated sequences and store them in a list of dictionaries | |
generated_sequences = [] | |
# decode prediction | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode( | |
generated_sequence, | |
clean_up_tokenization_spaces=True, | |
skip_special_tokens=True, | |
) | |
if not ALLOW_NEW_LINES: | |
limit = text.find("\n") | |
text = text[: limit if limit != -1 else None] | |
generated_sequences.append({"prompt": prompt, "generated_text": text.strip()}) | |
return generated_sequences | |