File size: 1,775 Bytes
2410db3
 
 
1cb6527
b81c54e
2410db3
e59dfd6
1cb6527
e59dfd6
 
 
1cb6527
 
 
bd20bce
1cb6527
bd20bce
1cb6527
bd20bce
1cb6527
2410db3
 
1cb6527
2410db3
1cb6527
 
 
 
2410db3
1cb6527
9d03989
2410db3
 
 
1cb6527
2410db3
1cb6527
2410db3
 
 
 
 
 
 
 
 
1cb6527
2410db3
1cb6527
2410db3
e59dfd6
2410db3
 
1cb6527
2410db3
 
 
 
bd20bce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from huggingface_hub import InferenceClient

# Initialize the InferenceClient with the appropriate model
client = InferenceClient("wop/kosmox")

def format_messages(history, user_message):
    # Create a formatted string according to the specified chat template
    formatted_message = "<s>"
    #if system_message:
    #    formatted_message += f"<|system|>\n{system_message}\n"

    for user_msg, assistant_msg in history:
        if user_msg:
            formatted_message += f"<|user|>\n{user_msg}\n"
        if assistant_msg:
            formatted_message += f"<|assistant|>\n{assistant_msg}\n"
    
    formatted_message += f"<|user|>\n{user_message}\n"
    return formatted_message

def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    # Format the messages
    formatted_message = format_messages(history, message)

    response = ""

    # Stream the response from the model
    for message in client.chat_completion(
        formatted_message,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Define the Gradio interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
       # gr.Textbox(value="", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()