Tonic commited on
Commit
f8e11cc
1 Parent(s): 50eae3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -8,53 +8,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
 
11
- device = "cuda"
12
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
- model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
14
- genlen = 200
15
-
16
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
18
 
19
- def pred(text_in, d_model, n_layer, vocab_size, genlen, temperature, top_k, top_p):
20
- model = MambaLMHeadModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size).to(device)
21
- model.eval()
22
  tokens = tokenizer(text_in, return_tensors="pt")
23
  input_ids = tokens.input_ids.to(device=device)
24
- attn_mask = tokens.attention_mask.to(device=device)
25
- max_length = input_ids.shape[1] + genlen
26
-
27
- output = model.generate(
28
  input_ids=input_ids,
29
  max_length=max_length,
30
- cg=True,
31
- do_sample = True,
32
- return_dict_in_generate=True,
33
- output_scores=True,
34
- enable_timing=False,
35
  temperature=temperature,
36
  top_k=top_k,
37
  top_p=top_p,
38
  )
39
-
40
- text_out = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)
41
  return text_out[0]
42
- pass
43
  demo = gr.Interface(
44
  fn=pred,
45
  inputs=[
46
  gr.Textbox(label="Input Text"),
47
- gr.Slider(minimum=128, maximum=1024, value=512, label="Model Dimension (d_model)"),
48
- gr.Slider(minimum=1, maximum=24, value=12, label="Number of Layers (n_layer)"),
49
- gr.Number(value=50257, label="Vocabulary Size (vocab_size)"),
50
- gr.Slider(minimum=50, maximum=500, value=200, label="Generation Length (genlen)"),
51
  gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
52
- gr.Slider(minimum=1, maximum=50, value=10, label="Top K"),
53
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
 
 
 
 
 
54
  ],
55
  outputs="text",
56
- title="Welcome to Tonic's 🐍Mamba",
57
- description="With this Demo, you can customize the model hyperparameters of [🐍Mamba](https://www.huggingface.co/state-spaces/mamba-2.8b). Every time you send a request it will instantiate the model accordingly, so please be patient."
 
 
58
  )
59
 
60
  if __name__ == "__main__":
 
8
 
9
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
 
 
 
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
+ model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype="auto")
14
 
15
+ def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
 
 
16
  tokens = tokenizer(text_in, return_tensors="pt")
17
  input_ids = tokens.input_ids.to(device=device)
18
+ max_length = input_ids.shape[1] + gen_length
19
+ out = model.generate(
 
 
20
  input_ids=input_ids,
21
  max_length=max_length,
22
+ cg=cg,
23
+ return_dict_in_generate=return_dict_in_generate,
24
+ output_scores=output_scores,
25
+ enable_timing=enable_timing,
 
26
  temperature=temperature,
27
  top_k=top_k,
28
  top_p=top_p,
29
  )
30
+ text_out = tokenizer.batch_decode(out.sequences.tolist(), skip_special_tokens=True)
 
31
  return text_out[0]
32
+
33
  demo = gr.Interface(
34
  fn=pred,
35
  inputs=[
36
  gr.Textbox(label="Input Text"),
 
 
 
 
37
  gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
38
+ gr.Slider(minimum=1, maximum=10, value=10, label="Top K"),
39
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
40
+ gr.Slider(minimum=50, maximum=650, value=200, label="Generation Length (gen_length)"),
41
+ gr.Checkbox(value=True, label="Cache Graph (cg)"),
42
+ gr.Checkbox(value=True, label="Return Dict in Generate"),
43
+ gr.Checkbox(value=True, label="Output Scores"),
44
+ gr.Checkbox(value=False, label="Enable Timing"),
45
  ],
46
  outputs="text",
47
+ title="Welcome👋🏻to🌟Tonic's🐍Mamba 2.8B! 🚀",
48
+ description="""🐍Mamba is quite special because it uses a unique model architecture, has reasonable🏆performance, and a👌🏻tiny size. You can use this Space to test out the current model 🐍[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) You can also use 🐍mamba-2.8b by cloning this space. Simply click here: [Duplicate Space](https://huggingface.co/spaces/Tonic1/VLChat?duplicate=true)
49
+ Join us: 🌟TeamTonic is always making cool demos! Join our active🛠️builder's community on Discord: [Discord](https://discord.gg/nXx5wbX9) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟[PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
50
+ """
51
  )
52
 
53
  if __name__ == "__main__":