reach-vb HF staff commited on
Commit
77ac825
1 Parent(s): ef0b2de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -10,13 +10,14 @@ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
 
11
  device = "cuda"
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
- model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)
 
14
 
15
  def pred(text_in):
16
  tokens = tokenizer(text_in, return_tensors="pt")
17
  input_ids = tokens.input_ids.to(device=device)
18
  attn_mask = tokens.attention_mask.to(device=device)
19
- max_length = input_ids.shape[1] + 100
20
  fn = lambda: model.generate(
21
  input_ids=input_ids,
22
  max_length=max_length,
@@ -24,13 +25,13 @@ def pred(text_in):
24
  return_dict_in_generate=True,
25
  output_scores=True,
26
  enable_timing=False,
27
- temperature=1.0,
28
  top_k=1,
29
- top_p=1.0,
30
  )
31
  out = fn()
32
  text_out = tokenizer.batch_decode(out.sequences.tolist())
33
- return text_out
34
 
35
  demo = gr.Interface(fn=pred, inputs="text", outputs="text")
36
 
 
10
 
11
  device = "cuda"
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
+ model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
14
+ genlen = 200
15
 
16
  def pred(text_in):
17
  tokens = tokenizer(text_in, return_tensors="pt")
18
  input_ids = tokens.input_ids.to(device=device)
19
  attn_mask = tokens.attention_mask.to(device=device)
20
+ max_length = input_ids.shape[1] + genlen
21
  fn = lambda: model.generate(
22
  input_ids=input_ids,
23
  max_length=max_length,
 
25
  return_dict_in_generate=True,
26
  output_scores=True,
27
  enable_timing=False,
28
+ temperature=0.5,
29
  top_k=1,
30
+ top_p=0.9,
31
  )
32
  out = fn()
33
  text_out = tokenizer.batch_decode(out.sequences.tolist())
34
+ return text_out[0]
35
 
36
  demo = gr.Interface(fn=pred, inputs="text", outputs="text")
37