Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
# Function to load model and tokenizer | |
def load_model(): | |
tokenizer = BertTokenizer.from_pretrained("Minej/bert-base-personality") | |
model = BertForSequenceClassification.from_pretrained("Minej/bert-base-personality") | |
return tokenizer, model | |
# Load the model and tokenizer | |
tokenizer, model = load_model() | |
# Function to predict personality traits | |
def personality_detection(text): | |
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().numpy() | |
label_names = ['Extroversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness'] | |
result = {label_names[i]: predictions[i] for i in range(len(label_names))} | |
return result | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=personality_detection, | |
inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."), | |
outputs=gr.Label(), | |
title="Personality Analyzer", | |
description="Enter a sentence and get a prediction of personality traits." | |
) | |
# Launch the Gradio app on a specific port | |
interface.launch(server_port=7861) # You can change 7861 to another port if necessary | |