File size: 3,697 Bytes
dda1539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23e105f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import json
import os
import spaces
import torch

from dotenv import load_dotenv
from huggingface_hub import login, snapshot_download

from superposed.llama.superposed_generation import SuperposedLlama
from superposed.llama.tokenizer import Tokenizer
from superposed.ngrams.ngram_models import make_models

# load_dotenv()
# print(os.getenv("HF_ACCESS_TOKEN"))
login(os.getenv("HF_ACCESS_TOKEN"))
if not os.path.exists("./weights/"):
    os.mkdir("./weights/")
snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights/")
weight_path = "./weights/"
# Load params
param_file = "params/p15_d3_mixed.json"
with open(param_file, "r") as f:
    params = json.load(f)
alpha = params["alpha"]
temp = params["temp"]
n_drafts = params["n_drafts"]
prompt_len = params["prompt_len"]
n_token_sample = params["n_token_sample"]
i_weights = params["i_weights"]
i_length = params["i_length"]
# Load main model
model = SuperposedLlama.build(ckpt_dir=weight_path, 
                         tokenizer_path=f'{weight_path}/tokenizer.model', 
                         max_seq_len=100, 
                         max_batch_size=32,
                         model_parallel_size=1)
tokenizer = Tokenizer(f'{weight_path}/tokenizer.model')
# Create ngram models
ngrams = make_models("ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)

def decode(tokenizer, encoding):
    """
    Args:
        tokenizer (Any): Tokenizer
        encoding (torch.Tensor): Encoding
    Returns:
        decoding (str)
    """
    eos_locs = (encoding == tokenizer.eos_id).nonzero()
    if len(eos_locs > 0):
        encoding = encoding[:eos_locs[0]]
    return tokenizer.decode(encoding.to(torch.int32).tolist())

@spaces.GPU
def update_options(input, num_tokens):
    tokenized_prompts = tokenizer.encode([input], True, False)
    alive_gens, _ = model.sup_generate(prompt_tokens=tokenized_prompts, 
                                            smoothing="geom",
                                            max_gen_len=num_tokens, 
                                            n_token_sample=n_token_sample,
                                            alpha=alpha, 
                                            temp=temp,
                                            n_drafts=n_drafts,
                                            i_weights=i_weights,
                                            i_length=i_length,
                                            ngrams=ngrams,
                                            get_time=False,
                                            penalty=200)
    gens = alive_gens[0].reshape(n_drafts, -1)
    return decode(tokenizer, gens[0]), decode(tokenizer, gens[1]), decode(tokenizer, gens[2])

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
    """
    # Superposed Decoding
    Start typing below to see suggestions.
    """)
    slider = gr.Slider(minimum=1, maximum=10, step=1, label="Generation length", value=10)
    inp = gr.Textbox(placeholder="Type anything!", lines=3)
    option1 = gr.Button(value="Option 1")
    option2 = gr.Button(value="Option 2")
    option3 = gr.Button(value="Option 3")
    inp.change(update_options, inputs=[inp, slider], outputs=[option1, option2, option3])
    # Button updates
    @option1.click(inputs=[inp, option1], outputs=inp)
    def option1_click(curr, txt):
        return curr + txt
    @option2.click(inputs=[inp, option2], outputs=inp)
    def option2_click(curr, txt):
        return curr + txt
    @option3.click(inputs=[inp, option3], outputs=inp)
    def option3_click(curr, txt):
        return curr + txt

if __name__ == "__main__":
    demo.launch()