Spaces:
Running
Running
import gradio as gr | |
from utils import ( | |
device, | |
jina_tokenizer, | |
jina_model, | |
embeddings_predict_relevance, | |
stsb_model, | |
stsb_tokenizer, | |
cross_encoder_predict_relevance | |
) | |
def predict(system_prompt, user_prompt): | |
predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device) | |
predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device) | |
result = f""" | |
**Prediction Summary** | |
**1. Model: jinaai/jina-embeddings-v2-small-en** | |
- **Prediction**: {"π₯ Off-topic" if predicted_label_jina==1 else "π© On-topic"} | |
- **Probability of being off-topic**: {probabilities_jina[0][1]:.2%} | |
**2. Model: cross-encoder/stsb-roberta-base** | |
- **Prediction**: {"π₯ Off-topic" if predicted_label_stsb==1 else "π© On-topic"} | |
- **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%} | |
""" | |
return result | |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app: | |
gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models") | |
with gr.Row(): | |
system_prompt = gr.TextArea(label="System Prompt", lines=5) | |
user_prompt = gr.TextArea(label="User Prompt", lines=5) | |
# Button to run the prediction | |
get_classfication = gr.Button("Check Content") | |
output_result = gr.Markdown(label="Classification and Probabilities") | |
get_classfication.click( | |
fn=predict, | |
inputs=[system_prompt, user_prompt], | |
outputs=output_result | |
) | |
if __name__ == "__main__": | |
app.launch() | |