File size: 6,625 Bytes
e956bee
 
 
2812e92
e956bee
 
 
26fd787
e956bee
 
 
 
 
2812e92
26fd787
e956bee
 
2812e92
e956bee
 
 
 
 
 
2812e92
e956bee
 
 
 
 
 
 
 
26fd787
 
e956bee
26fd787
2812e92
 
 
 
 
 
 
 
 
 
 
 
e956bee
 
 
 
26fd787
e956bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2812e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e956bee
 
 
 
 
 
2812e92
e956bee
2812e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e956bee
 
2812e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8541fdf
2812e92
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import codecs
from datetime import datetime
import gc
from transformers import BloomTokenizerFast
from petals.client import DistributedBloomForCausalLM
import torch
import time

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16
MODEL_NAMES = ["bigscience/bloom-petals", "bigscience/bloomz-petals"]

models = {"model":None,"model_name":None}
output = {MODEL_NAMES[0]:"",MODEL_NAMES[1]:""}


print (DEVICE)

def to_md(text):
    return text.replace("\n", "<br />")

def infer(
        prompt,
        model_idx = 0,
        max_new_tokens=10,
        temperature=0.1,
        top_p=1.0,
        repetition_penalty = 1.0,
        stop="\n",
        num_completions=1,
        seed=42,
):
    global output
    global models

    print("Loading Models\n")
    model_name = MODEL_NAMES[model_idx]
    if (models["model_name"] == None or models["model_name"] != model_name):
        models = {"model":None,"model_name":None}
        gc.collect()
        if (DEVICE == "cuda"):
            torch.cuda.empty_cache()
        tokenizer = BloomTokenizerFast.from_pretrained(model_name)
        model = DistributedBloomForCausalLM.from_pretrained(model_name, torch_dtype=TORCH_DTYPE, request_timeout=300)
        model = model.to(DEVICE)
        models["model"] = tokenizer, model
        models["model_name"] = model_name
    output[model_name] = ""

    max_new_tokens = int(max_new_tokens)
    temperature = float(temperature)
    top_p = float(top_p)
    stop =  [x.strip(' ') for x in stop.split(',')]
    repetition_penalty = float(repetition_penalty)
    seed = seed

    assert 1 <= max_new_tokens <= 384
    assert 1 <= num_completions <= 5
    assert 0.0 <= temperature <= 1.0
    assert 0.0 <= top_p <= 1.0
    assert 0.9 <= repetition_penalty <= 3.0

    if temperature == 0.0:
        temperature = 0.01
    if prompt == "":
        prompt = " "
    
    print(f"START -> ({datetime.now()})\n")
    print(f"PROMPT ({datetime.now()}):\n-------\n{prompt}\n")
    
    flag = False
    token_cnt = 0
    with models["model"][1].inference_session(max_length=512) as sess:
        print(f"Encode Input Prompt")
        output[model_name] = ""
        inputs = models["model"][0](prompt, return_tensors="pt")["input_ids"].to(DEVICE)
        n_input_tokens = inputs.shape[1]
        done = False
        print(f"Start Inference ({sess})")
        while not done:
            outputs = models["model"][1].generate(
                inputs, 
                max_new_tokens=1, 
                do_sample=True, 
                top_p=top_p, 
                temperature=temperature, 
                repetition_penalty=repetition_penalty,
                session=sess
            )
            output[model_name] += models["model"][0].decode(outputs[0, n_input_tokens:])
            token_cnt += 1
            print("\n["+ str(model_name) + "]" + output[model_name], end="", flush=True)
            yield output[model_name]
            for stop_word in stop:
                stop_word = codecs.getdecoder("unicode_escape")(stop_word)[0]
                if stop_word != '' and stop_word in output[model_name]:
                    print(f"\nDONE (stop)")
                    done = True
            if flag or (token_cnt >= max_new_tokens):
                print(f"\nDONE (max tokens)")
                done = True
            inputs = None  # Prefix is passed only for the 1st token of the bot's response
            n_input_tokens = 0
        print(f"\nEnd")
    yield output[model_name]

examples = [
    [
        # Question Answering
        '''Please answer the following question:
Question: What is the capital of Germany?
Answer:''',"BLOOMZ" , 3, 0.2, 1.0, 1.0, "\\n,</s>", ["BLOOM","BLOOMZ"]],
    [
        # Chatbot 1
        '''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOM" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
    [
        # Chatbot 1
        '''This is a conversation between Alex (an AI based on the 2020 GPT-3 language model), and Fritz (an AI based on the 2021 Jurassic-1 language model). They are exploring each other's capabilities, and trying to ask interesting, complex, and 'ungoogleable' questions of one another, to test the limits of the AI...
Alex: Good morning, Fritz!
Fritz:''',"BLOOMZ" , 160, 0.85, 0.9, 1.0, "\\n\\n,</s>"],
    [
        # Expert Answers 
        '''Expert Questions & Helpful Answers
Ask Research Experts
Question:
Are humans good or bad?

Full Answer:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"],
    [
        # G
        '''You are the writing assistant for Stephen King. You have worked in the fiction/horror genre for 30 years. You are a Pulitzer Prize-winning author, and now you are tasked with developing a skeletal outline for his newest novel, set to be completed in the spring of 2024. Create a title and brief description for the first 5 chapters of this work.\n\nTitle:''',"BLOOM" , 120, 0.85, 0.9, 1.0, "</s>"
    ]
]



iface = gr.Interface(
    fn=infer,
    allow_flagging="never",
    inputs=[
        gr.Textbox(lines=20,label="Input Prompt", max_lines=10),  # prompt
        gr.Radio(["BLOOM","BLOOMZ"], value="BLOOM", type="index", label="Choose 176 billion parameter Model"),
        gr.Slider(1, 256, value=15),  # max_tokens
        gr.Slider(0.0, 1.0, value=0.2),  # temperature
        gr.Slider(0.0, 1.0, value=0.9),  # top_p
        gr.Slider(0.9, 3.0, value=1.0),  # repetition penalty
        gr.Textbox(lines=1, value="\\n\\n,</s>") # stop
    ],
    outputs=gr.Textbox(lines=20, label="Generated Output:"),
    
    examples=examples,
    cache_examples=False,
    title="BLOOM vs BLOOMZ",
    description='''<p>Compare outputs of the BLOOM and BLOOMZ 176 billion parameter models using the Petals network. <b>WARNING:</b> Initial inference may take a long time. Keep the input prompt to a minimum size to speed things up.<p>
    <p>Please consider contributing your unused GPU cycles to the <a href='https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity'>petals swarm</a> to help speed up inference. Check the <a href='http://health.petals.ml/'>Health</a> of the Petals Swarm.</p>
    <p>Big thanks to <a href='https://www.rftcapital.com/'>RFT Capital</a> for providing initial compute resources.</p>'''
)

iface.queue(concurrency_count=2)
iface.launch()