File size: 4,075 Bytes
fb69876
69cc5ab
50d93bb
fb69876
fe4be4e
69cc5ab
a4b0cb5
fb69876
50d93bb
fb69876
69cc5ab
 
a4b0cb5
50d93bb
 
a4b0cb5
50d93bb
 
a4b0cb5
8454eb5
50d93bb
 
8454eb5
50d93bb
 
 
a4b0cb5
 
 
50d93bb
 
fe4be4e
 
 
 
50d93bb
 
 
 
a4b0cb5
50d93bb
 
 
 
 
99167bb
 
a4b0cb5
 
 
 
50d93bb
 
 
8454eb5
 
50d93bb
 
 
38f57a9
 
3c15525
38f57a9
 
fe4be4e
 
 
 
50d93bb
 
 
1ce5f7f
 
 
 
50d93bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import gradio as gr
from transformers import pipeline
import os
from mtranslate import translate

device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
text_generation_model = "cahya/indochat-tiny"
text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)


def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
    if decoding_methods == "Beam Search":
        do_sample = False
        penalty_alpha = 0
    elif decoding_methods == "Sampling":
        do_sample = True
        penalty_alpha = 0
        num_beams = 1
    else:
        do_sample = False
        num_beams = 1
    print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
    prompt = f"User: {user_input}\nAssistant: "
    generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
                                     num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
                                     temperature=temperature, repetition_penalty=repetition_penalty,
                                     penalty_alpha=penalty_alpha)
    answer = generated_text[0]["generated_text"]
    answer_without_prompt = answer[len(prompt)+1:]
    user_input_en = translate(user_input, "en", "id")
    answer_without_prompt_en = translate(answer_without_prompt, "en", "id")
    return [(f"{user_input} ", None), (answer_without_prompt, "")], \
        [(f"{user_input_en} ", None), (answer_without_prompt_en, "")]


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("## IndoChat")
    with gr.Row():
        with gr.Column():
            user_input = gr.inputs.Textbox(placeholder="",
                                           label="Ask me something in Indonesian or English",
                                           default="Bagaimana cara mendidik anak supaya tidak berbohong?")
            decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
                                                  default="Sampling")
            num_beams = gr.inputs.Slider(label="Number of beams for beam search",
                                     default=1, minimum=1, maximum=10, step=1)
            top_k = gr.inputs.Slider(label="Top K",
                                     default=30, maximum=50, minimum=1, step=1)
            top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
            temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
            repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
            penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
                                             default=0.5, step=0.05, minimum=0.05, maximum=1.0)
            with gr.Row():
                button_generate_story = gr.Button("Submit")
        with gr.Column():
            # generated_answer = gr.Textbox()
            generated_answer = gr.HighlightedText(
                label="Generated Text",
                combine_adjacent=True,
            ).style(color_map={"": "blue", "-": "green"})
            generated_answer_en = gr.HighlightedText(
                label="Translation",
                combine_adjacent=True,
            ).style(color_map={"": "blue", "-": "green"})
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")

    button_generate_story.click(get_answer,
                                inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
                                        repetition_penalty, penalty_alpha],
                                outputs=[generated_answer, generated_answer_en])

demo.launch(enable_queue=False)