File size: 1,113 Bytes
5d906de
db4c88c
5d906de
db4c88c
 
5d906de
db4c88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d906de
db4c88c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn.functional as F

from einops import rearrange
import gradio as gr

from transformers import AutoTokenizer, AutoModelForCausalLM

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)

def pred(text_in):
    tokens = tokenizer(text_in, return_tensors="pt")
    input_ids = tokens.input_ids.to(device=device)
    attn_mask = tokens.attention_mask.to(device=device)
    max_length = input_ids.shape[1] + 100
    fn = lambda: model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=1.0,
        top_k=1,
        top_p=1.0,
    )
    out = fn()
    text_out = tokenizer.batch_decode(out.sequences.tolist())
    return text_out

demo = gr.Interface(fn=pred, inputs="text", outputs="text")
    
if __name__ == "__main__":
    demo.launch()