File size: 2,039 Bytes
fc85928
 
 
 
 
 
 
126c386
37dac63
126c386
 
fc85928
 
126c386
508ada6
126c386
508ada6
fc85928
 
 
 
 
126c386
fc85928
f8f1995
fc85928
 
 
 
 
 
 
 
 
37dac63
fc85928
 
 
 
 
 
26b285e
37dac63
fc85928
 
 
 
 
 
 
 
 
 
 
 
26b285e
fc85928
 
 
 
37dac63
 
 
 
26b285e
 
 
 
 
 
fc85928
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import gradio as gr
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from typing import Dict
import os
import pandas as pd
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACE_TOKEN"))

FOUNDATIONS = ["authority", "care", "fairness", "loyalty", "sanctity"]
tokenizer = AutoTokenizer.from_pretrained(
    "joshnguyen/mformer-authority",
    use_auth_token=True
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODELS = {}
for foundation in FOUNDATIONS:
    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=f"joshnguyen/mformer-{foundation}",
        use_auth_token=True
    )
    model.eval()
    MODELS[foundation] = model.to(DEVICE)


def classify_text(text: str) -> Dict[str, float]:
    # Encode the prompt
    inputs = tokenizer([text],
                       padding=True,
                       truncation=True,
                       return_tensors='pt').to(DEVICE)
    scores = []
    for foundation in FOUNDATIONS:
        model = MODELS[foundation]
        outputs = model(**inputs)
        outputs = torch.softmax(outputs.logits, dim=1)
        outputs = outputs[:, 1]
        score = outputs.detach().cpu().numpy()[0]
        scores.append([foundation.capitalize(), score])
    scores = pd.DataFrame(scores, columns=["foundation", "score"])
    return scores


demo = gr.Interface(
    fn=classify_text,
    inputs=[
        # Prompt
        gr.Textbox(
            label="Input text",
            container=False,
            show_label=True,
            placeholder="Enter some text...",
            lines=12,
            scale=10,
        ),
    ],
    outputs=[
        gr.BarPlot(
            x="foundation",
            y="score",
            title="Moral foundations scores",
            x_title=" ",
            y_title=" ",
            y_lim=[0, 1],
            vertical=False,
            height=200,
            width=500,
        )
    ],
)

demo.queue(max_size=20).launch()