File size: 5,854 Bytes
f7e7fda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Portions copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Connect to a compute node interactively
# srun --partition=gpu-interactive --gpus=a5000:1 --mem=16000 --pty /bin/bash
# source env/hugh/bin/activate
# cd /share/compling/speech/llama_tuning

# Run it
# torchrun --nproc_per_node 1 strings_chat.py --ckpt_dirllama-2-7b-chat/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6 --benchmark_path benchmark/cl23.txt
# torchrun --nproc_per_node 1 strings_chat.py --ckpt_dir llama-2-7b-chat/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6 --benchmark_path benchmark/cl23.txt --tuple

from typing import List, Optional

import fire

from llama import Llama, Dialog

import sys
from ast import literal_eval

import json

def tuple2dialog(x):
    systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"}
    userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"}
    return [userprompt]

# The system prompt is not needed, the llama assistant already starts with True/False and
# includes an explanation of suitable scope.
# return [systemprompt,userprompt]

# ('jeremy_a_2_ab', 'every letter is a consonant', 'ab', 'f')
def tuple_add_dialog(x):
        systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"}
        userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"}
        return (x[0],x[1],x[2],x[3],[userprompt])

def main(
    ckpt_dir: str,
    tokenizer_path: str,
    benchmark_path: str,        
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 512,
    max_batch_size: int = 8,
    max_gen_len: Optional[int] = None,
    tuple: Optional[bool] = False,        
):
    """
    Entry point of the program for generating text using a pretrained model.

    Args:
        ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
        tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
        benchmark_path (str): The path to the benchmark e.g. benchmark/cl23.txt.
        temperature (float, optional): The temperature value for controlling randomness in generation.
            Defaults to 0.6.
        top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
            Defaults to 0.9.
        max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
        max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
        max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
            set to the model's max sequence length. Defaults to None.
    """

    # Need to work out how to add this to the arguments
    # benchmark_file = '/share/compling/speech/llama/benchmark/cl23.txt'
    
    # This is iterable
    # benchmark_stream =  map(lambda z:literal_eval(z), open(benchmark_path))
    # Include the dialog prompt in the tuple
    benchmark_stream =  map(lambda z:tuple_add_dialog(literal_eval(z)), open(benchmark_path))
    
    # Bunch into groups of five
    benchmark_by_5 = zip(*(benchmark_stream,) * 5)
    
    
    # Generator for lists of five dialogs, with 100 elements covering 500 items
    '''
    def gen_dialogs():
        for x in zip(range(100),benchmark_by_5):
            dialog = map(tuple2dialog,x[1])
            yield list(dialog)    
    '''

    # Generator for lists of five dialog tuples, with 100 elements covering 500 items
    def gen_dialog_tuples():
        for x in zip(range(100),benchmark_by_5):
            yield list(x[1])    

    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )

    #for dialogs in gen_dialogs():
    #    results = generator.chat_completion(
    #        dialogs,  # type: ignore
    #        max_gen_len=max_gen_len,
    #        temperature=temperature,
    #        top_p=top_p,
    #    )
    #    for dialog, result in zip(dialogs, results):
    #        for msg in dialog:
    #            print(f"{msg['role'].capitalize()}: {msg['content']}\n")
    #            print(
    #                f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
    #            )
    #            print("\n==================================\n")

    
    for dtuple in gen_dialog_tuples():
        dialogs = [z[4] for z in dtuple]  # Map out a list of five dialog promts
        results = generator.chat_completion(
            dialogs,  # type: ignore
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
        )
        for tpl, result in zip(dtuple, results):
            if tuple:
                t0 = tpl[0]
                t1 = tpl[1]
                t2 = tpl[2]
                t3 = tpl[3]
                t4 = tpl[4]
                # json.dumps escapes double quotes and newlines
                t5 = json.dumps(result['generation']['content'])
                # Need to check that quotes are coming out right, they arent
                print(f'("{t0}","{t1}","{t2}","{t3}",{t4},"{t5}")')
            else:
                # The below needs adjustment
                print(f"{msg['role'].capitalize()}: {msg['content']}\n")
                print(
                    f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
                )
                print("\n==================================\n")                


if __name__ == "__main__":
    fire.Fire(main)