Spaces:
Runtime error
Runtime error
File size: 2,039 Bytes
fc85928 126c386 37dac63 126c386 fc85928 126c386 508ada6 126c386 508ada6 fc85928 126c386 fc85928 f8f1995 fc85928 37dac63 fc85928 26b285e 37dac63 fc85928 26b285e fc85928 37dac63 26b285e fc85928 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import gradio as gr
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
from typing import Dict
import os
import pandas as pd
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACE_TOKEN"))
FOUNDATIONS = ["authority", "care", "fairness", "loyalty", "sanctity"]
tokenizer = AutoTokenizer.from_pretrained(
"joshnguyen/mformer-authority",
use_auth_token=True
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODELS = {}
for foundation in FOUNDATIONS:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=f"joshnguyen/mformer-{foundation}",
use_auth_token=True
)
model.eval()
MODELS[foundation] = model.to(DEVICE)
def classify_text(text: str) -> Dict[str, float]:
# Encode the prompt
inputs = tokenizer([text],
padding=True,
truncation=True,
return_tensors='pt').to(DEVICE)
scores = []
for foundation in FOUNDATIONS:
model = MODELS[foundation]
outputs = model(**inputs)
outputs = torch.softmax(outputs.logits, dim=1)
outputs = outputs[:, 1]
score = outputs.detach().cpu().numpy()[0]
scores.append([foundation.capitalize(), score])
scores = pd.DataFrame(scores, columns=["foundation", "score"])
return scores
demo = gr.Interface(
fn=classify_text,
inputs=[
# Prompt
gr.Textbox(
label="Input text",
container=False,
show_label=True,
placeholder="Enter some text...",
lines=12,
scale=10,
),
],
outputs=[
gr.BarPlot(
x="foundation",
y="score",
title="Moral foundations scores",
x_title=" ",
y_title=" ",
y_lim=[0, 1],
vertical=False,
height=200,
width=500,
)
],
)
demo.queue(max_size=20).launch()
|