File size: 3,020 Bytes
5d906de
db4c88c
74d73f1
db4c88c
 
 
 
 
f61eeff
f8e11cc
74d73f1
f61eeff
f8e11cc
db4c88c
 
74d73f1
 
f8e11cc
db4c88c
 
f8e11cc
 
 
 
f61eeff
 
 
db4c88c
f8e11cc
77ac825
f8e11cc
58aa497
f61eeff
 
2d5895c
50eae3e
f8e11cc
50eae3e
f8e11cc
 
 
 
 
f61eeff
 
f8e11cc
 
e18b2e5
f8e11cc
f61eeff
 
5d906de
f61eeff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn.functional as F
import einops
from einops import rearrange
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device=device, dtype=torch.float16)

def pred(text_in, temperature, top_k, top_p, gen_length, cg, return_dict_in_generate, output_scores, enable_timing):
    tokens = tokenizer(text_in, return_tensors="pt")
    input_ids = tokens.input_ids.to(device=device)
    attn_mask = tokens.attention_mask.to(device=device)
    max_length = input_ids.shape[1] + genlen
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=cg,
        return_dict_in_generate=return_dict_in_generate,
        output_scores=output_scores,
        enable_timing=enable_timing,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    text_out = tokenizer.batch_decode(out.sequences.tolist(), skip_special_tokens=True)
    return text_out[0]

demo = gr.Interface(
    fn=pred,
    inputs=[
        gr.Textbox(label="Input Text"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Temperature"),
        gr.Slider(minimum=1, maximum=10, value=10, label="Top K"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
        gr.Slider(minimum=50, maximum=650, value=200, label="Generation Length (gen_length)"),
        gr.Checkbox(value=True, label="Cache Graph (cg)"),
        gr.Checkbox(value=True, label="Return Dict in Generate"),
        gr.Checkbox(value=True, label="Output Scores"),
        gr.Checkbox(value=False, label="Enable Timing"),
    ],
    outputs="text",
    title="Welcome👋🏻to🌟Tonic's🐍Mamba 2.8B! 🚀",
    description="""🐍Mamba is quite special because it uses a unique model architecture, has reasonable🏆performance, and a👌🏻tiny size. You can use this Space to test out the current model 🐍[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) You can also use  🐍mamba-2.8b by cloning this space.   Simply click here: [Duplicate Space](https://huggingface.co/spaces/Tonic1/VLChat?duplicate=true) 
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community 👻  [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to 🌟 [DataTonic](https://github.com/Tonic-AI/DataTonic) 🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
)

if __name__ == "__main__":
    demo.launch()