File size: 3,191 Bytes
9c9ed59
 
 
 
7f7d37c
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
ca677a9
 
 
 
 
9c9ed59
 
d2304af
9c9ed59
 
 
 
c8c7772
9c9ed59
 
 
5d623bb
9c9ed59
28d0e79
9c9ed59
 
c8c7772
9c9ed59
 
 
d2304af
9c9ed59
 
 
 
c8c7772
9c9ed59
 
 
 
 
 
 
 
c8c7772
9c9ed59
 
 
5d623bb
cac98cc
9c9ed59
4579d7a
 
cac98cc
4579d7a
 
eda969f
 
ecf9963
 
 
eda969f
 
e95e8e1
 
 
2891dae
f8e42d0
1afe06d
4579d7a
1afe06d
e95e8e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.5,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт более разнообразные результаты.",
    ),
    gr.Slider(
        label="Max new tokens",
        value=20480,
        minimum=0,
        maximum=32768,
        step=64,
        interactive=True,
        info="Максимальное количество токенов",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.75,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт большее разнообразие ",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Степень наказания за повторение токенов",
    )
]

examples=[["", "Отвечай всегда полностью на русском языке", 0.5, 20480, 0.75, 1.2],
]


description = r"""

"""


article = r"""
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![GitHub Stars](https://img.shields.io/github/stars/TencentARC/GFPGAN?style=social)](https://github.com/TencentARC/GFPGAN)
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2101.04061)
"""

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mix-OpenAI-Chat",
    examples=examples,
    description=description,
    concurrency_limit=20,
).launch(show_api=False)