Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,53 +8,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
8 |
|
9 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
10 |
|
11 |
-
device = "cuda"
|
12 |
-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
13 |
-
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
|
14 |
-
genlen = 200
|
15 |
-
|
16 |
-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
18 |
|
19 |
-
def pred(text_in,
|
20 |
-
model = MambaLMHeadModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size).to(device)
|
21 |
-
model.eval()
|
22 |
tokens = tokenizer(text_in, return_tensors="pt")
|
23 |
input_ids = tokens.input_ids.to(device=device)
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
output = model.generate(
|
28 |
input_ids=input_ids,
|
29 |
max_length=max_length,
|
30 |
-
cg=
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
enable_timing=False,
|
35 |
temperature=temperature,
|
36 |
top_k=top_k,
|
37 |
top_p=top_p,
|
38 |
)
|
39 |
-
|
40 |
-
text_out = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)
|
41 |
return text_out[0]
|
42 |
-
|
43 |
demo = gr.Interface(
|
44 |
fn=pred,
|
45 |
inputs=[
|
46 |
gr.Textbox(label="Input Text"),
|
47 |
-
gr.Slider(minimum=128, maximum=1024, value=512, label="Model Dimension (d_model)"),
|
48 |
-
gr.Slider(minimum=1, maximum=24, value=12, label="Number of Layers (n_layer)"),
|
49 |
-
gr.Number(value=50257, label="Vocabulary Size (vocab_size)"),
|
50 |
-
gr.Slider(minimum=50, maximum=500, value=200, label="Generation Length (genlen)"),
|
51 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
|
52 |
-
gr.Slider(minimum=1, maximum=
|
53 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
|
|
|
|
|
|
|
|
|
|
|
54 |
],
|
55 |
outputs="text",
|
56 |
-
title="Welcome
|
57 |
-
description="
|
|
|
|
|
58 |
)
|
59 |
|
60 |
if __name__ == "__main__":
|
|
|
8 |
|
9 |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
13 |
+
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype="auto")
|
14 |
|
15 |
+
def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
|
|
|
|
|
16 |
tokens = tokenizer(text_in, return_tensors="pt")
|
17 |
input_ids = tokens.input_ids.to(device=device)
|
18 |
+
max_length = input_ids.shape[1] + gen_length
|
19 |
+
out = model.generate(
|
|
|
|
|
20 |
input_ids=input_ids,
|
21 |
max_length=max_length,
|
22 |
+
cg=cg,
|
23 |
+
return_dict_in_generate=return_dict_in_generate,
|
24 |
+
output_scores=output_scores,
|
25 |
+
enable_timing=enable_timing,
|
|
|
26 |
temperature=temperature,
|
27 |
top_k=top_k,
|
28 |
top_p=top_p,
|
29 |
)
|
30 |
+
text_out = tokenizer.batch_decode(out.sequences.tolist(), skip_special_tokens=True)
|
|
|
31 |
return text_out[0]
|
32 |
+
|
33 |
demo = gr.Interface(
|
34 |
fn=pred,
|
35 |
inputs=[
|
36 |
gr.Textbox(label="Input Text"),
|
|
|
|
|
|
|
|
|
37 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
|
38 |
+
gr.Slider(minimum=1, maximum=10, value=10, label="Top K"),
|
39 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
|
40 |
+
gr.Slider(minimum=50, maximum=650, value=200, label="Generation Length (gen_length)"),
|
41 |
+
gr.Checkbox(value=True, label="Cache Graph (cg)"),
|
42 |
+
gr.Checkbox(value=True, label="Return Dict in Generate"),
|
43 |
+
gr.Checkbox(value=True, label="Output Scores"),
|
44 |
+
gr.Checkbox(value=False, label="Enable Timing"),
|
45 |
],
|
46 |
outputs="text",
|
47 |
+
title="Welcome👋🏻to🌟Tonic's🐍Mamba 2.8B! 🚀",
|
48 |
+
description="""🐍Mamba is quite special because it uses a unique model architecture, has reasonable🏆performance, and a👌🏻tiny size. You can use this Space to test out the current model 🐍[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) You can also use 🐍mamba-2.8b by cloning this space. Simply click here: [Duplicate Space](https://huggingface.co/spaces/Tonic1/VLChat?duplicate=true)
|
49 |
+
Join us: 🌟TeamTonic is always making cool demos! Join our active🛠️builder's community on Discord: [Discord](https://discord.gg/nXx5wbX9) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟[PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
|
50 |
+
"""
|
51 |
)
|
52 |
|
53 |
if __name__ == "__main__":
|