Wizmik12's picture
Update app.py
022b401 verified
raw history blame
No virus
2.74 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Ejemplos de preguntas
mis_ejemplos = [
["La cocina de los gallegos es fabulosa."],
["Los niños juegan a la pelota."],
["Los científicos son muy listos."],
["Las enfermeras se esforzaron mucho durante la pandemia."],
["Los políticos no son del agrado de los ciudadanos."]
]
# Load complete model in 4bits
##################
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
hub_model = 'Andresmfs/merged_aguila-prueba-guardado'
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(hub_model, trust_remote_code=True)
## Load model in 4bits
# bnb_configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False
)
# model
model = AutoModelForCausalLM.from_pretrained(
hub_model,
quantization_config=bnb_config,
trust_remote_code=True,
device_map="auto"
)
# generation_config
generation_config = model.generation_config
generation_config.max_new_tokens = 100
generation_config.temperature = 0.7
generation_config.top_p = 0.7
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.do_sample = True # line added
# Define inference function
def translate_es_inclusivo(exclusive_text):
# generate input prompt
eval_prompt = f"""Reescribe el siguiente texto utilizando lenguaje inclusivo.\n
Texto: {exclusive_text}\n
Texto en lenguaje inclusivo:"""
# tokenize input
model_input = tokenizer(eval_prompt, return_tensors="pt").to(model.device)
# set max_new_tokens if necessary
if len(model_input['input_ids'][0]) > 80:
model.generation_config.max_new_tokens = len(model_input['input_ids'][0]) + 0.2 * len(model_input['input_ids'][0])
# get length of encoded prompt
prompt_token_len = len(model_input['input_ids'][0])
# generate and decode
with torch.no_grad():
inclusive_text = tokenizer.decode(model.generate(**model_input, generation_config=generation_config)[0][prompt_token_len:],
skip_special_tokens=True)
return inclusive_text
iface = gr.Interface(
fn=translate_es_inclusivo,
inputs="text",
outputs="text",
title="ES Inclusive Language",
description="Enter a Spanish phrase and get it converted into neutral/inclusive form.",
examples = mis_ejemplos
)
demo.launch()