File size: 1,354 Bytes
d381b23
b98bf6d
d381b23
 
b98bf6d
 
 
 
 
 
d381b23
b98bf6d
d381b23
 
b98bf6d
 
 
 
 
 
 
 
 
d381b23
 
 
b98bf6d
d381b23
 
 
 
 
 
4bdfa1f
b98bf6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
from transformers import BertTokenizer, BertForSequenceClassification
import torch

# Function to load model and tokenizer
def load_model():
    tokenizer = BertTokenizer.from_pretrained("Minej/bert-base-personality")
    model = BertForSequenceClassification.from_pretrained("Minej/bert-base-personality")
    return tokenizer, model

# Load the model and tokenizer
tokenizer, model = load_model()

# Function to predict personality traits
def personality_detection(text):
    inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().numpy()

    label_names = ['Extroversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
    result = {label_names[i]: predictions[i] for i in range(len(label_names))}
    return result

# Create the Gradio interface
interface = gr.Interface(
    fn=personality_detection,
    inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
    outputs=gr.Label(),
    title="Personality Analyzer",
    description="Enter a sentence and get a prediction of personality traits."
)

# Launch the Gradio app on a specific port
interface.launch(server_port=7861)  # You can change 7861 to another port if necessary