indochat / app.py
cahya's picture
use external inference
832a8c0
raw
history blame
4.47 kB
import gradio as gr
import os
from mtranslate import translate
import requests
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/indochat/v1'
indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "")
def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
headers = {'Authorization': 'Bearer ' + indochat_api_auth_token}
data = {
"text": user_input,
"min_length": len(user_input) + 50,
"max_length": 300,
"decoding_method": decoding_method,
"num_beams": num_beams,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"seed": -1,
"repetition_penalty": repetition_penalty,
"penalty_alpha": penalty_alpha
}
r = requests.post(indochat_api, headers=headers, data=data)
if r.status_code == 200:
result = r.json()
answer = result["generated_text"]
user_input_en = translate(user_input, "en", "id")
answer_en = translate(answer, "en", "id")
return [(f"{user_input}\n", None), (answer, "")], \
[(f"{user_input_en}\n", None), (answer_en, "")]
else:
return "Error: " + r.text
css = """
#answer_id span {white-space: pre-line}
#answer_id span.label {display: none}
#answer_en span {white-space: pre-line}
#answer_en span.label {display: none}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
gr.Markdown("""## IndoChat
A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with
multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB
of Indonesian and English dataset, where English part is only less than 1% of the whole dataset.
""")
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="",
label="Ask me something in Indonesian or English",
default="Bagaimana cara mendidik anak supaya tidak berbohong?")
decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
default="Sampling", label="Decoding Method")
num_beams = gr.inputs.Slider(label="Number of beams for beam search",
default=1, minimum=1, maximum=10, step=1)
top_k = gr.inputs.Slider(label="Top K",
default=30, maximum=50, minimum=1, step=1)
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
default=0.5, step=0.05, minimum=0.05, maximum=1.0)
with gr.Row():
button_generate_story = gr.Button("Submit")
with gr.Column():
# generated_answer = gr.Textbox()
generated_answer = gr.HighlightedText(
elem_id="answer_id",
label="Generated Text",
combine_adjacent=True,
css="#htext span {white-space: pre-line}",
).style(color_map={"": "blue", "-": "green"})
generated_answer_en = gr.HighlightedText(
elem_id="answer_en",
label="Translation",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
button_generate_story.click(get_answer,
inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature,
repetition_penalty, penalty_alpha],
outputs=[generated_answer, generated_answer_en])
demo.launch(enable_queue=False)