nmaina commited on
Commit
d4fe87d
1 Parent(s): 740c0a2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run a chatbot with FlexGen and OPT models."""
2
+ import argparse
3
+
4
+ from transformers import AutoTokenizer
5
+ from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice,
6
+ CompressionConfig, Env, Task, get_opt_config)
7
+
8
+
9
+ def main(args):
10
+ # Initialize environment
11
+ gpu = TorchDevice("cuda:0")
12
+ cpu = TorchDevice("cpu")
13
+ disk = TorchDisk(args.offload_dir)
14
+ env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
15
+
16
+ # Offloading policy
17
+ policy = Policy(1, 1,
18
+ args.percent[0], args.percent[1],
19
+ args.percent[2], args.percent[3],
20
+ args.percent[4], args.percent[5],
21
+ overlap=True, sep_layer=True, pin_weight=True,
22
+ cpu_cache_compute=False, attn_sparsity=1.0,
23
+ compress_weight=args.compress_weight,
24
+ comp_weight_config=CompressionConfig(
25
+ num_bits=4, group_size=64,
26
+ group_dim=0, symmetric=False),
27
+ compress_cache=args.compress_cache,
28
+ comp_cache_config=CompressionConfig(
29
+ num_bits=4, group_size=64,
30
+ group_dim=2, symmetric=False))
31
+
32
+ # Model
33
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", padding_side="left")
34
+ tokenizer.add_bos_token = False
35
+ stop = tokenizer("\n").input_ids[0]
36
+
37
+ print("Initialize...")
38
+ opt_config = get_opt_config(args.model)
39
+ model = OptLM(opt_config, env, args.path, policy)
40
+ model.init_all_weights()
41
+
42
+ context = (
43
+ "A chat between a curious human and a knowledgeable artificial intelligence assistant.\n"
44
+ "Human: Hello! What can you do?\n"
45
+ "Assistant: As an AI assistant, I can answer questions and chat with you.\n"
46
+ "Human: What is the name of the tallest mountain in the world?\n"
47
+ "Assistant: Everest.\n"
48
+ )
49
+
50
+ # Chat
51
+ print(context, end="")
52
+ while True:
53
+ inp = input("Human: ")
54
+ if not inp:
55
+ print("exit...")
56
+ break
57
+
58
+ context += "Human: " + inp + "\n"
59
+ inputs = tokenizer([context])
60
+ output_ids = model.generate(
61
+ inputs.input_ids,
62
+ do_sample=True,
63
+ temperature=0.7,
64
+ max_new_tokens=96,
65
+ stop=stop)
66
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
67
+ try:
68
+ index = outputs.index("\n", len(context))
69
+ except ValueError:
70
+ outputs += "\n"
71
+ index = outputs.index("\n", len(context))
72
+
73
+ outputs = outputs[:index + 1]
74
+ print(outputs[len(context):], end="")
75
+ context = outputs
76
+
77
+ # TODO: optimize the performance by reducing redundant computation.
78
+
79
+ # Shutdown
80
+ model.delete_all_weights()
81
+ disk.close_copy_threads()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument("--model", type=str, default="facebook/opt-6.7b",
87
+ help="The model name.")
88
+ parser.add_argument("--path", type=str, default="~/opt_weights",
89
+ help="The path to the model weights. If there are no cached weights, "
90
+ "FlexGen will automatically download them from HuggingFace.")
91
+ parser.add_argument("--offload-dir", type=str, default="~/flexgen_offload_dir",
92
+ help="The directory to offload tensors. ")
93
+ parser.add_argument("--percent", nargs="+", type=int,
94
+ default=[100, 0, 100, 0, 100, 0],
95
+ help="Six numbers. They are "
96
+ "the percentage of weight on GPU, "
97
+ "the percentage of weight on CPU, "
98
+ "the percentage of attention cache on GPU, "
99
+ "the percentage of attention cache on CPU, "
100
+ "the percentage of activations on GPU, "
101
+ "the percentage of activations on CPU")
102
+ parser.add_argument("--compress-weight", action="store_true",
103
+ help="Whether to compress weight.")
104
+ parser.add_argument("--compress-cache", action="store_true",
105
+ help="Whether to compress cache.")
106
+ args = parser.parse_args()
107
+
108
+ assert len(args.percent) == 6
109
+
110
+ main(args)