File size: 4,213 Bytes
6309c6b
 
7fdf629
 
0ea985e
 
a09c96f
0ea985e
 
7fdf629
0b644a6
6309c6b
0ea985e
 
6309c6b
 
 
0ea985e
a09c96f
0ea985e
6309c6b
 
0b644a6
6309c6b
0b644a6
 
 
17ab0da
6309c6b
 
 
 
 
 
a09c96f
 
6309c6b
 
a09c96f
6309c6b
a09c96f
6309c6b
 
 
 
 
 
 
 
a09c96f
 
 
0b644a6
 
395cd13
6309c6b
 
 
 
 
 
 
a09c96f
 
 
 
 
 
 
 
 
 
 
 
 
6309c6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b644a6
a09c96f
6309c6b
 
 
 
 
 
 
 
 
 
 
410864e
2a96b36
6309c6b
 
 
 
daf470d
a09c96f
6309c6b
 
0d7ab1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from huggingface_hub import InferenceClient
import gradio as gr
import os

API_URL = {
    "Mistral" : "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Codestral" : "mistralai/Codestral-22B-v0.1"
}

HF_TOKEN = os.environ['HF_TOKEN']
Hinglish_Prompt = os.environ['Hinglish_Prompt']

mistralClient = InferenceClient(
    API_URL["Mistral"],
    headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
)

codestralClient = InferenceClient(
    model = API_URL["Codestral"],
    headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
)

def format_prompt(message, history, enable_hinglish=False):
  prompt = "<s>"
  # Adding the Hinglish prompt
  if enable_hinglish and not any("[INST] You are a Hinglish LLM." in user_prompt for user_prompt, bot_response in history):
      prompt += Hinglish_Prompt
      
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(prompt, history, model = "Mistral", enable_hinglish=False, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature) # Generation arguments
    if temperature < 1e-2:
        temperature = 1e-2
        
    top_p = float(top_p)
    
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Selecting model to be used
    client = mistralClient if(model == "Mistral") else codestralClient
    
    formatted_prompt = format_prompt(prompt, history, enable_hinglish)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield output
    return output

additional_inputs=[
    gr.Dropdown(
        choices = ["Mistral","Codestral"],
        value = "Mistral",
        label = "Model to be used",
        interactive=True,
        info = "Mistral for general-purpose chatting and codestral for code related task (Supports 80+ languages)"
    ),
    gr.Checkbox(
        label="Hinglish",
        value=False,
        interactive=True,
        info="Enables the MistralTalk to talk in Hinglish (Combination of Hindi and English)",
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    ),
    
]

css = """
  #mkd {
    height: 500px;
    overflow: auto;
    border: 1px solid #ccc;
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>MistralTalk🗣️<h1><center>")
    gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1'>Mixtral-8x7B</a> model. 💬<h3><center>")
    gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        theme = gr.themes.Soft(),
        examples=[["What is the secret to life?"], ["How the universe works?"],["What can you do?"],["What is quantum mechanics?"],["Do you belive in after life?"], ["Java function to check if URL is valid or not."]]
    )

demo.queue(max_size=100).launch(debug=True)