param-bharat commited on
Commit
347d909
·
verified ·
1 Parent(s): 5ba8934

Upload DocumentSentenceRelevancePipeline

Browse files
Files changed (2) hide show
  1. config.json +9 -0
  2. pipeline.py +109 -0
config.json CHANGED
@@ -7,6 +7,15 @@
7
  "AutoModel": "modeling.MultiHeadModel"
8
  },
9
  "classifier_dropout": 0.1,
 
 
 
 
 
 
 
 
 
10
  "encoder_name": "tasksource/deberta-base-long-nli",
11
  "id2label": {
12
  "0": "irrelevant",
 
7
  "AutoModel": "modeling.MultiHeadModel"
8
  },
9
  "classifier_dropout": 0.1,
10
+ "custom_pipelines": {
11
+ "context-relevance": {
12
+ "impl": "pipeline.DocumentSentenceRelevancePipeline",
13
+ "pt": [
14
+ "AutoModel"
15
+ ],
16
+ "tf": []
17
+ }
18
+ },
19
  "encoder_name": "tasksource/deberta-base-long-nli",
20
  "id2label": {
21
  "0": "irrelevant",
pipeline.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch
3
+ from typing import Union
4
+
5
+
6
+
7
+ def convert_to_list(data):
8
+ first_list = next(iter(data.values()))
9
+ return [
10
+ {key: values[i] for key, values in data.items()}
11
+ for i in range(len(first_list))
12
+ ]
13
+
14
+ class DocumentSentenceRelevancePipeline(Pipeline):
15
+ def _sanitize_parameters(self, **kwargs):
16
+ threshold = kwargs.get("threshold", 0.5)
17
+ return {}, {}, {"threshold": threshold}
18
+
19
+ def preprocess(self, inputs):
20
+ question = inputs.get("question", "")
21
+ context = inputs.get("context", [""])
22
+ response = inputs.get("response", "")
23
+
24
+ q_enc = self.tokenizer(question, add_special_tokens=True, truncation=False, padding=False)
25
+ r_enc = self.tokenizer(response, add_special_tokens=True, truncation=False, padding=False)
26
+
27
+ question_ids = q_enc["input_ids"]
28
+ response_ids = r_enc["input_ids"]
29
+
30
+ document_sentences_ids = []
31
+ for s in context:
32
+ s_enc = self.tokenizer(s, add_special_tokens=True, truncation=False, padding=False)
33
+ document_sentences_ids.append(s_enc["input_ids"])
34
+
35
+ ids = question_ids + response_ids
36
+ pair_ids = []
37
+ for s_ids in document_sentences_ids:
38
+ pair_ids.extend(s_ids)
39
+
40
+ total_length = len(ids) + len(pair_ids)
41
+ if total_length > self.tokenizer.model_max_length:
42
+ num_tokens_to_remove = total_length - self.tokenizer.model_max_length
43
+ ids, pair_ids, _ = self.tokenizer.truncate_sequences(
44
+ ids=ids,
45
+ pair_ids=pair_ids,
46
+ num_tokens_to_remove=num_tokens_to_remove,
47
+ truncation_strategy="only_second",
48
+ stride=0,
49
+ )
50
+ combined_ids = ids + pair_ids
51
+ token_types = [0]*len(ids) + [1]*len(pair_ids)
52
+ attention_mask = [1]*len(combined_ids)
53
+
54
+ sentence_positions = []
55
+ current_pos = len(ids)
56
+ found_sentences = 0
57
+
58
+ for i, tok_id in enumerate(pair_ids):
59
+ if tok_id == self.tokenizer.cls_token_id:
60
+ sentence_positions.append(current_pos + i)
61
+ found_sentences += 1
62
+
63
+ input_ids = torch.tensor([combined_ids], dtype=torch.long)
64
+ attention_mask = torch.tensor([attention_mask], dtype=torch.long)
65
+ token_type_ids = torch.tensor([token_types], dtype=torch.long)
66
+ sentence_positions = torch.tensor([sentence_positions], dtype=torch.long)
67
+
68
+ return {
69
+ "input_ids": input_ids,
70
+ "attention_mask": attention_mask,
71
+ "token_type_ids": token_type_ids,
72
+ "sentence_positions": sentence_positions
73
+ }
74
+
75
+ def _forward(self, model_inputs):
76
+ return self.model(**model_inputs)
77
+
78
+ def __call__(self, inputs: Union[dict[str, str], list[dict[str, str]]], **kwargs):
79
+ if isinstance(inputs, dict):
80
+ inputs = [inputs]
81
+ model_outputs = super().__call__(inputs, **kwargs)
82
+ pipeline_outputs = []
83
+ for i, output in enumerate(model_outputs):
84
+ sentences = inputs[i]["context"]
85
+ output["sentences"]["sentence"] = sentences
86
+ output['sentences'] = convert_to_list(output['sentences'])
87
+ pipeline_outputs.append(output)
88
+ return pipeline_outputs if len(pipeline_outputs) > 1 else pipeline_outputs[0]
89
+
90
+ def postprocess(self, model_outputs, threshold = 0.5):
91
+ doc_logits = model_outputs.doc_logits
92
+ sent_logits = model_outputs.sent_logits
93
+ document_probabilities = torch.softmax(doc_logits, dim=-1)
94
+ sentence_probabilities = torch.softmax(sent_logits, dim=-1)
95
+
96
+ document_best_class = (document_probabilities[:, 1] > threshold).long()
97
+ sentence_best_class = (sentence_probabilities[:, :, 1] > threshold).long()
98
+ document_score = document_probabilities[:, document_best_class]
99
+ sentence_best_class = sentence_best_class.squeeze()
100
+ batch_indices = torch.arange(sentence_probabilities.size(1))
101
+ sentence_scores = sentence_probabilities.squeeze()[batch_indices, sentence_best_class]
102
+ best_document_label = document_best_class.numpy().item()
103
+ best_document_label = self.model.config.id2label[best_document_label]
104
+
105
+ best_sentence_labels = sentence_best_class.numpy().tolist()
106
+ best_sentence_labels = [self.model.config.id2label[label] for label in best_sentence_labels]
107
+ document_output = {"label": best_document_label, "score": document_score.numpy().item()}
108
+ sentence_output = {"label": best_sentence_labels, "score": sentence_scores.numpy().tolist()}
109
+ return {"document": document_output, "sentences": sentence_output}