Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
peft_model_id = "rwheel/discriminacion_gitana_intervenciones" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained(peft_model_id) | |
# Load the Lora model | |
model = PeftModel.from_pretrained(model, peft_model_id) | |
def predecir_intervencion(text): | |
text = "<SH>" + text + " Intervenci贸n: " | |
batch = tokenizer(text, return_tensors='pt') | |
with torch.cuda.amp.autocast(): | |
output_tokens = model.generate(**batch, max_new_tokens=256, eos_token_id=50258) | |
output = tokenizer.decode(output_tokens[0], skip_special_tokens=False) | |
aux = output.split("Intervenci贸n:")[1].strip() | |
intervencion = aux.split("Resultado:")[0].strip() | |
resultado = aux.split("Resultado:")[1].split("<EH>")[0].strip() | |
return intervencion, resultado | |
with gr.Blocks() as demo: | |
gr.Markdown("Predicci贸n de intervenciones para mitigar el da帽o racista en el pueblo gitano") | |
with gr.Row(): | |
hechos = gr.Textbox(placeholder="Un alumno gitano de un Instituto...") | |
with gr.Row(): | |
intervencion = gr.Textbox() | |
resultado = gr.Textbox() | |
btn = gr.Button("Go") | |
btn.click(fn=predecir_intervencion, inputs=hechos, outputs=[intervencion, resultado]) | |
demo.launch(share=True) |