File size: 5,330 Bytes
cda4913
 
 
 
 
 
554f3ed
cda4913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554f3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
from huggingface_hub import InferenceClient

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("shisa-ai/shisa-llama3-8b-v1")


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()

'''
# https://www.gradio.app/guides/using-hugging-face-integrations

import gradio as gr
import logging
import html
from   pprint import pprint
import time
import torch
from   threading import Thread
from   transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

# Model
model_name = "augmxnt/shisa-7b-v1"

# UI Settings
title = "Shisa 7B"
description = "Test out <a href='https://huggingface.co/augmxnt/shisa-7b-v1'>Shisa 7B</a> in either English or Japanese. If you aren't getting the right language outputs, you can try changing the system prompt to the appropriate language.\n\nNote: we are running this model quantized at `load_in_4bit` to fit in 16GB of VRAM."
placeholder = "Type Here / ここに入力してください" 
examples = [
    ["What are the best slices of pizza in New York City?"],
    ["東京でおすすめのラーメン屋ってどこ?"],
    ['How do I program a simple "hello world" in Python?'],
    ["Pythonでシンプルな「ハローワールド」をプログラムするにはどうすればいいですか?"],
]

# LLM Settings
# Initial
system_prompt = 'You are a helpful, bilingual assistant. Reply in same language as the user.'
default_prompt = system_prompt

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # load_in_8bit=True,
    load_in_4bit=True,
    use_flash_attention_2=True,
)

def chat(message, history, system_prompt):
    if not system_prompt:
        system_prompt = default_prompt

    print('---')
    print('Prompt:', system_prompt)
    pprint(history)
    print(message)

    # Let's just rebuild every time it's easier
    chat_history = [{"role": "system", "content": system_prompt}]
    for h in history:
        chat_history.append({"role": "user", "content": h[0]})
        chat_history.append({"role": "assistant", "content": h[1]})
    chat_history.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")

    # for multi-gpu, find the device of the first parameter of the model
    first_param_device = next(model.parameters()).device
    input_ids = input_ids.to(first_param_device)

    generate_kwargs = dict(
        inputs=input_ids,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        repetition_penalty=1.15,
        top_p=0.95,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    output_ids = model.generate(**generate_kwargs)
    new_tokens = output_ids[0, input_ids.size(1):]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True) 
    return response


chat_interface = gr.ChatInterface(
    chat,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
    title=title,
    description=description,
    theme="soft",
    examples=examples,
    cache_examples=False,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    additional_inputs=[
        gr.Textbox(system_prompt, label="System Prompt (Change the language of the prompt for better replies)"),
    ],
)

# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
with gr.Blocks() as demo:
    chat_interface.render()
    gr.Markdown("You can try asking this question in Japanese or English. We limit output to 200 tokens.")

demo.queue().launch()

'''