File size: 3,934 Bytes
85982c7
7be1664
b9cf639
f6cbd41
7be1664
 
8d621ae
7be1664
 
 
 
4df27ca
 
 
 
7be1664
 
 
 
 
 
 
 
4df27ca
b2497fe
4df27ca
 
 
 
e1ea1d4
b2497fe
8586d5d
 
b2497fe
4df27ca
7be1664
4df27ca
 
7be1664
 
 
4df27ca
 
 
7be1664
4df27ca
7be1664
 
 
 
4df27ca
 
 
7be1664
4df27ca
7be1664
 
 
 
 
4df27ca
 
 
 
7be1664
4df27ca
7be1664
 
 
4df27ca
7be1664
4df27ca
7be1664
4df27ca
7be1664
4df27ca
7be1664
4df27ca
 
7be1664
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
# from optimum.bettertransformer import BetterTransformer
from tokenization_yi import YiTokenizer
import torch
import os
import bitsandbytes
import gradio as gr
import sentencepiece


# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:126'
MAX_MAX_NEW_TOKENS = 160000
DEFAULT_MAX_NEW_TOKENS = 20000
MAX_INPUT_TOKEN_LENGTH = 160000
DESCRIPTION = """
# Welcome to Tonic'sYI-6B-200K
You can use this Space to test out the current model [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K)
You can also use YI-200 by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic1Tonics-Yi-6B-200K/?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> 
Join us : TeamTonic is always making cool demos! Join our active builder's community on Discord: [Discord](https://discord.gg/nXx5wbX9) On Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On Github: [Polytonic](https://github.com/tonic-ai) & contribute to [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
"""


# Set up the model and tokenizer
model_name = "01-ai/Yi-6B-200K"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="cuda", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
    trust_remote_code=True
)
model.to(device)

def run(prompt, max_new_tokens, temperature, top_p, top_k):
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    response_ids = model.generate(
        input_ids,
        max_length=max_new_tokens + input_ids.shape[1],
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True
    )
    response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response

def generate(prompt, max_new_tokens, temperature, top_p, top_k):
    response = run(prompt, max_new_tokens, temperature, top_p, top_k)
    return response

# Gradio Interface
with gr.Blocks(theme='ParityError/Anime') as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(
                label='Enter your prompt',
                placeholder='Type something...',
                lines=5
            )
            submit_button = gr.Button('Generate')

    with gr.Accordion(label='Advanced options', open=False):
        max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
        temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=2.0, step=0.1, value=1.2)
        top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
        top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=900)

    output = gr.Textbox(label='Generated Text', lines=10, readonly=True)

    submit_button.click(
        fn=generate,
        inputs=[prompt, max_new_tokens, temperature, top_p, top_k],
        outputs=output
    )

demo.queue(max_size=5).launch(show_api=True)