mamba / app.py
reach-vb's picture
reach-vb HF staff
Update app.py
77ac825
raw
history blame
1.13 kB
import torch
import torch.nn.functional as F
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
genlen = 200
def pred(text_in):
tokens = tokenizer(text_in, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen
fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=0.5,
top_k=1,
top_p=0.9,
)
out = fn()
text_out = tokenizer.batch_decode(out.sequences.tolist())
return text_out[0]
demo = gr.Interface(fn=pred, inputs="text", outputs="text")
if __name__ == "__main__":
demo.launch()