MatsRooth commited on
Commit
f7e7fda
1 Parent(s): bccf302

Upload strings_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. strings_chat.py +144 -0
strings_chat.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Portions copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ # Connect to a compute node interactively
5
+ # srun --partition=gpu-interactive --gpus=a5000:1 --mem=16000 --pty /bin/bash
6
+ # source env/hugh/bin/activate
7
+ # cd /share/compling/speech/llama_tuning
8
+
9
+ # Run it
10
+ # torchrun --nproc_per_node 1 strings_chat.py --ckpt_dirllama-2-7b-chat/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6 --benchmark_path benchmark/cl23.txt
11
+ # torchrun --nproc_per_node 1 strings_chat.py --ckpt_dir llama-2-7b-chat/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6 --benchmark_path benchmark/cl23.txt --tuple
12
+
13
+ from typing import List, Optional
14
+
15
+ import fire
16
+
17
+ from llama import Llama, Dialog
18
+
19
+ import sys
20
+ from ast import literal_eval
21
+
22
+ import json
23
+
24
+ def tuple2dialog(x):
25
+ systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"}
26
+ userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"}
27
+ return [userprompt]
28
+
29
+ # The system prompt is not needed, the llama assistant already starts with True/False and
30
+ # includes an explanation of suitable scope.
31
+ # return [systemprompt,userprompt]
32
+
33
+ # ('jeremy_a_2_ab', 'every letter is a consonant', 'ab', 'f')
34
+ def tuple_add_dialog(x):
35
+ systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"}
36
+ userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"}
37
+ return (x[0],x[1],x[2],x[3],[userprompt])
38
+
39
+ def main(
40
+ ckpt_dir: str,
41
+ tokenizer_path: str,
42
+ benchmark_path: str,
43
+ temperature: float = 0.6,
44
+ top_p: float = 0.9,
45
+ max_seq_len: int = 512,
46
+ max_batch_size: int = 8,
47
+ max_gen_len: Optional[int] = None,
48
+ tuple: Optional[bool] = False,
49
+ ):
50
+ """
51
+ Entry point of the program for generating text using a pretrained model.
52
+
53
+ Args:
54
+ ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
55
+ tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
56
+ benchmark_path (str): The path to the benchmark e.g. benchmark/cl23.txt.
57
+ temperature (float, optional): The temperature value for controlling randomness in generation.
58
+ Defaults to 0.6.
59
+ top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
60
+ Defaults to 0.9.
61
+ max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
62
+ max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
63
+ max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
64
+ set to the model's max sequence length. Defaults to None.
65
+ """
66
+
67
+ # Need to work out how to add this to the arguments
68
+ # benchmark_file = '/share/compling/speech/llama/benchmark/cl23.txt'
69
+
70
+ # This is iterable
71
+ # benchmark_stream = map(lambda z:literal_eval(z), open(benchmark_path))
72
+ # Include the dialog prompt in the tuple
73
+ benchmark_stream = map(lambda z:tuple_add_dialog(literal_eval(z)), open(benchmark_path))
74
+
75
+ # Bunch into groups of five
76
+ benchmark_by_5 = zip(*(benchmark_stream,) * 5)
77
+
78
+
79
+ # Generator for lists of five dialogs, with 100 elements covering 500 items
80
+ '''
81
+ def gen_dialogs():
82
+ for x in zip(range(100),benchmark_by_5):
83
+ dialog = map(tuple2dialog,x[1])
84
+ yield list(dialog)
85
+ '''
86
+
87
+ # Generator for lists of five dialog tuples, with 100 elements covering 500 items
88
+ def gen_dialog_tuples():
89
+ for x in zip(range(100),benchmark_by_5):
90
+ yield list(x[1])
91
+
92
+ generator = Llama.build(
93
+ ckpt_dir=ckpt_dir,
94
+ tokenizer_path=tokenizer_path,
95
+ max_seq_len=max_seq_len,
96
+ max_batch_size=max_batch_size,
97
+ )
98
+
99
+ #for dialogs in gen_dialogs():
100
+ # results = generator.chat_completion(
101
+ # dialogs, # type: ignore
102
+ # max_gen_len=max_gen_len,
103
+ # temperature=temperature,
104
+ # top_p=top_p,
105
+ # )
106
+ # for dialog, result in zip(dialogs, results):
107
+ # for msg in dialog:
108
+ # print(f"{msg['role'].capitalize()}: {msg['content']}\n")
109
+ # print(
110
+ # f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
111
+ # )
112
+ # print("\n==================================\n")
113
+
114
+
115
+ for dtuple in gen_dialog_tuples():
116
+ dialogs = [z[4] for z in dtuple] # Map out a list of five dialog promts
117
+ results = generator.chat_completion(
118
+ dialogs, # type: ignore
119
+ max_gen_len=max_gen_len,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ )
123
+ for tpl, result in zip(dtuple, results):
124
+ if tuple:
125
+ t0 = tpl[0]
126
+ t1 = tpl[1]
127
+ t2 = tpl[2]
128
+ t3 = tpl[3]
129
+ t4 = tpl[4]
130
+ # json.dumps escapes double quotes and newlines
131
+ t5 = json.dumps(result['generation']['content'])
132
+ # Need to check that quotes are coming out right, they arent
133
+ print(f'("{t0}","{t1}","{t2}","{t3}",{t4},"{t5}")')
134
+ else:
135
+ # The below needs adjustment
136
+ print(f"{msg['role'].capitalize()}: {msg['content']}\n")
137
+ print(
138
+ f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
139
+ )
140
+ print("\n==================================\n")
141
+
142
+
143
+ if __name__ == "__main__":
144
+ fire.Fire(main)