Spaces:
Runtime error
Runtime error
Change SDK
Browse files
README.md
CHANGED
@@ -3,8 +3,8 @@ title: Bloom Chat
|
|
3 |
emoji: ⚡
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: openrail
|
|
|
3 |
emoji: ⚡
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.10.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: openrail
|
app.py
CHANGED
@@ -1,27 +1,104 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
# text = st.text_area("Prefix", value="DM: You enter the room.")
|
4 |
-
# batch = st.number_input("Variants", value=5)
|
5 |
-
# st.markdown(f"{text} {batch}")
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
|
16 |
-
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
col1.image(image, use_column_width=True)
|
23 |
-
predictions = pipeline(image)
|
24 |
|
25 |
-
|
26 |
-
for p in predictions:
|
27 |
-
col2.subheader(f"{ p['label'] }: { round(p['score'] * 100, 1)}%")
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
import time
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
import streamlit as st
|
6 |
+
import copy
|
7 |
+
from transformers import AutoConfig, GPTJForCausalLM
|
8 |
+
from transformers.models.gptj.modeling_gptj import GPTJBlock
|
9 |
+
from tqdm import trange
|
10 |
|
|
|
|
|
|
|
11 |
|
12 |
+
@st.cache(allow_output_mutation=True)
|
13 |
+
def load_model():
|
14 |
+
for down in trange(1, disable=True):
|
15 |
+
fpath = snapshot_download("OpenDungeon/gpt-j-8bit-ffbgem", revision="separate")
|
16 |
+
config = AutoConfig.from_pretrained("EleutherAI/gpt-j-6B")
|
17 |
+
qconfig = torch.quantization.get_default_qconfig('fbgemm')
|
18 |
+
torch.backends.quantized.engine = 'fbgemm'
|
19 |
+
n_layer, config.n_layer = config.n_layer, 0
|
20 |
|
21 |
+
model = GPTJForCausalLM(config)
|
22 |
+
model.load_state_dict(torch.load(fpath + "/blocks/base.pt"))
|
23 |
+
ref_block = torch.quantization.quantize_dynamic(
|
24 |
+
GPTJBlock(config),
|
25 |
+
{torch.nn.Linear: qconfig},
|
26 |
+
dtype=torch.qint8,
|
27 |
+
inplace=True
|
28 |
+
)
|
29 |
+
|
30 |
+
for i in trange(n_layer):
|
31 |
+
new_block = copy.deepcopy(ref_block)
|
32 |
+
new_block.load_state_dict(torch.load(f"{fpath}/blocks/block{i}.pt"))
|
33 |
+
model.transformer.h.append(new_block)
|
34 |
+
|
35 |
+
config.n_layer = len(model.transformer.h)
|
36 |
+
del ref_block
|
37 |
+
|
38 |
+
return transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B"), model
|
39 |
+
|
40 |
+
|
41 |
+
def PrintContinuation(prompt, local_model, single_hook=None, batch=1, limit_tokens = 50):
|
42 |
+
past_key_values = None # used to keep track of conversation history
|
43 |
+
input_dict = tokenizer([prompt] * batch, return_tensors='pt', padding=False)
|
44 |
+
output = [""] * batch
|
45 |
+
batch_time = 0
|
46 |
+
|
47 |
+
with torch.inference_mode():
|
48 |
+
for i in range(limit_tokens + 20):
|
49 |
+
if i == 5:
|
50 |
+
start_time = time.perf_counter()
|
51 |
+
|
52 |
+
outputs = local_model.forward(**input_dict, use_cache=True, past_key_values=past_key_values)
|
53 |
+
last_logits = outputs.logits[:, -1]
|
54 |
+
|
55 |
+
for j in range(batch):
|
56 |
+
last_logits[j, last_logits[j].topk(k=10).indices] += 10
|
57 |
+
|
58 |
+
past_key_values = outputs.past_key_values
|
59 |
+
token_ix = torch.multinomial(last_logits.softmax(-1), 1)
|
60 |
+
output = [stream + tokenizer.decode(ix) for stream, ix in zip(output, token_ix)]
|
61 |
+
|
62 |
+
if single_hook is not None:
|
63 |
+
single_hook(tokenizer.decode(token_ix[0]))
|
64 |
+
if i == limit_tokens:
|
65 |
+
batch_time = (time.perf_counter() - start_time) / (i - 4)
|
66 |
+
break
|
67 |
+
|
68 |
+
input_dict = dict(input_ids=token_ix)
|
69 |
+
return output, batch_time
|
70 |
+
|
71 |
+
import sys
|
72 |
+
|
73 |
+
def Sureprint(text):
|
74 |
+
text = f"\nDDBG: {text}\n"
|
75 |
+
print(text, flush=True)
|
76 |
+
print(text, file=sys.stderr, flush=True)
|
77 |
+
|
78 |
+
Sureprint("ready to load")
|
79 |
+
tokenizer, model = load_model()
|
80 |
+
Sureprint("loaded")
|
81 |
+
text = st.text_area("Prefix", value="DM: You enter the room.")
|
82 |
+
Sureprint(f"text acquired '{text}'")
|
83 |
+
batch = st.number_input("Variants", value=5)
|
84 |
|
85 |
+
t = st.empty()
|
86 |
+
firstline = ""
|
87 |
|
88 |
+
def PrintSome(text):
|
89 |
+
global t, firstline
|
90 |
+
firstline += text
|
91 |
+
t.markdown(f"{firstline}...")
|
92 |
|
93 |
+
Sureprint("before inference")
|
94 |
+
choices, batch_time = PrintContinuation(text, model, PrintSome, batch, 50)
|
95 |
+
Sureprint("after inference")
|
96 |
|
97 |
+
final_page = ""
|
98 |
+
for i in range(batch):
|
99 |
+
final_page += f"#### choice №{i + 1} \n{choices[i]} \n______ \n"
|
100 |
+
final_page += f"Seconds per batch: {batch_time}, Batch: {batch}"
|
101 |
|
102 |
+
t.markdown(final_page)
|
|
|
|
|
103 |
|
104 |
+
Sureprint("all done")
|
|
|
|