Tonic commited on
Commit
74d73f1
1 Parent(s): f8e11cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,21 +1,20 @@
1
  import torch
2
  import torch.nn.functional as F
3
-
4
  from einops import rearrange
5
  import gradio as gr
6
-
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
9
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
- model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype="auto")
14
 
15
  def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
16
  tokens = tokenizer(text_in, return_tensors="pt")
17
  input_ids = tokens.input_ids.to(device=device)
18
- max_length = input_ids.shape[1] + gen_length
 
19
  out = model.generate(
20
  input_ids=input_ids,
21
  max_length=max_length,
 
1
  import torch
2
  import torch.nn.functional as F
3
+ import einops
4
  from einops import rearrange
5
  import gradio as gr
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
  from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
11
+ model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
12
 
13
  def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
14
  tokens = tokenizer(text_in, return_tensors="pt")
15
  input_ids = tokens.input_ids.to(device=device)
16
+ attn_mask = tokens.attention_mask.to(device=device)
17
+ max_length = input_ids.shape[1] + genlen
18
  out = model.generate(
19
  input_ids=input_ids,
20
  max_length=max_length,