Spaces:
Runtime error
Runtime error
File size: 1,270 Bytes
5d906de db4c88c 5d906de db4c88c 5d906de db4c88c 2292c41 db4c88c 9a94062 db4c88c 77ac825 db4c88c f7e4169 db4c88c 77ac825 db4c88c 58aa497 9a94062 58aa497 db4c88c 5d906de db4c88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn.functional as F
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b-slimpj", device=device, dtype=torch.float16)
genlen = 500
def pred(text_in,):
tokens = tokenizer(text_in, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen
fn = lambda: model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=0.9,
top_p=0.7,
)
out = fn()
text_out = tokenizer.batch_decode(out.sequences.tolist())
return text_out[0]
demo = gr.Interface(
title="Mamba: Selective State Space Model",
description="A demo for [Mamba](https://github.com/state-spaces/mamba) by Albert & Tri.",
fn=pred, inputs="text", outputs="text")
if __name__ == "__main__":
demo.launch() |