Tonic commited on
Commit
f61eeff
1 Parent(s): 3fbacf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -13,30 +13,50 @@ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
  model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
14
  genlen = 200
15
 
16
- def pred(text_in,):
 
 
 
 
 
17
  tokens = tokenizer(text_in, return_tensors="pt")
18
  input_ids = tokens.input_ids.to(device=device)
19
  attn_mask = tokens.attention_mask.to(device=device)
20
  max_length = input_ids.shape[1] + genlen
21
- fn = lambda: model.generate(
 
22
  input_ids=input_ids,
23
  max_length=max_length,
24
  cg=True,
 
25
  return_dict_in_generate=True,
26
  output_scores=True,
27
  enable_timing=False,
28
- temperature=0.5,
29
- top_k=10,
30
- top_p=0.9,
31
  )
32
- out = fn()
33
- text_out = tokenizer.batch_decode(out.sequences.tolist())
34
  return text_out[0]
35
 
 
36
  demo = gr.Interface(
37
- title="Mamba: Selective State Space Model",
38
- description="A demo for [Mamba](https://github.com/state-spaces/mamba) by Albert & Tri.",
39
- fn=pred, inputs="text", outputs="text")
40
-
 
 
 
 
 
 
 
 
 
 
 
 
41
  if __name__ == "__main__":
42
- demo.launch()
 
13
  model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
14
  genlen = 200
15
 
16
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ def pred(text_in, d_model, n_layer, vocab_size, genlen, temperature, top_k, top_p):
20
+ model = MambaLMHeadModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size).to(device)
21
+ model.eval()
22
  tokens = tokenizer(text_in, return_tensors="pt")
23
  input_ids = tokens.input_ids.to(device=device)
24
  attn_mask = tokens.attention_mask.to(device=device)
25
  max_length = input_ids.shape[1] + genlen
26
+
27
+ output = model.generate(
28
  input_ids=input_ids,
29
  max_length=max_length,
30
  cg=True,
31
+ do_sample = True
32
  return_dict_in_generate=True,
33
  output_scores=True,
34
  enable_timing=False,
35
+ temperature=temperature,
36
+ top_k=top_k,
37
+ top_p=top_p,
38
  )
39
+
40
+ text_out = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)
41
  return text_out[0]
42
 
43
+ # Define Gradio interface
44
  demo = gr.Interface(
45
+ fn=pred,
46
+ inputs=[
47
+ gr.inputs.Textbox(label="Input Text"),
48
+ gr.inputs.Slider(minimum=128, maximum=1024, default=512, label="Model Dimension (d_model)"),
49
+ gr.inputs.Slider(minimum=1, maximum=24, default=12, label="Number of Layers (n_layer)"),
50
+ gr.inputs.Number(default=50257, label="Vocabulary Size (vocab_size)"),
51
+ gr.inputs.Slider(minimum=50, maximum=500, default=200, label="Generation Length (genlen)"),
52
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.5, label="Temperature"),
53
+ gr.inputs.Slider(minimum=1, maximum=50, default=10, label="Top K"),
54
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.9, label="Top P"),
55
+ ],
56
+ outputs="text",
57
+ title="Welcome to Tonic's 🐍Mamba",
58
+ description="With this Demo, you can customize the model hyperparameters of [🐍Mamba](https://www.huggingface.co/state-spaces/mamba-2.8b) . Everytime you send a request it will instantiate the model accordingly, so please be patient."
59
+ )
60
+
61
  if __name__ == "__main__":
62
+ demo.launch()