import gradio as gr import os from mtranslate import translate import requests HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/indochat/v1' indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "") def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha): print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha) headers = {'Authorization': 'Bearer ' + indochat_api_auth_token} data = { "text": user_input, "min_length": len(user_input) + 50, "max_length": 300, "decoding_method": decoding_method, "num_beams": num_beams, "top_k": top_k, "top_p": top_p, "temperature": temperature, "seed": -1, "repetition_penalty": repetition_penalty, "penalty_alpha": penalty_alpha } r = requests.post(indochat_api, headers=headers, data=data) if r.status_code == 200: result = r.json() answer = result["generated_text"] user_input_en = translate(user_input, "en", "id") answer_en = translate(answer, "en", "id") return [(f"{user_input}\n", None), (answer, "")], \ [(f"{user_input_en}\n", None), (answer_en, "")] else: return "Error: " + r.text css = """ #answer_id span {white-space: pre-line} #answer_id span.label {display: none} #answer_en span {white-space: pre-line} #answer_en span.label {display: none} """ with gr.Blocks(css=css) as demo: with gr.Row(): gr.Markdown("""## IndoChat A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB of Indonesian and English dataset, where English part is only less than 1% of the whole dataset. """) with gr.Row(): with gr.Column(): user_input = gr.inputs.Textbox(placeholder="", label="Ask me something in Indonesian or English", default="Bagaimana cara mendidik anak supaya tidak berbohong?") decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"], default="Sampling", label="Decoding Method") num_beams = gr.inputs.Slider(label="Number of beams for beam search", default=1, minimum=1, maximum=10, step=1) top_k = gr.inputs.Slider(label="Top K", default=30, maximum=50, minimum=1, step=1) top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0) temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0) repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0) penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=0.5, step=0.05, minimum=0.05, maximum=1.0) with gr.Row(): button_generate_story = gr.Button("Submit") with gr.Column(): # generated_answer = gr.Textbox() generated_answer = gr.HighlightedText( elem_id="answer_id", label="Generated Text", combine_adjacent=True, css="#htext span {white-space: pre-line}", ).style(color_map={"": "blue", "-": "green"}) generated_answer_en = gr.HighlightedText( elem_id="answer_en", label="Translation", combine_adjacent=True, ).style(color_map={"": "blue", "-": "green"}) with gr.Row(): gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)") button_generate_story.click(get_answer, inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha], outputs=[generated_answer, generated_answer_en]) demo.launch(enable_queue=False)