halejosm's picture
Update app.py
26e0e94 verified
raw
history blame
2.99 kB
import streamlit as st
from transformers import pipeline
import time
import matplotlib.pyplot as plt
#######################creacion de funciones######################
@st.cache_resource
#funcion eleccion del modelo
def cargar_modelo(model_name):
return pipeline("zero-shot-classification", model=model_name, device=0)
#funcion que grafica el resultado del modelo
def graficar(result):
# Crear un gr谩fico de pastel
fig, ax = plt.subplots()
ax.pie(result["scores"], labels=result["labels"], autopct="%1.1f%%", startangle=90)
ax.axis("equal") # Hace que el gr谩fico sea un c铆rculo perfecto
# Mostrar el gr谩fico en Streamlit
st.pyplot(fig)
#######################fin creacion de funciones######################
st.title("Ejercicio interfaz transformers pipeline con modelos Zero-shot")
model_options = [
"MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
"MoritzLaurer/multilingual-MiniLMv2-L6-mnli-xnli",
"facebook/bart-large-mnli",
]
# "roberta-large-mnli",
# "cross-encoder/nli-roberta-base"
# Seleccionar el modelo del pipeline
selected_model = st.selectbox("Selecciona un modelo", model_options)
#asignacion del modelo a usar
classifier = cargar_modelo(selected_model)
# Entrada de texto para la oraci贸n
nombre = st.text_input("Ingrese un nombre o una oraci贸n a clasificar:")
etiquetas = st.text_area(
"Ingresa las categor铆as separadas por comas:",
placeholder="Ejemplo: masculino,femenino"
)
# Bot贸n para clasificar
if st.button("Clasificar"):
# Show a spinner during a process
with st.spinner(text="Ejecutando el modelo"):
if nombre and etiquetas:
# basado en el string crear array labels que pasan como clases al modelo
labels = [label.strip() for label in etiquetas.split(",")]
#Llamado del modelo y las clases elegidas por el usuario
result = classifier(nombre, candidate_labels=labels)
# Mostrar los resultados
st.subheader("Resultados de Clasificaci贸n")
#for label, score in zip(result["labels"], result["scores"]):
# st.write(f"**{label}**: {score:.2f}")
for label, score in zip(result["labels"], result["scores"]):
with st.container():
st.markdown(
f"""
<div style="background-color: #283747;padding:10px;margin-bottom:10px;border-radius:5px;border: 1px solid #ddd;">
<h4 style="margin:0;">{label}</h4>
<p style="margin:0;">Confianza: <b>{score:.2f}</b></p>
</div>
""",
unsafe_allow_html=True,
)
graficar(result)
else:
st.warning("Por favor, ingresa una oraci贸n y categor铆as v谩lidas.")
st.success("Revisar el resultado obtenido en la parte superior")