File size: 4,497 Bytes
7ea9682
 
 
 
 
468c17d
7ea9682
 
 
 
 
 
 
468c17d
7ea9682
468c17d
7ea9682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468c17d
 
 
7ea9682
 
 
468c17d
7ea9682
 
 
468c17d
7ea9682
 
 
 
 
 
 
 
 
 
 
 
468c17d
7ea9682
468c17d
7ea9682
 
 
 
 
 
 
 
 
 
468c17d
 
7ea9682
 
 
 
 
468c17d
 
7ea9682
 
 
468c17d
 
7ea9682
 
468c17d
7ea9682
 
 
 
 
 
 
 
 
468c17d
7ea9682
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import numpy as np
import networkx as nx
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Function to make logits consistent based on the hierarchy matrix R
def _make_logits_consistent(x, R):
    c_out = x.unsqueeze(1) + 10
    c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
    R_batch = R.expand(len(x), R.shape[1], R.shape[1]).to(x.device)
    final_out, _ = torch.max(R_batch * c_out, dim=2)
    return final_out - 10

# Function to initialize the model, tokenizer, and hierarchy matrix
def initialize_model():
    # Define the hierarchy graph
    G = nx.DiGraph()
    edges = [
        ("ROOT", "Logos"), 
            ("Logos", "Repetition"), ("Logos", "Obfuscation, Intentional vagueness, Confusion"), ("Logos", "Reasoning"), ("Logos", "Justification"), 
	            ("Justification", "Slogans"), ("Justification", "Bandwagon"), ("Justification", "Appeal to authority"), ("Justification", "Flag-waving"), ("Justification", "Appeal to fear/prejudice"),
	            ("Reasoning", "Simplification"), 
	            	("Simplification", "Causal Oversimplification"), ("Simplification", "Black-and-white Fallacy/Dictatorship"), ("Simplification", "Thought-terminating cliché"),  
                    ("Reasoning", "Distraction"),
                    	("Distraction", "Misrepresentation of Someone's Position (Straw Man)"), ("Distraction", "Presenting Irrelevant Data (Red Herring)"), ("Distraction", "Whataboutism"),
        ("ROOT", "Ethos"), 
            ("Ethos", "Appeal to authority"), ("Ethos", "Glittering generalities (Virtue)"), ("Ethos", "Bandwagon"), ("Ethos", "Ad Hominem"), ("Ethos", "Transfer"), 
                ("Ad Hominem", "Doubt"), ("Ad Hominem", "Name calling/Labeling"), ("Ad Hominem", "Smears"), ("Ad Hominem", "Reductio ad hitlerum"), ("Ad Hominem", "Whataboutism"), 
        ("ROOT", "Pathos"), 
            ("Pathos", "Exaggeration/Minimisation"), ("Pathos", "Loaded Language"), ("Pathos", "Appeal to (Strong) Emotions"), ("Pathos", "Appeal to fear/prejudice"), ("Pathos", "Flag-waving"), ("Pathos", "Transfer")
    ]
    G.add_edges_from(edges)
    
    # model and tokenizer is saved in the current directory
    model_dir = "."
    # loading the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    
    # Set device to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Create the hierarchy matrix R based on the graph structure
    A = nx.to_numpy_array(G).transpose()
    R = np.zeros(A.shape)
    np.fill_diagonal(R, 1)
    g = nx.DiGraph(A)
    for i in range(len(A)):
        descendants = list(nx.descendants(g, i))
        if descendants:
            R[i, descendants] = 1
    R = torch.tensor(R).transpose(1, 0).unsqueeze(0)
    
    return tokenizer, model, R, G, device

# Function to predict persuasion labels for a given text
def predict_persuasion_labels(text, tokenizer, model, R, G, device):
    # Tokenize and encode the input text
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,
        return_token_type_ids=False,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
    
    # Forward pass through the model
    with torch.no_grad():
        outputs = model(
            input_ids=encoding["input_ids"].to(device),
            attention_mask=encoding["attention_mask"].to(device),
        )

    # Make logits consistent based on the hierarchy matrix R
    logits = _make_logits_consistent(outputs.logits, R)
    logits[:, 0] = -1.0
    logits = logits > 0.0

    # Get the complete predicted hierarchy of labels
    complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()

    # Get the child-only labels (labels without any successors)
    child_only_labels = []
    for label in complete_predicted_hierarchy:
        if not list(G.successors(label)):
            child_only_labels.append(label)

    return complete_predicted_hierarchy, child_only_labels

tokenizer, model, R, G, device = initialize_model()

# Main inference function
def inference(text):
    return predict_persuasion_labels(text, tokenizer, model, R, G, device)

if __name__ == "__main__":
    # ask the user for input
    text = input("Enter the text: ")
    print(inference(text))