a.korepanov commited on
Commit
a614a1a
1 Parent(s): 044cecb

some formatting

Browse files
Files changed (2) hide show
  1. sbert_punc_case_ru/sbertpunccase.py +77 -47
  2. setup.py +22 -17
sbert_punc_case_ru/sbertpunccase.py CHANGED
@@ -8,62 +8,66 @@ import numpy as np
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
 
10
  # Прогнозируемые знаки препинания
11
- PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
12
 
13
  # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
14
  # UPPER_TOTAL - верхний регистр для всех символов
15
- LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
16
  # Добавим в пунктуацию метку O означающий отсутсвие пунктуации
17
- LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
18
 
19
  # Сформируем метки на основе комбинаций регистра и пунктуации
20
  LABELS_list = []
21
  for case in LABELS_CASE:
22
  for punc in LABELS_PUNC:
23
- LABELS_list.append(f'{case}_{punc}')
24
- LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
25
- LABELS['O'] = -100
26
  INVERSE_LABELS = {i: label for label, i in LABELS.items()}
27
 
28
- LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
29
- LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
 
 
 
 
30
 
31
 
32
  def token_to_label(token, label):
33
  if type(label) == int:
34
  label = INVERSE_LABELS[label]
35
- if label == 'LOWER_O':
36
  return token
37
- if label == 'LOWER_PERIOD':
38
- return token + '.'
39
- if label == 'LOWER_COMMA':
40
- return token + ','
41
- if label == 'LOWER_QUESTION':
42
- return token + '?'
43
- if label == 'UPPER_O':
44
  return token.capitalize()
45
- if label == 'UPPER_PERIOD':
46
- return token.capitalize() + '.'
47
- if label == 'UPPER_COMMA':
48
- return token.capitalize() + ','
49
- if label == 'UPPER_QUESTION':
50
- return token.capitalize() + '?'
51
- if label == 'UPPER_TOTAL_O':
52
  return token.upper()
53
- if label == 'UPPER_TOTAL_PERIOD':
54
- return token.upper() + '.'
55
- if label == 'UPPER_TOTAL_COMMA':
56
- return token.upper() + ','
57
- if label == 'UPPER_TOTAL_QUESTION':
58
- return token.upper() + '?'
59
- if label == 'O':
60
  return token
61
 
62
 
63
- def decode_label(label, classes='all'):
64
- if classes == 'punc':
65
  return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
66
- if classes == 'case':
67
  return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
68
  else:
69
  return INVERSE_LABELS[label]
@@ -76,14 +80,12 @@ class SbertPuncCase(nn.Module):
76
  def __init__(self):
77
  super().__init__()
78
 
79
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
80
- strip_accents=False)
81
  self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
82
  self.model.eval()
83
 
84
  def forward(self, input_ids, attention_mask):
85
- return self.model(input_ids=input_ids,
86
- attention_mask=attention_mask)
87
 
88
  def punctuate(self, text):
89
  text = text.strip().lower()
@@ -94,10 +96,23 @@ class SbertPuncCase(nn.Module):
94
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
95
 
96
  if len(tokenizer_output.input_ids) > 512:
97
- return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
98
-
99
- predictions = self(torch.tensor([tokenizer_output.input_ids], device=self.model.device),
100
- torch.tensor([tokenizer_output.attention_mask], device=self.model.device)).logits.cpu().data.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  predictions = np.argmax(predictions, axis=2)
102
 
103
  # decode punctuation and casing
@@ -108,16 +123,31 @@ class SbertPuncCase(nn.Module):
108
  label_id = predictions[0][label_pos]
109
  label = decode_label(label_id)
110
  splitted_text.append(token_to_label(word, label))
111
- capitalized_text = ' '.join(splitted_text)
112
  return capitalized_text
113
 
114
 
115
- if __name__ == '__main__':
116
- parser = argparse.ArgumentParser("Punctuation and case restoration model sbert_punc_case_ru")
117
- parser.add_argument("-i", "--input", type=str, help="text to restore", default='sbert punc case расставляет точки запятые и знаки вопроса вам нравится')
118
- parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  args = parser.parse_args()
120
  print(f"Source text: {args.input}\n")
121
  sbertpunc = SbertPuncCase().to(args.device)
122
  punctuated_text = sbertpunc.punctuate(args.input)
123
- print(f"Restored text: {punctuated_text}")
 
8
  from transformers import AutoTokenizer, AutoModelForTokenClassification
9
 
10
  # Прогнозируемые знаки препинания
11
+ PUNK_MAPPING = {".": "PERIOD", ",": "COMMA", "?": "QUESTION"}
12
 
13
  # Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа,
14
  # UPPER_TOTAL - верхний регистр для всех символов
15
+ LABELS_CASE = ["LOWER", "UPPER", "UPPER_TOTAL"]
16
  # Добавим в пунктуацию метку O означающий отсутсвие пунктуации
17
+ LABELS_PUNC = ["O"] + list(PUNK_MAPPING.values())
18
 
19
  # Сформируем метки на основе комбинаций регистра и пунктуации
20
  LABELS_list = []
21
  for case in LABELS_CASE:
22
  for punc in LABELS_PUNC:
23
+ LABELS_list.append(f"{case}_{punc}")
24
+ LABELS = {label: i + 1 for i, label in enumerate(LABELS_list)}
25
+ LABELS["O"] = -100
26
  INVERSE_LABELS = {i: label for label, i in LABELS.items()}
27
 
28
+ LABEL_TO_PUNC_LABEL = {
29
+ label: label.split("_")[-1] for label in LABELS.keys() if label != "O"
30
+ }
31
+ LABEL_TO_CASE_LABEL = {
32
+ label: "_".join(label.split("_")[:-1]) for label in LABELS.keys() if label != "O"
33
+ }
34
 
35
 
36
  def token_to_label(token, label):
37
  if type(label) == int:
38
  label = INVERSE_LABELS[label]
39
+ if label == "LOWER_O":
40
  return token
41
+ if label == "LOWER_PERIOD":
42
+ return token + "."
43
+ if label == "LOWER_COMMA":
44
+ return token + ","
45
+ if label == "LOWER_QUESTION":
46
+ return token + "?"
47
+ if label == "UPPER_O":
48
  return token.capitalize()
49
+ if label == "UPPER_PERIOD":
50
+ return token.capitalize() + "."
51
+ if label == "UPPER_COMMA":
52
+ return token.capitalize() + ","
53
+ if label == "UPPER_QUESTION":
54
+ return token.capitalize() + "?"
55
+ if label == "UPPER_TOTAL_O":
56
  return token.upper()
57
+ if label == "UPPER_TOTAL_PERIOD":
58
+ return token.upper() + "."
59
+ if label == "UPPER_TOTAL_COMMA":
60
+ return token.upper() + ","
61
+ if label == "UPPER_TOTAL_QUESTION":
62
+ return token.upper() + "?"
63
+ if label == "O":
64
  return token
65
 
66
 
67
+ def decode_label(label, classes="all"):
68
+ if classes == "punc":
69
  return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
70
+ if classes == "case":
71
  return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
72
  else:
73
  return INVERSE_LABELS[label]
 
80
  def __init__(self):
81
  super().__init__()
82
 
83
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, strip_accents=False)
 
84
  self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO)
85
  self.model.eval()
86
 
87
  def forward(self, input_ids, attention_mask):
88
+ return self.model(input_ids=input_ids, attention_mask=attention_mask)
 
89
 
90
  def punctuate(self, text):
91
  text = text.strip().lower()
 
96
  tokenizer_output = self.tokenizer(words, is_split_into_words=True)
97
 
98
  if len(tokenizer_output.input_ids) > 512:
99
+ return " ".join(
100
+ [
101
+ self.punctuate(" ".join(text_part))
102
+ for text_part in np.array_split(words, 2)
103
+ ]
104
+ )
105
+
106
+ predictions = (
107
+ self(
108
+ torch.tensor([tokenizer_output.input_ids], device=self.model.device),
109
+ torch.tensor(
110
+ [tokenizer_output.attention_mask], device=self.model.device
111
+ ),
112
+ )
113
+ .logits.cpu()
114
+ .data.numpy()
115
+ )
116
  predictions = np.argmax(predictions, axis=2)
117
 
118
  # decode punctuation and casing
 
123
  label_id = predictions[0][label_pos]
124
  label = decode_label(label_id)
125
  splitted_text.append(token_to_label(word, label))
126
+ capitalized_text = " ".join(splitted_text)
127
  return capitalized_text
128
 
129
 
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(
132
+ "Punctuation and case restoration model sbert_punc_case_ru"
133
+ )
134
+ parser.add_argument(
135
+ "-i",
136
+ "--input",
137
+ type=str,
138
+ help="text to restore",
139
+ default="sbert punc case расставляет точки запятые и знаки вопроса вам нравится",
140
+ )
141
+ parser.add_argument(
142
+ "-d",
143
+ "--device",
144
+ type=str,
145
+ help="run model on cpu or gpu",
146
+ choices=["cpu", "cuda"],
147
+ default="cpu",
148
+ )
149
  args = parser.parse_args()
150
  print(f"Source text: {args.input}\n")
151
  sbertpunc = SbertPuncCase().to(args.device)
152
  punctuated_text = sbertpunc.punctuate(args.input)
153
+ print(f"Restored text: {punctuated_text}")
setup.py CHANGED
@@ -1,19 +1,24 @@
1
  from distutils.core import setup
2
 
3
- setup(name='sbert_punc_case_ru',
4
- version='0.1',
5
- description='Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru',
6
- author='Almira Murtazina',
7
- author_email='ar.murtazina@skbkontur.ru',
8
- packages=['sbert_punc_case_ru'],
9
- install_requires=['transformers>=4.18.3'],
10
- classifiers=[
11
- "Operating System :: OS Independent",
12
- "Programming Language :: Python :: 3",
13
- "Programming Language :: Python :: 3.6",
14
- "Programming Language :: Python :: 3.7",
15
- "Programming Language :: Python :: 3.8",
16
- "Programming Language :: Python :: 3.9",
17
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
18
- ]
19
- )
 
 
 
 
 
 
1
  from distutils.core import setup
2
 
3
+ setup(
4
+ name="sbert_punc_case_ru",
5
+ version="0.2",
6
+ description="Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru",
7
+ author="Almira Murtazina",
8
+ author_email="ar.murtazina@skbkontur.ru",
9
+ packages=["sbert_punc_case_ru"],
10
+ install_requires=[
11
+ "transformers>=4.36.2",
12
+ "torch",
13
+ "numpy"
14
+ ],
15
+ classifiers=[
16
+ "Operating System :: OS Independent",
17
+ "Programming Language :: Python :: 3",
18
+ "Programming Language :: Python :: 3.6",
19
+ "Programming Language :: Python :: 3.7",
20
+ "Programming Language :: Python :: 3.8",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
23
+ ],
24
+ )