import torch import torch.nn.functional as F from einops import rearrange import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel device = "cuda" tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16) def pred(text_in): tokens = tokenizer(text_in, return_tensors="pt") input_ids = tokens.input_ids.to(device=device) attn_mask = tokens.attention_mask.to(device=device) max_length = input_ids.shape[1] + 100 fn = lambda: model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=1.0, top_k=1, top_p=1.0, ) out = fn() text_out = tokenizer.batch_decode(out.sequences.tolist()) return text_out demo = gr.Interface(fn=pred, inputs="text", outputs="text") if __name__ == "__main__": demo.launch()