gururise commited on
Commit
e956bee
1 Parent(s): 466213b

add application file

Browse files
Files changed (2) hide show
  1. app.py +168 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import threading
3
+ import codecs
4
+ #from ast import literal_eval
5
+ from datetime import datetime
6
+
7
+ import os
8
+ os.environ['TRANSFORMERS_CACHE'] = '/data/.modelcache/huggingface/hub/'
9
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"
10
+
11
+ from transformers import BloomTokenizerFast
12
+ from petals.client import DistributedBloomForCausalLM
13
+ import torch
14
+ import gc
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ TORCH_DTYPE = torch.bfloat16
18
+ MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]
19
+
20
+ models = {}
21
+ output = {}
22
+
23
+
24
+ def gen_thread(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty):
25
+ global output
26
+ n_input_tokens = inputs.shape[1]
27
+ outputs = models[model_name][1].generate(inputs,
28
+ max_new_tokens=max_new_tokens,
29
+ min_length=min_length,
30
+ do_sample=True,
31
+ temperature=temperature,
32
+ top_p=top_p,
33
+ repetition_penalty=repetition_penalty
34
+ )
35
+ output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
36
+
37
+ def to_md(text):
38
+ # return text.replace("\n", "<br />")
39
+ return text.replace("\n", "<br />")
40
+
41
+ def infer(
42
+ prompt,
43
+ min_length=2,
44
+ max_new_tokens=10,
45
+ temperature=0.1,
46
+ top_p=1.0,
47
+ repetition_penalty = 1.0,
48
+ stop="\n",
49
+ num_completions=1,
50
+ seed=42,
51
+ ):
52
+
53
+ #gc.collect()
54
+ #torch.cuda.empty_cache()
55
+
56
+ if not models:
57
+ for model_name in MODEL_NAMES:
58
+ tokenizer = BloomTokenizerFast.from_pretrained(model_name)
59
+ model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE)
60
+ model = model.to(DEVICE)
61
+ models[model_name] = tokenizer, model
62
+
63
+ max_new_tokens = int(max_new_tokens)
64
+ num_completions = int(num_completions)
65
+ temperature = float(temperature)
66
+ top_p = float(top_p)
67
+ stop = stop.split(";")
68
+ repetition_penalty = float(repetition_penalty)
69
+ seed = seed
70
+
71
+ assert 1 <= max_new_tokens <= 384
72
+ assert 0 <= min_length <= max_new_tokens
73
+ assert 1 <= num_completions <= 5
74
+ assert 0.0 <= temperature <= 1.0
75
+ assert 0.0 <= top_p <= 1.0
76
+ assert 0.9 <= repetition_penalty <= 3.0
77
+
78
+ if temperature == 0.0:
79
+ temperature = 0.01
80
+ if prompt == "":
81
+ prompt = " "
82
+
83
+ threads = list()
84
+ print(f"START -> ({datetime.now()})\n")
85
+ print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
86
+ for model_name in MODEL_NAMES:
87
+ inputs = models[model_name][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
88
+ x = threading.Thread(target=gen_thread, args=(model_name, inputs, max_new_tokens, min_length, temperature, top_p, repetition_penalty))
89
+ threads.append(x)
90
+ x.start()
91
+ #n_input_tokens = inputs.shape[1]
92
+ # outputs = models[model_name][1].generate(inputs,
93
+ # max_new_tokens=max_new_tokens,
94
+ # min_length=min_length,
95
+ # do_sample=True,
96
+ # temperature=temperature,
97
+ # top_p=top_p,
98
+ # repetition_penalty=repetition_penalty
99
+ # )
100
+ #output[model_name] = models[model_name][0].decode(outputs[0, n_input_tokens:])
101
+
102
+ #output[model_name] = outputs[len(prompt):]
103
+
104
+ # Join Threads
105
+ for model_name, thread in enumerate(threads):
106
+ print(f"waiting on: {model_name}\n")
107
+ thread.join()
108
+ print(f"{model_name} thread done\n")
109
+
110
+
111
+ for model_name in MODEL_NAMES:
112
+ stop = codecs.getdecoder("unicode_escape")(stop[0])[0]
113
+ stop = [x.strip(' ') for x in stop.split(',')]
114
+ for stop_word in stop:
115
+ if stop_word != '' and stop_word in output[model_name]:
116
+ output[model_name] = output[model_name][:output[model_name].find(stop_word)]
117
+
118
+ print(f"--- START: {model_name} --- \n{output[model_name]}\n--- END {model_name} ---\n\n")
119
+
120
+ print(f"DONE -> ({datetime.now()})\n")
121
+ return output[MODEL_NAMES[0]], output[MODEL_NAMES[1]]
122
+
123
+
124
+ examples = [
125
+ [
126
+ # Question Answering
127
+ '''Please answer the following question:
128
+ Question: What is the capital of Germany?
129
+ Answer:''', 1, 3, 0.2, 1.0, 1.0, "\\n,</s>"],
130
+ [
131
+ # Natural Language Interface
132
+ '''Given a pair of sentences, choose whether the two sentences agree (entailment)/disagree (contradiction) with each other.
133
+ Possible labels: 1. entailment 2. contradiction
134
+ Sentence 1: The skier was on the edge of the ramp. Sentence 2: The skier was dressed in winter clothes.
135
+ Label: entailment
136
+ Sentence 1: The boy skated down the staircase railing. Sentence 2: The boy is a newbie skater.
137
+ Label: contradiction
138
+ Sentence 1: Two middle-aged people stand by a golf hole. Sentence 2: A couple riding in a golf cart.
139
+ Label:''', 1, 2, 0.2, 1.0, 1.0, "\\n,</s>"]
140
+ ]
141
+
142
+
143
+ def main():
144
+ iface = gr.Interface(
145
+ fn=infer,
146
+ allow_flagging="never",
147
+ inputs=[
148
+ gr.Textbox(lines=20), # prompt
149
+ gr.Slider(0, 256, value=1), #min_length
150
+ gr.Slider(1, 384, value=20), # max_tokens
151
+ gr.Slider(0.0, 1.0, value=0.2), # temperature
152
+ gr.Slider(0.0, 1.0, value=0.9), # top_p
153
+ gr.Slider(0.9, 3.0, value=1.0), # repetition penalty
154
+ gr.Textbox(lines=1, value="\\n,</s>") # stop
155
+ ],
156
+ outputs=[gr.Textbox(lines=7, label="BLOOM OUTPUT:"), gr.Textbox(lines=7,label="BLOOMZ OUTPUT:")],
157
+
158
+ examples=examples,
159
+ cache_examples=True,
160
+ title="BLOOM vs BLOOMZ",
161
+ description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the [Petals](https://petals.ml/) network. Please consider joining the Petals network to help speed up inference.</p><p>Big thanks to [RFTCapital](https://www.rftcapital.com) for providing initial compute resources.</p>'''
162
+ )
163
+
164
+ iface.launch(debug=True, share=False)
165
+
166
+
167
+ if __name__ == '__main__':
168
+ main()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ threadpoolctl==3.1.0
2
+ transformers==4.25.1