MambaTonic / app.py
Tonic's picture
Update app.py
50eae3e
raw
history blame
2.39 kB
import torch
import torch.nn.functional as F
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)
genlen = 200
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
device = "cuda" if torch.cuda.is_available() else "cpu"
def pred(text_in, d_model, n_layer, vocab_size, genlen, temperature, top_k, top_p):
model = MambaLMHeadModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size).to(device)
model.eval()
tokens = tokenizer(text_in, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + genlen
output = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
do_sample = True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
text_out = tokenizer.batch_decode(output.sequences.tolist(), skip_special_tokens=True)
return text_out[0]
pass
demo = gr.Interface(
fn=pred,
inputs=[
gr.Textbox(label="Input Text"),
gr.Slider(minimum=128, maximum=1024, value=512, label="Model Dimension (d_model)"),
gr.Slider(minimum=1, maximum=24, value=12, label="Number of Layers (n_layer)"),
gr.Number(value=50257, label="Vocabulary Size (vocab_size)"),
gr.Slider(minimum=50, maximum=500, value=200, label="Generation Length (genlen)"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
gr.Slider(minimum=1, maximum=50, value=10, label="Top K"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
],
outputs="text",
title="Welcome to Tonic's 🐍Mamba",
description="With this Demo, you can customize the model hyperparameters of [🐍Mamba](https://www.huggingface.co/state-spaces/mamba-2.8b). Every time you send a request it will instantiate the model accordingly, so please be patient."
)
if __name__ == "__main__":
demo.launch()