File size: 3,220 Bytes
54fe16b
 
31a1ff8
e2b5fc2
d6839dc
d34759f
d6839dc
54fe16b
9e4ba23
af83917
54fe16b
 
87aa391
54fe16b
 
 
 
 
 
 
 
 
31a1ff8
54fe16b
0a1707e
54fe16b
 
 
 
 
 
b2926ec
54fe16b
 
 
 
87aa391
 
54fe16b
0a1707e
 
54fe16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab71e4
54fe16b
 
 
 
 
 
e9ac030
23521bb
0a1707e
 
54fe16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import os
import spaces


import gradio as gr

import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str)  # model path
    parser.add_argument("--n_gpus", type=int, default=1)  # n_gpu
    return parser.parse_args()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    global model, tokenizer, device
    instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
    for human, assistant in history:
        instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
    instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
    problem = [instruction]
    stop_tokens = ["<|endoftext|>", "<|im_end|>"]
    streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids
    attention_mask = enc.attention_mask

    if input_ids.shape[1] > MAX_LENGTH:
        input_ids = input_ids[:, -MAX_LENGTH:]

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        top_p=0.95,
        temperature=0.5,
        max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        if text in stop_tokens:
            break
        print(text)
        yield "".join(outputs)



if __name__ == "__main__":
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b", torch_dtype=torch.bfloat16)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    gr.ChatInterface(
        predict,
        title="Stable Code Instruct Chat - Demo",
        description="Chat Model Stable Code 3B",
        theme="soft",
        chatbot=gr.Chatbot(height=1400, label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", label="System Prompt"),
            gr.Slider(0, 1, 0.9, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch()