File size: 4,075 Bytes
fb69876 69cc5ab 50d93bb fb69876 fe4be4e 69cc5ab a4b0cb5 fb69876 50d93bb fb69876 69cc5ab a4b0cb5 50d93bb a4b0cb5 50d93bb a4b0cb5 8454eb5 50d93bb 8454eb5 50d93bb a4b0cb5 50d93bb fe4be4e 50d93bb a4b0cb5 50d93bb 99167bb a4b0cb5 50d93bb 8454eb5 50d93bb 38f57a9 3c15525 38f57a9 fe4be4e 50d93bb 1ce5f7f 50d93bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
import gradio as gr
from transformers import pipeline
import os
from mtranslate import translate
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
text_generation_model = "cahya/indochat-tiny"
text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
if decoding_methods == "Beam Search":
do_sample = False
penalty_alpha = 0
elif decoding_methods == "Sampling":
do_sample = True
penalty_alpha = 0
num_beams = 1
else:
do_sample = False
num_beams = 1
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
prompt = f"User: {user_input}\nAssistant: "
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty,
penalty_alpha=penalty_alpha)
answer = generated_text[0]["generated_text"]
answer_without_prompt = answer[len(prompt)+1:]
user_input_en = translate(user_input, "en", "id")
answer_without_prompt_en = translate(answer_without_prompt, "en", "id")
return [(f"{user_input} ", None), (answer_without_prompt, "")], \
[(f"{user_input_en} ", None), (answer_without_prompt_en, "")]
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## IndoChat")
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="",
label="Ask me something in Indonesian or English",
default="Bagaimana cara mendidik anak supaya tidak berbohong?")
decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
default="Sampling")
num_beams = gr.inputs.Slider(label="Number of beams for beam search",
default=1, minimum=1, maximum=10, step=1)
top_k = gr.inputs.Slider(label="Top K",
default=30, maximum=50, minimum=1, step=1)
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
default=0.5, step=0.05, minimum=0.05, maximum=1.0)
with gr.Row():
button_generate_story = gr.Button("Submit")
with gr.Column():
# generated_answer = gr.Textbox()
generated_answer = gr.HighlightedText(
label="Generated Text",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
generated_answer_en = gr.HighlightedText(
label="Translation",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
button_generate_story.click(get_answer,
inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
repetition_penalty, penalty_alpha],
outputs=[generated_answer, generated_answer_en])
demo.launch(enable_queue=False) |