MarkelFe's picture
Remove forgotten print
4ab7d61
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn
import gradio as gr
import re
# CONF
MAX_LENGTH = 1024
device = 'cuda' if torch.cuda.is_available() else 'cpu'; print("Using:", device)
tokenizer = AutoTokenizer.from_pretrained("MarkelFe/PoliticalSpeech2", padding_side='left')
model = AutoModelForCausalLM.from_pretrained("MarkelFe/PoliticalSpeech2").to(device)
def return_conf(max_tokens, conf, ngram, beams, top_k, top_p):
if conf == "Ezer":
options = {"max_new_tokens": max_tokens, "do_sample": False}
elif conf == "Beam Search":
options = {"no_repeat_ngram_size": ngram, "num_beams": beams, "max_new_tokens": max_tokens, "do_sample": False}
elif conf == "Top K":
options = {"top_k": top_k, "max_new_tokens": max_tokens, "do_sample": False}
elif conf == "Top P":
options = {"top_p": top_p, "max_new_tokens": max_tokens, "do_sample": False}
return options
def sortu_testua(alderdia, testua, max_tokens, conf, ngram, beams, top_k, top_p):
options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p)
prompt = f"[{alderdia}] {testua}"
tokens = tokenizer(prompt, return_tensors="pt").to(device)
generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options)[0]
text = tokenizer.decode(generation)
return re.split("\[(.*?)\] ", text)[-1]
def sortu_testu_guztiak(testua, max_tokens, conf, ngram, beams, top_k, top_p):
options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p)
prompts = [f"[\"EAJ\"] {testua}", f"[\"EH Bildu\"] {testua}", f"[\"PP\"] {testua}", f"[\"PSE-EE\"] {testua}", f"[\"EP\"] {testua}", f"[\"UPyD\"] {testua}"]
tokens = tokenizer(prompts, padding = True, return_tensors="pt").to(device)
generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options)
texts = tokenizer.batch_decode(generation)
texts = list(map(lambda text: re.split("\[(.*?)\] ", text)[-1], texts))
return (texts[0], texts[1], texts[2], texts[3], texts[4], texts[5])
with gr.Blocks() as demo:
with gr.Tab("Alderdi bakarra"):
with gr.Row():
with gr.Column(scale=4, min_width=400):
alderdia = gr.Dropdown(["EAJ", "EH Bildu", "PP", "PSE-EE", "EP", "UPyD"], label="Alderdi politikoa",)
testua = gr.Textbox(label="Testua")
greet_btn = gr.Button("Sortu testua")
gr.Markdown("""Aldatu konfigurazioa""")
new_token = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.")
confi = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko")
ngram = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago")
beams = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago")
top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago")
with gr.Column(scale=3, min_width=200):
output = gr.Textbox(label="Output")
with gr.Tab("Alderdi guztiak"):
with gr.Row():
with gr.Column(scale=4, min_width=400):
testua2 = gr.Textbox(label="Testua")
greet_btn2 = gr.Button("Sortu testuak")
gr.Markdown("""Aldatu konfigurazioa""")
new_token2 = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.")
confi2 = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko")
ngram2 = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago")
beams2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago")
top_k2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago")
top_p2 = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago")
with gr.Column(scale=3, min_width=200):
outputEAJ = gr.Textbox(label="EAJ")
outputBildu = gr.Textbox(label="EH Bildu")
outputPP = gr.Textbox(label="PP")
outputPSE = gr.Textbox(label="PSE-EE")
outputEP = gr.Textbox(label="EP")
outputUPyD = gr.Textbox(label="UPyD")
greet_btn.click(fn=sortu_testua, inputs=[alderdia, testua, new_token, confi, ngram, beams, top_k, top_p], outputs=output, api_name="sortu_testua")
greet_btn2.click(fn=sortu_testu_guztiak, inputs=[testua2, new_token2, confi2, ngram2, beams2, top_k2, top_p2], outputs=[outputEAJ, outputBildu, outputPP, outputPSE, outputEP, outputUPyD], api_name="sortu_testu_guztiak")
demo.launch()