File size: 5,087 Bytes
4e2136a
0bb16c0
 
 
 
 
 
 
 
 
4e2136a
 
 
 
 
 
 
 
0bb16c0
 
 
 
 
4e2136a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb16c0
 
 
4e2136a
 
 
 
0bb16c0
4e2136a
 
 
 
 
 
 
 
 
 
 
0bb16c0
 
4e2136a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb16c0
 
23ef29e
 
 
 
 
 
0bb16c0
23ef29e
 
 
 
4e2136a
 
23ef29e
4e2136a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb16c0
 
4e2136a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import itertools

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

tokenizer = AutoTokenizer.from_pretrained(
    "rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False
)
model = AutoModelForCausalLM.from_pretrained(
    "rinna/japanese-gpt-neox-3.6b-instruction-sft",
    device_map="auto",
    torch_dtype=torch.float16,
)
model = model.to(device)


@torch.no_grad()
def inference_func(prompt, max_new_tokens=128, temperature=0.7):
    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    output = tokenizer.decode(
        output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True
    )
    output = output.replace("<NL>", "\n")
    return output


def make_prompt(message, chat_history, max_context_size: int = 10):
    contexts = chat_history + [[message, ""]]
    contexts = list(itertools.chain.from_iterable(contexts))
    if max_context_size > 0:
        context_size = max_context_size - 1
    else:
        context_size = 100000
    contexts = contexts[-context_size:]
    prompt = []
    for idx, context in enumerate(reversed(contexts)):
        if idx % 2 == 0:
            prompt = [f"システム: {context}"] + prompt
        else:
            prompt = [f"ユーザー: {context}"] + prompt
    prompt = "<NL>".join(prompt)
    return prompt


def interact_func(message, chat_history, max_context_size, max_new_tokens, temperature):
    prompt = make_prompt(message, chat_history, max_context_size)
    print(f"prompt: {prompt}")
    generated = inference_func(prompt, max_new_tokens, temperature)
    print(f"generated: {generated}")
    chat_history.append((message, generated))
    return "", chat_history


ORIGINAL_SPACE_ID = "mkshing/rinna-japanese-gpt-neox-3.6b-instruction-sft"
SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
"""

if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
    SETTINGS = (
        f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
    )
else:
    SETTINGS = "Settings"
CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
<center>
You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
"T4 small" is sufficient to run this demo.
</center>
"""


def show_warning(warning_text: str) -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Box():
            gr.Markdown(warning_text)
    return demo

with gr.Blocks() as demo:
    if os.getenv('IS_SHARED_UI'):
        show_warning(SHARED_UI_WARNING)
    if not torch.cuda.is_available():
        show_warning(CUDA_NOT_AVAILABLE_WARNING)
    gr.Markdown("""# Chat with `rinna/japanese-gpt-neox-3.6b-instruction-sft`
    <a href=\"https://colab.research.google.com/github/mkshing/notebooks/blob/main/rinna_japanese_gpt_neox_3_6b_instruction_sft.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>

    This demo is a chat UI for [rinna/japanese-gpt-neox-3.6b-instruction-sft](https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft). 
    """)
    with gr.Accordion("Configs", open=False):
        # max_context_size = the number of turns * 2
        max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
        max_new_tokens = gr.Number(value=128, label="max_new_tokens", precision=0)
        temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="temperature")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")
    msg.submit(
        interact_func,
        [msg, chatbot, max_context_size, max_new_tokens, temperature],
        [msg, chatbot],
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(debug=True)