indochat / app.py
cahya's picture
fix cuda device and penalti alpha
a4b0cb5
import torch
import gradio as gr
from transformers import pipeline
import os
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
text_generation_model = "cahya/indochat-tiny"
text_generation = pipeline("text-generation", text_generation_model, use_auth_token=HF_AUTH_TOKEN, device=device)
def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
if decoding_methods == "Beam Search":
do_sample = False
penalty_alpha = 0
elif decoding_methods == "Sampling":
do_sample = True
penalty_alpha = 0
else:
do_sample = False
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
prompt = f"User: {user_input}\nAssistant: "
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
num_beams=num_beams, do_sample=do_sample, top_k=top_k, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty,
penalty_alpha=penalty_alpha)
answer = generated_text[0]["generated_text"]
answer_without_prompt = answer[len(prompt)+1:]
return answer_without_prompt
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## IndoChat")
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="",
label="Ask me something in Indonesian or English",
default="Bagaimana cara mendidik anak supaya tidak berbohong?")
decoding_methods = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
default="Sampling")
num_beams = gr.inputs.Slider(label="Number of beams for beam search",
default=1, minimum=1, maximum=10, step=1)
top_k = gr.inputs.Slider(label="Top K",
default=30, maximum=50, minimum=1, step=1)
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
with gr.Row():
button_generate_story = gr.Button("Submit")
with gr.Column():
generated_answer = gr.Textbox()
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, num_beams, top_k, top_p, temperature,
repetition_penalty, penalty_alpha], outputs=[generated_answer])
demo.launch(enable_queue=False)