File size: 2,268 Bytes
e674a51
1e6822b
e674a51
555f172
3e69d5d
984fbb4
 
 
 
555f172
e674a51
 
 
49ecfa7
984fbb4
555f172
9f453e8
5bce228
49ecfa7
 
e674a51
153d27f
984fbb4
 
70872ee
 
 
 
 
 
 
 
 
 
6b3b146
 
 
 
 
e684352
e674a51
 
 
 
 
 
 
692215d
 
e674a51
984fbb4
 
692215d
9f453e8
e674a51
 
 
 
 
 
5bce228
 
49ecfa7
 
 
 
 
 
 
e674a51
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import pipeline

models = {
    "DistilGPT2 SD": "FredZhang7/distilgpt2-stable-diffusion",
    "Llama-SmolTalk-3.2-1B": "prithivMLmods/Llama-SmolTalk-3.2-1B-Instruct",
    "Dolphin-Phi 3 2.9.2": "cognitivecomputations/dolphin-2.9.2-Phi-3-Medium",
    "EXAONE 3.5 2.4B": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
    "Granite 3.3 2B": "ibm-granite/granite-3.3-2b-instruct"
}

def respond(
    message,
    _: list[tuple[str, str]],
    system_prompt: str,
    model: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int
):
    pipe = pipeline("text-generation", model=model)

    yield pipe(
        [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": message
            }
        ],
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k
    )[0]['generated_text'][-1]


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    title="Prompt Enhancer Test",
    type="messages",
    additional_inputs=[
        gr.Textbox(value="When the user provides two sentences, return a longer sentence that fuses the two together with a natural motion in between.", lines=5, show_label=True, label="System prompt"),
        gr.Radio(list(models.items()), value="FredZhang7/distilgpt2-stable-diffusion", type="value", label="Model"),
        # gr.Textbox(value="Enhance the provided text so that it is more vibrant and detailed.", label="System prompt"),
        gr.Slider(minimum=8, maximum=128, value=64, step=8, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(
            minimum=10,
            maximum=100,
            value=30,
            step=5,
            label="Top-k",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()