File size: 1,338 Bytes
a38f44b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3
import sys

filenames = sys.argv[1:]

MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart._make_causal_mask"
MATCH_PATTERN_2 = "def _make_causal_mask("

MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart.prepare_4d_attention_mask"
MATCH_PATTERN_2 = "def prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):"

END_MATCH_PATTERN_2 = ""

# MATCH_PATTERN_1 = "def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):"
#MATCH_PATTERN_2 = "# create causal mask"

# END_MATCH_PATTERN_2 = "def forward("

                                        
for filename in filenames:
    with open(filename, "r") as f:
        lines = f.readlines()

    new_lines = []
    is_in_del = False
    for i, line in enumerate(lines):
        if line.strip().lstrip() == MATCH_PATTERN_1 and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == MATCH_PATTERN_2:
            print("suh")
            is_in_del = True
        elif line.strip().lstrip() == "" and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == END_MATCH_PATTERN_2:
            is_in_del = False

        if not is_in_del:
            new_lines.append(line)


    with open(filename, "w") as f:
        f.writelines(new_lines)