import gradio as gr
import re
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from keybert import KeyBERT
# Initialize your model and tokenizer here
model_identifier = "karalif/myTestModel"
new_model = AutoModelForSequenceClassification.from_pretrained(model_identifier)
new_tokenizer = AutoTokenizer.from_pretrained(model_identifier)
def get_prediction(text):
# Tokenize the input text
encoding = new_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200)
encoding = {k: v.to(new_model.device) for k, v in encoding.items()}
with torch.no_grad():
outputs = new_model(**encoding)
logits = outputs.logits
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu()).numpy()
# Initialize KeyBERT
kw_model = KeyBERT()
keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words='english', use_maxsum=True, nr_candidates=20, top_n=5)
# Prepare the HTML output with labels and their probabilities
response = ""
labels = ['Politeness', 'Toxicity', 'Sentiment', 'Formality']
colors = ['#b8e994', '#f8d7da', '#fff3cd', '#bee5eb'] # Corresponding colors for labels
for i, label in enumerate(labels):
response += f"{label}: {probs[i]*100:.1f}%
"
influential_keywords = "INFLUENTIAL KEYWORDS:
"
for keyword, score in keywords:
influential_keywords += f"{keyword} (Score: {score:.2f})
"
return response, keywords, influential_keywords
def predict(text):
greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)"
prediction_output, keywords, influential_keywords = get_prediction(text)
greeting_feedback = ""
# Highlight the keywords in the input text
modified_input = text
for keyword, _ in keywords:
modified_input = modified_input.replace(keyword, f"{keyword}")
if not re.match(greeting_pattern, text, re.IGNORECASE):
greeting_feedback = "OTHER FEEDBACK:
Heilsaðu dóninn þinn
"
response = f"INPUT:
{modified_input}
MY PREDICTION:
{prediction_output}
{influential_keywords}
{greeting_feedback}"
return response
description_html = """