import torch import torch.nn.functional as F from einops import rearrange import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel device = "cuda" tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16) genlen = 200 tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") device = "cuda" if torch.cuda.is_available() else "cpu" def pred(text_in, d_model, n_layer, vocab_size, genlen, temperature, top_k, top_p): model = MambaLMHeadModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size).to(device) model.eval() tokens = tokenizer(text_in, return_tensors="pt") input_ids = tokens.input_ids.to(device=device) attn_mask = tokens.attention_mask.to(device=device) max_length = input_ids.shape[1] + genlen output = model.generate( input_ids=input_ids, max_length=max_length, cg=True, do_sample = True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=temperature, top_k=top_k, top_p=top_p, ) text_out = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True) return text_out[0] pass demo = gr.Interface( fn=pred, inputs=[ gr.Textbox(label="Input Text"), gr.Slider(minimum=128, maximum=1024, value=512, label="Model Dimension (d_model)"), gr.Slider(minimum=1, maximum=24, value=12, label="Number of Layers (n_layer)"), gr.Number(value=50257, label="Vocabulary Size (vocab_size)"), gr.Slider(minimum=50, maximum=500, value=200, label="Generation Length (genlen)"), gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"), gr.Slider(minimum=1, maximum=50, value=10, label="Top K"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"), ], outputs="text", title="Welcome to Tonic's 🐍Mamba", description="With this Demo, you can customize the model hyperparameters of [🐍Mamba](https://www.huggingface.co/state-spaces/mamba-2.8b). Every time you send a request it will instantiate the model accordingly, so please be patient." ) if __name__ == "__main__": demo.launch()