Randomtalks commited on
Commit
0774a75
·
1 Parent(s): eef1745

First commit with DOME model

Browse files

Signed-off-by: egor <egorbu@gmail.com>

README.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DOME wrapper for docstring intent classification
2
+ This wrapper allows to
3
+ * split docstrings into sentences
4
+ * convert to required DOME inputs
5
+ * predict class for each sentence in docstring
6
+
7
+ ## Model architecture
8
+ Architecture is based on https://github.com/ICSE-DOME/DOME.
9
+
10
+ ## Usage
11
+ ```python
12
+ docstring = "sentences of docstring"
13
+ dome = DOME("dome_location")
14
+ sentences, predictions = dome.predict(docstring)
15
+ ```
16
+
17
+ ## Dependencies
18
+ ```
19
+ spacy
20
+ torch
21
+ transformers
22
+ ```
23
+
24
+ ## Code of the model
25
+ ````python
26
+ """
27
+ Model is based on replication package for ICSE23 Paper Developer-Intent Driven Code Comment Generation.
28
+ Initial solution: https://github.com/ICSE-DOME/DOME
29
+ Pipeline consists of several parts:
30
+ * split docstring into sentences
31
+ * prepare input data for DOMEBertForClassification
32
+ * predict class
33
+
34
+ How to use:
35
+ ```python
36
+ docstring = "sentences of docstring"
37
+ dome = DOME("dome_location")
38
+ sentences, predictions = dome.predict(docstring)
39
+ ```
40
+ """
41
+ import re
42
+ from typing import Tuple, List
43
+
44
+ import spacy
45
+ import torch
46
+ import torch.nn as nn
47
+ import torch.nn.functional as F
48
+ from transformers import AutoTokenizer, RobertaConfig, RobertaModel
49
+
50
+ MAX_LENGTH_BERT = 510
51
+
52
+
53
+ class DOME:
54
+ """
55
+ End-to-end pipeline for docstring classification
56
+ * split sentences
57
+ * prepare inputs
58
+ * classify
59
+ """
60
+ def __init__(self, pretrained_model: str):
61
+ """
62
+ :param pretrained_model: location of pretrained model
63
+ """
64
+ self.model = DOMEBertForClassification.from_pretrained(pretrained_model)
65
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
66
+ self.docstring2sentences = Docstring2Sentences()
67
+
68
+ def predict(self, docstring: str) -> Tuple[List[str], List[str]]:
69
+ """
70
+ Predict DOME classes for each sentence in docstring.
71
+ :param docstring: docstring to process
72
+ :return: tuple with list of sentences and list of predictions for each sentence.
73
+ """
74
+ sentences = self.docstring2sentences.docstring2sentences(docstring)
75
+ predictions = [self.model.predict(*dome_preprocess(tokenizer=self.tokenizer, comment=sentence))
76
+ for sentence in sentences]
77
+ return sentences, predictions
78
+
79
+
80
+ class DOMEBertForClassification(RobertaModel):
81
+ """
82
+ A custom classification model based on the RobertaModel for intent classification.
83
+
84
+ This model extends the RobertaModel with additional linear layers to incorporate
85
+ comment length as an additional feature for classification tasks.
86
+ """
87
+
88
+ DOME_CLASS_NAMES = ["what", "why", "how-to-use", "how-it-is-done", "property", "others"]
89
+
90
+ def __init__(self, config: RobertaConfig):
91
+ """
92
+ Initialize the DOMEBertForClassification model.
93
+
94
+ :param config: The configuration information for the RobertaModel.
95
+ """
96
+ super().__init__(config)
97
+
98
+ # I omit possibility to configure number of classes and so on because we need to load pretrained model
99
+ # DOME layers for intent classification:
100
+ self.fc1 = nn.Linear(768 + 1, 768 // 3)
101
+ self.fc2 = nn.Linear(768 // 3, 6)
102
+ self.dropout = nn.Dropout(0.2)
103
+
104
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, comment_len: torch.Tensor = None) \
105
+ -> torch.Tensor:
106
+ """
107
+ Forward pass for the DOMEBertForClassification model.
108
+
109
+ :param input_ids: Tensor of token ids to be fed to a model.
110
+ :param attention_mask: Mask to avoid performing attention on padding token indices. Always equals 1.
111
+ :param comment_len: Tensor representing the length of comments. Equal 1 if comment has less than 3 words,
112
+ 0 otherwise.
113
+ :return: The logits after passing through the model.
114
+ """
115
+ # Use the parent class's forward method to get the base outputs
116
+ outputs = super().forward(
117
+ input_ids=input_ids,
118
+ attention_mask=attention_mask
119
+ )
120
+ # Extract the pooled output (last hidden state of the [CLS] token)
121
+ pooled_output = outputs.pooler_output
122
+ # DOME custom layers:
123
+ comment_len = comment_len.view(-1, 1).float() # Ensure comment_len is correctly shaped
124
+ # DOME use comment len as additional feature
125
+ combined_input = torch.cat([pooled_output, comment_len], dim=-1)
126
+ x = self.dropout(F.relu(self.fc1(self.dropout(combined_input))))
127
+ logits = self.fc2(x)
128
+ return logits
129
+
130
+ def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, comment_len: torch.Tensor = None) \
131
+ -> str:
132
+ """
133
+ Predict class for tokenized docstring.
134
+
135
+ :param input_ids: Tensor of token ids to be fed to a model.
136
+ :param attention_mask: Mask to avoid performing attention on padding token indices. Always equals 1.
137
+ :param comment_len: Tensor representing the length of comments. Equal 1 if comment has less than 3 words,
138
+ 0 otherwise.
139
+ :return: class
140
+ """
141
+ logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, comment_len=comment_len)
142
+ return self.DOME_CLASS_NAMES[int(torch.argmax(logits, 1))]
143
+
144
+
145
+ def dome_preprocess(tokenizer, comment):
146
+ """
147
+ DOME preprocessor - returns all required values for "DOMEBertForClassification.forward".
148
+ This function limits maximum number of tokens to fit into BERT.
149
+ :param tokenizer: tokenizer to use.
150
+ :param comment: text of sentence from docstring/comment that should be classified by DOMEBertForClassification.
151
+ :return: tuple with (input_ids, attention_mask, comment_len).
152
+ """
153
+ input_ids = tokenizer.convert_tokens_to_ids([tokenizer.cls_token] + tokenizer.tokenize(comment) +
154
+ [tokenizer.sep_token])[:MAX_LENGTH_BERT]
155
+ attention_mask = [1] * len(input_ids)
156
+ if len(comment.strip().split()) < 3:
157
+ comment_len = 1
158
+ else:
159
+ comment_len = 0
160
+ return (torch.tensor(input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0),
161
+ torch.tensor(comment_len).unsqueeze(0))
162
+
163
+
164
+ class Docstring2Sentences:
165
+ """Helper class to split docstrings into sentences"""
166
+ def __init__(self):
167
+ self.spacy_nlp = spacy.load("en_core_web_sm")
168
+
169
+ @staticmethod
170
+ def split_docstring(docstring: str, delimiters: List[Tuple[str, str]]):
171
+ """
172
+ Splits the docstring into separate parts of text and code blocks, preserving the original formatting.
173
+
174
+ :param docstring: The docstring to split.
175
+ :param delimiters: A list of tuples, each containing start and end delimiters for code blocks.
176
+ :return: A list of strings, each either a text block or a code block.
177
+ """
178
+
179
+ # Escape delimiter parts for regex and create a combined pattern
180
+ escaped_delimiters = [tuple(map(re.escape, d)) for d in delimiters]
181
+ combined_pattern = '|'.join([f'({start}.*?{end})' for start, end in escaped_delimiters])
182
+
183
+ # Split using the combined pattern, preserving the delimiters
184
+ parts = re.split(combined_pattern, docstring, flags=re.DOTALL)
185
+
186
+ # Filter out empty strings
187
+ parts = [part for part in parts if part]
188
+
189
+ return parts
190
+
191
+ @staticmethod
192
+ def is_only_spaces_and_newlines(string):
193
+ """
194
+ Check if the given string contains only spaces and newlines.
195
+
196
+ :param string: The string to check.
197
+ :return: True if the string contains only spaces and newlines, False otherwise.
198
+ """
199
+ return bool(re.match(r'^[\s\n]+$', string))
200
+
201
+ def docstring2sentences(self, docstring):
202
+ """
203
+ Splits a docstring into individual sentences, preserving code blocks.
204
+
205
+ This function uses `docstring2parts` to split the docstring into parts based on predefined code block delimiters.
206
+ It then utilizes a SpaCy NLP model to split the non-code text parts into sentences.
207
+ Code blocks are kept intact as single elements.
208
+
209
+ :param docstring: The docstring to be processed, which may contain both regular text and code blocks.
210
+ :return: A list containing individual sentences and intact code blocks.
211
+ """
212
+ delimiters = [("@code", "@endcode"), ("\code", "\endcode")]
213
+ parts = self.split_docstring(docstring=docstring, delimiters=delimiters)
214
+ sentences = []
215
+ for part in parts:
216
+ if part[1:5] == "code" and part[-7:] == "endcode":
217
+ # code block
218
+ sentences.append(part)
219
+ else:
220
+ sentences.extend(sentence.text for sentence in self.spacy_nlp(part).sents)
221
+
222
+ return [sentence for sentence in sentences if not self.is_only_spaces_and_newlines(sentence)]
223
+
224
+ ````
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DOMEBertForClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "output_past": true,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.9.2",
24
+ "type_vocab_size": 1,
25
+ "use_cache": true,
26
+ "vocab_size": 50265
27
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f045aba0a263d2bc109af9476b5673fb666b2e716de91698a6b639505668cb19
3
+ size 499457001
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "errors": "replace", "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "model_max_length": 512, "special_tokens_map_file": "/home/egor/workdir/github/ICSE-DOME/DOME/src/comment_classifier/pretrained_codebert/special_tokens_map.json", "name_or_path": "/home/egor/workdir/github/ICSE-DOME/DOME/src/comment_classifier/pretrained_codebert", "tokenizer_class": "RobertaTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff