sarahyurick commited on
Commit
db91539
1 Parent(s): 61e8f13

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +207 -1
README.md CHANGED
@@ -54,7 +54,213 @@ The inference code for this model is available through the NeMo Curator GitHub r
54
  To use the prompt task and complexity classifier, use the following code:
55
 
56
  ```python
57
- # TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ```
59
 
60
  # Input & Output
 
54
  To use the prompt task and complexity classifier, use the following code:
55
 
56
  ```python
57
+ import numpy as np
58
+ import torch
59
+ import torch.nn as nn
60
+ from huggingface_hub import PyTorchModelHubMixin
61
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
62
+
63
+
64
+ class MeanPooling(nn.Module):
65
+ def __init__(self):
66
+ super(MeanPooling, self).__init__()
67
+
68
+ def forward(self, last_hidden_state, attention_mask):
69
+ input_mask_expanded = (
70
+ attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
71
+ )
72
+ sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
73
+
74
+ sum_mask = input_mask_expanded.sum(1)
75
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
76
+
77
+ mean_embeddings = sum_embeddings / sum_mask
78
+ return mean_embeddings
79
+
80
+
81
+ class MulticlassHead(nn.Module):
82
+ def __init__(self, input_size, num_classes):
83
+ super(MulticlassHead, self).__init__()
84
+ self.fc = nn.Linear(input_size, num_classes)
85
+
86
+ def forward(self, x):
87
+ x = self.fc(x)
88
+ return x
89
+
90
+
91
+ class CustomModel(nn.Module, PyTorchModelHubMixin):
92
+ def __init__(self, target_sizes, task_type_map, weights_map, divisor_map):
93
+ super(CustomModel, self).__init__()
94
+
95
+ self.backbone = AutoModel.from_pretrained("microsoft/DeBERTa-v3-base")
96
+ self.target_sizes = target_sizes.values()
97
+ self.task_type_map = task_type_map
98
+ self.weights_map = weights_map
99
+ self.divisor_map = divisor_map
100
+
101
+ self.heads = [
102
+ MulticlassHead(self.backbone.config.hidden_size, sz)
103
+ for sz in self.target_sizes
104
+ ]
105
+
106
+ for i, head in enumerate(self.heads):
107
+ self.add_module(f"head_{i}", head)
108
+
109
+ self.pool = MeanPooling()
110
+
111
+ def compute_results(self, preds, target, decimal=4):
112
+ if target == "task_type":
113
+ task_type = {}
114
+
115
+ top2_indices = torch.topk(preds, k=2, dim=1).indices
116
+ softmax_probs = torch.softmax(preds, dim=1)
117
+ top2_probs = softmax_probs.gather(1, top2_indices)
118
+ top2 = top2_indices.detach().cpu().tolist()
119
+ top2_prob = top2_probs.detach().cpu().tolist()
120
+
121
+ top2_strings = [
122
+ [self.task_type_map[str(idx)] for idx in sample] for sample in top2
123
+ ]
124
+ top2_prob_rounded = [
125
+ [round(value, 3) for value in sublist] for sublist in top2_prob
126
+ ]
127
+
128
+ counter = 0
129
+ for sublist in top2_prob_rounded:
130
+ if sublist[1] < 0.1:
131
+ top2_strings[counter][1] = "NA"
132
+ counter += 1
133
+
134
+ task_type_1 = [sublist[0] for sublist in top2_strings]
135
+ task_type_2 = [sublist[1] for sublist in top2_strings]
136
+ task_type_prob = [sublist[0] for sublist in top2_prob_rounded]
137
+
138
+ return (task_type_1, task_type_2, task_type_prob)
139
+
140
+ else:
141
+ preds = torch.softmax(preds, dim=1)
142
+
143
+ weights = np.array(self.weights_map[target])
144
+ weighted_sum = np.sum(np.array(preds.detach().cpu()) * weights, axis=1)
145
+ scores = weighted_sum / self.divisor_map[target]
146
+
147
+ scores = [round(value, decimal) for value in scores]
148
+ if target == "number_of_few_shots":
149
+ scores = [x if x >= 0.05 else 0 for x in scores]
150
+ return scores
151
+
152
+ def process_logits(self, logits):
153
+ result = {}
154
+
155
+ # Round 1: "task_type"
156
+ task_type_logits = logits[0]
157
+ task_type_results = self.compute_results(task_type_logits, target="task_type")
158
+ result["task_type_1"] = task_type_results[0]
159
+ result["task_type_2"] = task_type_results[1]
160
+ result["task_type_prob"] = task_type_results[2]
161
+
162
+ # Round 2: "creativity_scope"
163
+ creativity_scope_logits = logits[1]
164
+ target = "creativity_scope"
165
+ result[target] = self.compute_results(creativity_scope_logits, target=target)
166
+
167
+ # Round 3: "reasoning"
168
+ reasoning_logits = logits[2]
169
+ target = "reasoning"
170
+ result[target] = self.compute_results(reasoning_logits, target=target)
171
+
172
+ # Round 4: "contextual_knowledge"
173
+ contextual_knowledge_logits = logits[3]
174
+ target = "contextual_knowledge"
175
+ result[target] = self.compute_results(
176
+ contextual_knowledge_logits, target=target
177
+ )
178
+
179
+ # Round 5: "number_of_few_shots"
180
+ number_of_few_shots_logits = logits[4]
181
+ target = "number_of_few_shots"
182
+ result[target] = self.compute_results(number_of_few_shots_logits, target=target)
183
+
184
+ # Round 6: "domain_knowledge"
185
+ domain_knowledge_logits = logits[5]
186
+ target = "domain_knowledge"
187
+ result[target] = self.compute_results(domain_knowledge_logits, target=target)
188
+
189
+ # Round 7: "no_label_reason"
190
+ no_label_reason_logits = logits[6]
191
+ target = "no_label_reason"
192
+ result[target] = self.compute_results(no_label_reason_logits, target=target)
193
+
194
+ # Round 8: "constraint_ct"
195
+ constraint_ct_logits = logits[7]
196
+ target = "constraint_ct"
197
+ result[target] = self.compute_results(constraint_ct_logits, target=target)
198
+
199
+ # Round 9: "prompt_complexity_score"
200
+ result["prompt_complexity_score"] = [
201
+ round(
202
+ 0.35 * creativity
203
+ + 0.25 * reasoning
204
+ + 0.15 * constraint
205
+ + 0.15 * domain_knowledge
206
+ + 0.05 * contextual_knowledge
207
+ + 0.05 * few_shots,
208
+ 5,
209
+ )
210
+ for creativity, reasoning, constraint, domain_knowledge, contextual_knowledge, few_shots in zip(
211
+ result["creativity_scope"],
212
+ result["reasoning"],
213
+ result["constraint_ct"],
214
+ result["domain_knowledge"],
215
+ result["contextual_knowledge"],
216
+ result["number_of_few_shots"],
217
+ )
218
+ ]
219
+
220
+ return result
221
+
222
+ def forward(self, batch):
223
+ input_ids = batch["input_ids"]
224
+ attention_mask = batch["attention_mask"]
225
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
226
+
227
+ last_hidden_state = outputs.last_hidden_state
228
+ mean_pooled_representation = self.pool(last_hidden_state, attention_mask)
229
+
230
+ logits = [
231
+ self.heads[k](mean_pooled_representation)
232
+ for k in range(len(self.target_sizes))
233
+ ]
234
+
235
+ return self.process_logits(logits)
236
+
237
+
238
+ config = AutoConfig.from_pretrained("nvidia/prompt-task-and-complexity-classifier")
239
+ tokenizer = AutoTokenizer.from_pretrained(
240
+ "nvidia/prompt-task-and-complexity-classifier"
241
+ )
242
+ model = CustomModel(
243
+ target_sizes=config.target_sizes,
244
+ task_type_map=config.task_type_map,
245
+ weights_map=config.weights_map,
246
+ divisor_map=config.divisor_map,
247
+ ).from_pretrained("nvidia/prompt-task-and-complexity-classifier")
248
+ model.eval()
249
+
250
+ prompt = ["Prompt: Write a Python script that uses a for loop."]
251
+
252
+ encoded_texts = tokenizer(
253
+ prompt,
254
+ return_tensors="pt",
255
+ add_special_tokens=True,
256
+ max_length=512,
257
+ padding="max_length",
258
+ truncation=True,
259
+ )
260
+
261
+ result = model(encoded_texts)
262
+ print(result)
263
+ # {'task_type_1': ['Code Generation'], 'task_type_2': ['Text Generation'], 'task_type_prob': [0.767], 'creativity_scope': [0.0826], 'reasoning': [0.0632], 'contextual_knowledge': [0.056], 'number_of_few_shots': [0], 'domain_knowledge': [0.9803], 'no_label_reason': [0.0], 'constraint_ct': [0.5578], 'prompt_complexity_score': [0.27822]}
264
  ```
265
 
266
  # Input & Output