import torch import gradio as gr from transformers import pipeline import os from mtranslate import translate device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") text_generation_model = "cahya/indochat-tiny" text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device) def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha): if decoding_methods == "Beam Search": do_sample = False penalty_alpha = 0 elif decoding_methods == "Sampling": do_sample = True penalty_alpha = 0 num_beams = 1 else: do_sample = False num_beams = 1 print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha) prompt = f"User: {user_input}\nAssistant: " generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1, num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, penalty_alpha=penalty_alpha) answer = generated_text[0]["generated_text"] answer_without_prompt = answer[len(prompt)+1:] user_input_en = translate(user_input, "en", "id") answer_without_prompt_en = translate(answer_without_prompt, "en", "id") return [(f"{user_input} ", None), (answer_without_prompt, "")], \ [(f"{user_input_en} ", None), (answer_without_prompt_en, "")] with gr.Blocks() as demo: with gr.Row(): gr.Markdown("## IndoChat") with gr.Row(): with gr.Column(): user_input = gr.inputs.Textbox(placeholder="", label="Ask me something in Indonesian or English", default="Bagaimana cara mendidik anak supaya tidak berbohong?") decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"], default="Sampling") num_beams = gr.inputs.Slider(label="Number of beams for beam search", default=1, minimum=1, maximum=10, step=1) top_k = gr.inputs.Slider(label="Top K", default=30, maximum=50, minimum=1, step=1) top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0) temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0) repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0) penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=0.5, step=0.05, minimum=0.05, maximum=1.0) with gr.Row(): button_generate_story = gr.Button("Submit") with gr.Column(): # generated_answer = gr.Textbox() generated_answer = gr.HighlightedText( label="Generated Text", combine_adjacent=True, ).style(color_map={"": "blue", "-": "green"}) generated_answer_en = gr.HighlightedText( label="Translation", combine_adjacent=True, ).style(color_map={"": "blue", "-": "green"}) with gr.Row(): gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)") button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha], outputs=[generated_answer, generated_answer_en]) demo.launch(enable_queue=False)