File size: 2,330 Bytes
8158335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Train a language model:

Use PyTorch and spaCy to train a language model on the preprocessed tweet data.
Load a pre-trained language model from spaCy and fine-tune it on your tweet data using PyTorch. Here's an example code snippet:
"""

from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import os


def generate_account_text(prompt, model_dir, num_return_sequences=5):

    if not os.path.exists(model_dir):
        print("****************** ERROR **************************")
        print(f"Error: {model_dir} does not exist.")
        print("****************** ERROR **************************")
        return f"Error: {model_dir} does not exist."

    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForCausalLM.from_pretrained(model_dir)

    # Prepend the BOS (beginning of sequence) token to the prompt
    start_with_bos = "<|endoftext|>" + prompt

    # Encode the prompt using the trainer's tokenizer and convert to PyTorch tensor
    encoded_prompt = tokenizer(
        start_with_bos, add_special_tokens=False, return_tensors="pt"
    ).input_ids
    encoded_prompt = encoded_prompt.to(model.device)

    # Generate sequences using the encoded prompt as input
    output_sequences = model.generate(
        input_ids=encoded_prompt,
        max_length=200,
        min_length=10,
        temperature=0.85,
        top_p=0.95,
        do_sample=True,
        num_return_sequences=num_return_sequences,
        pad_token_id=tokenizer.eos_token_id,
    )
    # Set a flag for whether new lines are allowed in the generated text
    ALLOW_NEW_LINES = False

    # Decode the generated sequences and store them in a list of dictionaries
    generated_sequences = []
    # decode prediction
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(
            generated_sequence,
            clean_up_tokenization_spaces=True,
            skip_special_tokens=True,
        )
        if not ALLOW_NEW_LINES:
            limit = text.find("\n")
            text = text[: limit if limit != -1 else None]
        generated_sequences.append({"prompt": prompt, "generated_text": text.strip()})

    return generated_sequences