File size: 4,469 Bytes
69cc5ab
fb69876
fe4be4e
832a8c0
69cc5ab
fb69876
832a8c0
 
69cc5ab
832a8c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d93bb
832a8c0
50d93bb
 
9aae25d
 
 
 
 
 
 
 
50d93bb
e55e013
 
 
 
 
 
50d93bb
 
 
 
 
832a8c0
07c107f
a4b0cb5
 
 
 
50d93bb
 
 
8454eb5
 
50d93bb
 
 
38f57a9
 
3284509
3c15525
38f57a9
3aa662c
38f57a9
fe4be4e
3284509
fe4be4e
 
 
50d93bb
 
 
1ce5f7f
832a8c0
1ce5f7f
 
50d93bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import os
from mtranslate import translate
import requests

HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/indochat/v1'
indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "")

def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
    print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
    headers = {'Authorization': 'Bearer ' + indochat_api_auth_token}
    data = {
        "text": user_input,
        "min_length": len(user_input) + 50,
        "max_length": 300,
        "decoding_method": decoding_method,
        "num_beams": num_beams,
        "top_k": top_k,
        "top_p": top_p,
        "temperature": temperature,
        "seed": -1,
        "repetition_penalty": repetition_penalty,
        "penalty_alpha": penalty_alpha
    }
    r = requests.post(indochat_api, headers=headers, data=data)
    if r.status_code == 200:
        result = r.json()
        answer = result["generated_text"]
        user_input_en = translate(user_input, "en", "id")
        answer_en = translate(answer, "en", "id")
        return [(f"{user_input}\n", None), (answer, "")], \
            [(f"{user_input_en}\n", None), (answer_en, "")]
    else:
        return "Error: " + r.text


css = """
#answer_id span {white-space: pre-line}
#answer_id span.label {display: none}
#answer_en span {white-space: pre-line}
#answer_en span.label {display: none}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row():
        gr.Markdown("""## IndoChat
        
        A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with 
        multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB 
        of Indonesian and English dataset, where English part is only less than 1% of the whole dataset.
        """)
    with gr.Row():
        with gr.Column():
            user_input = gr.inputs.Textbox(placeholder="",
                                           label="Ask me something in Indonesian or English",
                                           default="Bagaimana cara mendidik anak supaya tidak berbohong?")
            decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
                                                  default="Sampling", label="Decoding Method")
            num_beams = gr.inputs.Slider(label="Number of beams for beam search",
                                     default=1, minimum=1, maximum=10, step=1)
            top_k = gr.inputs.Slider(label="Top K",
                                     default=30, maximum=50, minimum=1, step=1)
            top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
            temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
            repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
            penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
                                             default=0.5, step=0.05, minimum=0.05, maximum=1.0)
            with gr.Row():
                button_generate_story = gr.Button("Submit")
        with gr.Column():
            # generated_answer = gr.Textbox()
            generated_answer = gr.HighlightedText(
                elem_id="answer_id",
                label="Generated Text",
                combine_adjacent=True,
                css="#htext span {white-space: pre-line}",
            ).style(color_map={"": "blue", "-": "green"})
            generated_answer_en = gr.HighlightedText(
                elem_id="answer_en",
                label="Translation",
                combine_adjacent=True,
            ).style(color_map={"": "blue", "-": "green"})
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")

    button_generate_story.click(get_answer,
                                inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature,
                                        repetition_penalty, penalty_alpha],
                                outputs=[generated_answer, generated_answer_en])

demo.launch(enable_queue=False)