File size: 1,262 Bytes
5d906de
db4c88c
5d906de
db4c88c
 
5d906de
db4c88c
 
 
 
 
 
77ac825
 
db4c88c
 
 
 
 
77ac825
db4c88c
 
 
 
 
 
 
77ac825
db4c88c
77ac825
db4c88c
 
 
77ac825
db4c88c
58aa497
 
 
 
db4c88c
5d906de
db4c88c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F

from einops import rearrange
import gradio as gr

from transformers import AutoTokenizer, AutoModelForCausalLM

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
genlen = 200

def pred(text_in):
    tokens = tokenizer(text_in, return_tensors="pt")
    input_ids = tokens.input_ids.to(device=device)
    attn_mask = tokens.attention_mask.to(device=device)
    max_length = input_ids.shape[1] + genlen
    fn = lambda: model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=0.5,
        top_k=1,
        top_p=0.9,
    )
    out = fn()
    text_out = tokenizer.batch_decode(out.sequences.tolist())
    return text_out[0]

demo = gr.Interface(
    title="Mamba: Selective State Space Model",
    description="A demo for [Mamba](https://github.com/state-spaces/mamba)",
    fn=pred, inputs="text", outputs="text")
    
if __name__ == "__main__":
    demo.launch()