File size: 4,381 Bytes
c544e81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mistral-7B-Instruct-v0.1"
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

css = """
  #mkd {
    height: 200px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        examples=[
        ["Create a ten-point markdown outline with emojis about: Decreased Ξ±-ketoglutarate dehydrogenase activity in astrocytes"],
        ["Create a ten-point markdown outline with emojis about: Lewy body dementia"],
        ["Create a ten-point markdown outline with emojis about: Delusional disorder"],
        ["Create a ten-point markdown outline with emojis about: Galantamine"],
        ["Create a ten-point markdown outline with emojis about: Neural crest"],
        ["Create a ten-point markdown outline with emojis about: Progressive multifocal encephalopathy (PML)"],
        ["Create a ten-point markdown outline with emojis about: CT head"],
        ["Create a ten-point markdown outline with emojis about: Ξ²-Galactocerebrosidase"],
        ["Create a ten-point markdown outline with emojis about: Dopamine"],
        ["Create a ten-point markdown outline with emojis about: G protein-coupled receptors"],
        ["Create a ten-point markdown outline with emojis about: CT scan of the head without contrast"],
        ["Create a ten-point markdown outline with emojis about: Pyogenic brain abscess"],
        ["Create a ten-point markdown outline with emojis about: Pneumocystitis jiroveci"]
        ]
    )
    gr.HTML("""<h2>πŸ€– Mistral Chat - Gradio πŸ€–</h2>
        In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬
        Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š
        <h2>πŸ›  Model Features πŸ› </h2>
        <ul>
          <li>πŸͺŸ Sliding Window Attention with 128K tokens span</li>
          <li>πŸš€ GQA for faster inference</li>
          <li>πŸ“ Byte-fallback BPE tokenizer</li>
        </ul>
        <h3>πŸ“œ License πŸ“œ  Released under Apache 2.0 License</h3>
        <h3>πŸ“¦ Usage πŸ“¦</h3>
        <ul>
          <li>πŸ“š Available on Huggingface Hub</li>
          <li>🐍 Python code snippets for easy setup</li>
          <li>πŸ“ˆ Expected speedups with Flash Attention 2</li>
        </ul>
    """)
demo.queue().launch(debug=True)