reach-vb HF staff commited on
Commit
2292c41
1 Parent(s): 9a94062
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -10,8 +10,8 @@ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
 
11
  device = "cuda"
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
- model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
14
- genlen = 200
15
 
16
  def pred(text_in,):
17
  tokens = tokenizer(text_in, return_tensors="pt")
@@ -25,8 +25,7 @@ def pred(text_in,):
25
  return_dict_in_generate=True,
26
  output_scores=True,
27
  enable_timing=False,
28
- temperature=0.5,
29
- top_k=10,
30
  top_p=0.9,
31
  )
32
  out = fn()
 
10
 
11
  device = "cuda"
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
+ model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b-slimpj", device=device, dtype=torch.float16)
14
+ genlen = 500
15
 
16
  def pred(text_in,):
17
  tokens = tokenizer(text_in, return_tensors="pt")
 
25
  return_dict_in_generate=True,
26
  output_scores=True,
27
  enable_timing=False,
28
+ temperature=0.7,
 
29
  top_p=0.9,
30
  )
31
  out = fn()