File size: 1,338 Bytes
a38f44b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
#!/usr/bin/env python3
import sys
filenames = sys.argv[1:]
MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart._make_causal_mask"
MATCH_PATTERN_2 = "def _make_causal_mask("
MATCH_PATTERN_1 = "# Copied from transformers.models.bart.modeling_bart.prepare_4d_attention_mask"
MATCH_PATTERN_2 = "def prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):"
END_MATCH_PATTERN_2 = ""
# MATCH_PATTERN_1 = "def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):"
#MATCH_PATTERN_2 = "# create causal mask"
# END_MATCH_PATTERN_2 = "def forward("
for filename in filenames:
with open(filename, "r") as f:
lines = f.readlines()
new_lines = []
is_in_del = False
for i, line in enumerate(lines):
if line.strip().lstrip() == MATCH_PATTERN_1 and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == MATCH_PATTERN_2:
print("suh")
is_in_del = True
elif line.strip().lstrip() == "" and i < len(lines) - 1 and lines[i + 1].strip().lstrip() == END_MATCH_PATTERN_2:
is_in_del = False
if not is_in_del:
new_lines.append(line)
with open(filename, "w") as f:
f.writelines(new_lines)
|