Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import torch.nn | |
import gradio as gr | |
import re | |
# CONF | |
MAX_LENGTH = 1024 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu'; print("Using:", device) | |
tokenizer = AutoTokenizer.from_pretrained("MarkelFe/PoliticalSpeech2", padding_side='left') | |
model = AutoModelForCausalLM.from_pretrained("MarkelFe/PoliticalSpeech2").to(device) | |
def return_conf(max_tokens, conf, ngram, beams, top_k, top_p): | |
if conf == "Ezer": | |
options = {"max_new_tokens": max_tokens, "do_sample": False} | |
elif conf == "Beam Search": | |
options = {"no_repeat_ngram_size": ngram, "num_beams": beams, "max_new_tokens": max_tokens, "do_sample": False} | |
elif conf == "Top K": | |
options = {"top_k": top_k, "max_new_tokens": max_tokens, "do_sample": False} | |
elif conf == "Top P": | |
options = {"top_p": top_p, "max_new_tokens": max_tokens, "do_sample": False} | |
return options | |
def sortu_testua(alderdia, testua, max_tokens, conf, ngram, beams, top_k, top_p): | |
options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p) | |
prompt = f"[{alderdia}] {testua}" | |
tokens = tokenizer(prompt, return_tensors="pt").to(device) | |
generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options)[0] | |
text = tokenizer.decode(generation) | |
return re.split("\[(.*?)\] ", text)[-1] | |
def sortu_testu_guztiak(testua, max_tokens, conf, ngram, beams, top_k, top_p): | |
options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p) | |
prompts = [f"[\"EAJ\"] {testua}", f"[\"EH Bildu\"] {testua}", f"[\"PP\"] {testua}", f"[\"PSE-EE\"] {testua}", f"[\"EP\"] {testua}", f"[\"UPyD\"] {testua}"] | |
tokens = tokenizer(prompts, padding = True, return_tensors="pt").to(device) | |
generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options) | |
texts = tokenizer.batch_decode(generation) | |
texts = list(map(lambda text: re.split("\[(.*?)\] ", text)[-1], texts)) | |
return (texts[0], texts[1], texts[2], texts[3], texts[4], texts[5]) | |
with gr.Blocks() as demo: | |
with gr.Tab("Alderdi bakarra"): | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=400): | |
alderdia = gr.Dropdown(["EAJ", "EH Bildu", "PP", "PSE-EE", "EP", "UPyD"], label="Alderdi politikoa",) | |
testua = gr.Textbox(label="Testua") | |
greet_btn = gr.Button("Sortu testua") | |
gr.Markdown("""Aldatu konfigurazioa""") | |
new_token = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.") | |
confi = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko") | |
ngram = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") | |
beams = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") | |
top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago") | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago") | |
with gr.Column(scale=3, min_width=200): | |
output = gr.Textbox(label="Output") | |
with gr.Tab("Alderdi guztiak"): | |
with gr.Row(): | |
with gr.Column(scale=4, min_width=400): | |
testua2 = gr.Textbox(label="Testua") | |
greet_btn2 = gr.Button("Sortu testuak") | |
gr.Markdown("""Aldatu konfigurazioa""") | |
new_token2 = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.") | |
confi2 = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko") | |
ngram2 = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") | |
beams2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") | |
top_k2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago") | |
top_p2 = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago") | |
with gr.Column(scale=3, min_width=200): | |
outputEAJ = gr.Textbox(label="EAJ") | |
outputBildu = gr.Textbox(label="EH Bildu") | |
outputPP = gr.Textbox(label="PP") | |
outputPSE = gr.Textbox(label="PSE-EE") | |
outputEP = gr.Textbox(label="EP") | |
outputUPyD = gr.Textbox(label="UPyD") | |
greet_btn.click(fn=sortu_testua, inputs=[alderdia, testua, new_token, confi, ngram, beams, top_k, top_p], outputs=output, api_name="sortu_testua") | |
greet_btn2.click(fn=sortu_testu_guztiak, inputs=[testua2, new_token2, confi2, ngram2, beams2, top_k2, top_p2], outputs=[outputEAJ, outputBildu, outputPP, outputPSE, outputEP, outputUPyD], api_name="sortu_testu_guztiak") | |
demo.launch() |