nishan-chatterjee commited on
Commit
468c17d
1 Parent(s): 933881f

additional comments

Browse files
Files changed (1) hide show
  1. inference.py +18 -2
inference.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import networkx as nx
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
 
6
  def _make_logits_consistent(x, R):
7
  c_out = x.unsqueeze(1) + 10
8
  c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
@@ -10,8 +11,9 @@ def _make_logits_consistent(x, R):
10
  final_out, _ = torch.max(R_batch * c_out, dim=2)
11
  return final_out - 10
12
 
 
13
  def initialize_model():
14
- model_dir = "."
15
  G = nx.DiGraph()
16
  edges = [
17
  ("ROOT", "Logos"),
@@ -29,12 +31,17 @@ def initialize_model():
29
  ]
30
  G.add_edges_from(edges)
31
 
 
 
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
33
  model = AutoModelForSequenceClassification.from_pretrained(model_dir)
34
 
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model.to(device)
37
 
 
38
  A = nx.to_numpy_array(G).transpose()
39
  R = np.zeros(A.shape)
40
  np.fill_diagonal(R, 1)
@@ -47,7 +54,9 @@ def initialize_model():
47
 
48
  return tokenizer, model, R, G, device
49
 
 
50
  def predict_persuasion_labels(text, tokenizer, model, R, G, device):
 
51
  encoding = tokenizer.encode_plus(
52
  text,
53
  add_special_tokens=True,
@@ -58,17 +67,23 @@ def predict_persuasion_labels(text, tokenizer, model, R, G, device):
58
  return_attention_mask=True,
59
  return_tensors="pt",
60
  )
61
-
 
62
  with torch.no_grad():
63
  outputs = model(
64
  input_ids=encoding["input_ids"].to(device),
65
  attention_mask=encoding["attention_mask"].to(device),
66
  )
 
 
67
  logits = _make_logits_consistent(outputs.logits, R)
68
  logits[:, 0] = -1.0
69
  logits = logits > 0.0
 
 
70
  complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
71
 
 
72
  child_only_labels = []
73
  for label in complete_predicted_hierarchy:
74
  if not list(G.successors(label)):
@@ -78,6 +93,7 @@ def predict_persuasion_labels(text, tokenizer, model, R, G, device):
78
 
79
  tokenizer, model, R, G, device = initialize_model()
80
 
 
81
  def inference(text):
82
  return predict_persuasion_labels(text, tokenizer, model, R, G, device)
83
 
 
3
  import networkx as nx
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
+ # Function to make logits consistent based on the hierarchy matrix R
7
  def _make_logits_consistent(x, R):
8
  c_out = x.unsqueeze(1) + 10
9
  c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
 
11
  final_out, _ = torch.max(R_batch * c_out, dim=2)
12
  return final_out - 10
13
 
14
+ # Function to initialize the model, tokenizer, and hierarchy matrix
15
  def initialize_model():
16
+ # Define the hierarchy graph
17
  G = nx.DiGraph()
18
  edges = [
19
  ("ROOT", "Logos"),
 
31
  ]
32
  G.add_edges_from(edges)
33
 
34
+ # model and tokenizer is saved in the current directory
35
+ model_dir = "."
36
+ # loading the model and tokenizer
37
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
38
  model = AutoModelForSequenceClassification.from_pretrained(model_dir)
39
 
40
+ # Set device to GPU if available
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  model.to(device)
43
 
44
+ # Create the hierarchy matrix R based on the graph structure
45
  A = nx.to_numpy_array(G).transpose()
46
  R = np.zeros(A.shape)
47
  np.fill_diagonal(R, 1)
 
54
 
55
  return tokenizer, model, R, G, device
56
 
57
+ # Function to predict persuasion labels for a given text
58
  def predict_persuasion_labels(text, tokenizer, model, R, G, device):
59
+ # Tokenize and encode the input text
60
  encoding = tokenizer.encode_plus(
61
  text,
62
  add_special_tokens=True,
 
67
  return_attention_mask=True,
68
  return_tensors="pt",
69
  )
70
+
71
+ # Forward pass through the model
72
  with torch.no_grad():
73
  outputs = model(
74
  input_ids=encoding["input_ids"].to(device),
75
  attention_mask=encoding["attention_mask"].to(device),
76
  )
77
+
78
+ # Make logits consistent based on the hierarchy matrix R
79
  logits = _make_logits_consistent(outputs.logits, R)
80
  logits[:, 0] = -1.0
81
  logits = logits > 0.0
82
+
83
+ # Get the complete predicted hierarchy of labels
84
  complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
85
 
86
+ # Get the child-only labels (labels without any successors)
87
  child_only_labels = []
88
  for label in complete_predicted_hierarchy:
89
  if not list(G.successors(label)):
 
93
 
94
  tokenizer, model, R, G, device = initialize_model()
95
 
96
+ # Main inference function
97
  def inference(text):
98
  return predict_persuasion_labels(text, tokenizer, model, R, G, device)
99