File size: 4,405 Bytes
501a925
 
 
 
 
 
412cd1c
 
501a925
 
 
412cd1c
501a925
412cd1c
501a925
1ef1a6a
6d13d04
501a925
 
 
 
 
 
 
 
 
412cd1c
501a925
 
 
e2c318c
 
501a925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eef0f34
 
1ef1a6a
501a925
 
 
 
412cd1c
1ef1a6a
a389ea6
501a925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb1dc3e
501a925
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import os
from mtranslate import translate
import requests

HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
text_generator_api = 'https://cahya-indonesian-whisperer.hf.space/api/text-generator/v1'
text_generator_api_auth_token = os.getenv("TEXT_GENERATOR_API_AUTH_TOKEN", "")

def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
    print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
    headers = {'Authorization': 'Bearer ' + text_generator_api_auth_token}
    data = {
        "model_name": "bloomz-1b1-instruct",
        "text": user_input,
        "min_length": len(user_input) + 10,
        "max_length": len(user_input) + 200,
        "decoding_method": decoding_method,
        "num_beams": num_beams,
        "top_k": top_k,
        "top_p": top_p,
        "temperature": temperature,
        "seed": -1,
        "repetition_penalty": repetition_penalty,
        "penalty_alpha": penalty_alpha
    }
    r = requests.post(text_generator_api, headers=headers, data=data)
    if r.status_code == 200:
        result = r.json()
        answer = result["generated_text"]
        user_input_en = translate(user_input, "en", "auto")
        answer_en = translate(answer, "en", "auto")
        return [(f"{user_input}\n", None), (answer, "")], \
            [(f"{user_input_en}\n", None), (answer_en, "")]
    else:
        return "Error: " + r.text


css = """
#answer_id span {white-space: pre-line}
#answer_id span.label {display: none}
#answer_en span {white-space: pre-line}
#answer_en span.label {display: none}
"""

with gr.Blocks(css=css) as demo:
    with gr.Row():
        gr.Markdown("""## Bloomz-1b7-Instruct        
        We fine-tuned the BigScience model Bloomz-1b7 with cross-lingual instructions dataset. Some of the supported 
        languages are: English, Indonesian, Vietnam, Hindi, Spanish, French, and Chinese.
        """)
    with gr.Row():
        with gr.Column():
            user_input = gr.inputs.Textbox(placeholder="",
                                           label="Ask me something",
                                           default="Will we ever cure cancer? Please answer in Chinese.")
            decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling"],
                                                  default="Sampling", label="Decoding Method")
            num_beams = gr.inputs.Slider(label="Number of beams for beam search",
                                     default=1, minimum=1, maximum=10, step=1)
            top_k = gr.inputs.Slider(label="Top K",
                                     default=30, maximum=50, minimum=1, step=1)
            top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
            temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
            repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
            penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
                                             default=0.5, step=0.05, minimum=0.05, maximum=1.0)
            with gr.Row():
                button_generate_story = gr.Button("Submit")
        with gr.Column():
            # generated_answer = gr.Textbox()
            generated_answer = gr.HighlightedText(
                elem_id="answer_id",
                label="Generated Text",
                combine_adjacent=True,
                css="#htext span {white-space: pre-line}",
            ).style(color_map={"": "blue", "-": "green"})
            generated_answer_en = gr.HighlightedText(
                elem_id="answer_en",
                label="Translation",
                combine_adjacent=True,
            ).style(color_map={"": "blue", "-": "green"})
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_bloomz-1b1-instruct)")

    button_generate_story.click(get_answer,
                                inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature,
                                        repetition_penalty, penalty_alpha],
                                outputs=[generated_answer, generated_answer_en])

demo.launch(enable_queue=False)