|
import torch |
|
import gradio as gr |
|
from transformers import pipeline |
|
import os |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") |
|
text_generation_model = "cahya/indochat-tiny" |
|
text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device) |
|
|
|
|
|
def get_answer(user_input, decoding_methods, top_k, top_p, temperature, repetition_penalty, penalty_alpha): |
|
if decoding_methods == "Beam Search": |
|
do_sample = False |
|
elif decoding_methods == "Sampling": |
|
do_sample = True |
|
else: |
|
do_sample = False |
|
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha) |
|
prompt = f"User: {user_input}\nAssistant: " |
|
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1, |
|
do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, |
|
repetition_penalty=repetition_penalty) |
|
answer = generated_text[0]["generated_text"] |
|
answer_without_prompt = answer[len(prompt)+1:] |
|
return answer_without_prompt |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown( |
|
"## IndoChat") |
|
with gr.Row(): |
|
with gr.Column(): |
|
user_input = gr.inputs.Textbox(placeholder="", |
|
label="Ask me something in Indonesian or English", |
|
default="Bagaimana cara mendidik anak supaya tidak berbohong?") |
|
decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"], |
|
default="Sampling") |
|
top_k = gr.inputs.Slider(label="Top K: The number of highest probability vocabulary tokens to keep", |
|
default=40, maximum=50, minimum=1, step=1) |
|
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0) |
|
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0) |
|
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0) |
|
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=1.1, step=0.05, minimum=1.0, maximum=2.0) |
|
with gr.Row(): |
|
button_generate_story = gr.Button("Submit") |
|
with gr.Column(): |
|
generated_answer = gr.Textbox() |
|
with gr.Row(): |
|
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)") |
|
|
|
button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, top_k, top_p, temperature, |
|
repetition_penalty, penalty_alpha], outputs=[generated_answer]) |
|
|
|
demo.launch(enable_queue=False) |