metricv commited on
Commit
d06f65e
·
verified ·
1 Parent(s): 5414abe

Update model

Browse files
Files changed (4) hide show
  1. model_consts.py +2 -2
  2. segmenter.ckpt +2 -2
  3. train.py +1 -1
  4. utils.py +60 -64
model_consts.py CHANGED
@@ -4,6 +4,6 @@ else:
4
  from .utils import get_upenn_tags_dict
5
 
6
  input_size = len(get_upenn_tags_dict())
7
- embedding_size = 128
8
- hidden_size = 128
9
  num_layers = 2
 
4
  from .utils import get_upenn_tags_dict
5
 
6
  input_size = len(get_upenn_tags_dict())
7
+ embedding_size = 256
8
+ hidden_size = 256
9
  num_layers = 2
segmenter.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26c09246a1ed23aa5be9656e36878d30b1b39aa649dbd9a24bbef7ecee5a4e7d
3
- size 2665888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8e6209584d0021684bb3a09ec1b717843f3086dfcc6411c57276f743f8e62fa
3
+ size 10584544
train.py CHANGED
@@ -26,6 +26,6 @@ if __name__ == "__main__":
26
 
27
  model.to(device)
28
 
29
- train_bidirlstm_embedding_model(model, dataset, num_epochs=75, batch_size=2)
30
 
31
  torch.save(model.state_dict(), "segmenter.ckpt")
 
26
 
27
  model.to(device)
28
 
29
+ train_bidirlstm_embedding_model(model, dataset, num_epochs=150, batch_size=2)
30
 
31
  torch.save(model.state_dict(), "segmenter.ckpt")
utils.py CHANGED
@@ -4,6 +4,64 @@ from stable_whisper.result import WordTiming
4
  import numpy as np
5
  import torch
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def bind_wordtimings_to_tags(wt: list[WordTiming]):
8
  raw_words = [w.word for w in wt]
9
 
@@ -16,6 +74,7 @@ def bind_wordtimings_to_tags(wt: list[WordTiming]):
16
  tokens_wordtiming_map.append(len(tokens_word))
17
 
18
  tagged_words = nltk.pos_tag(tokenized_raw_words)
 
19
 
20
  grouped_tags = []
21
 
@@ -49,6 +108,7 @@ def tag_training_data(filename: str):
49
 
50
  tokenized_full_text = nltk.word_tokenize(full_text)
51
  tagged_full_text = nltk.pos_tag(tokenized_full_text)
 
52
 
53
  tagged_full_text_copy = tagged_full_text
54
 
@@ -75,70 +135,6 @@ def tag_training_data(filename: str):
75
 
76
  return reconstructed_tags
77
 
78
- def get_upenn_tags_dict():
79
- # tagger = PerceptronTagger()
80
-
81
- # tags = list(tagger.tagdict.values())
82
-
83
- # # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
84
- # tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
85
- # tags = list(set(tags))
86
- # tags.sort()
87
- # tags.append("BREAK")
88
-
89
- # tags_dict = dict()
90
-
91
- # for index, tag in enumerate(tags):
92
- # tags_dict[tag] = index
93
-
94
- return {'#': 0,
95
- '$': 1,
96
- "''": 2,
97
- '(': 3,
98
- ')': 4,
99
- ',': 5,
100
- '.': 6,
101
- ':': 7,
102
- 'CC': 8,
103
- 'CD': 9,
104
- 'DT': 10,
105
- 'EX': 11,
106
- 'FW': 12,
107
- 'IN': 13,
108
- 'JJ': 14,
109
- 'JJR': 15,
110
- 'JJS': 16,
111
- 'LS': 17,
112
- 'MD': 18,
113
- 'NN': 19,
114
- 'NNP': 20,
115
- 'NNPS': 21,
116
- 'NNS': 22,
117
- 'PDT': 23,
118
- 'POS': 24,
119
- 'PRP': 25,
120
- 'PRP$': 26,
121
- 'RB': 27,
122
- 'RBR': 28,
123
- 'RBS': 29,
124
- 'RP': 30,
125
- 'SYM': 31,
126
- 'TO': 32,
127
- 'UH': 33,
128
- 'VB': 34,
129
- 'VBD': 35,
130
- 'VBG': 36,
131
- 'VBN': 37,
132
- 'VBP': 38,
133
- 'VBZ': 39,
134
- 'WDT': 40,
135
- 'WP': 41,
136
- 'WP$': 42,
137
- 'WRB': 43,
138
- '``': 44,
139
- 'BREAK': 45}
140
-
141
-
142
  def parse_tags(reconstructed_tags):
143
  """
144
  Parse reconstructed tags into input/tag datapoint.
 
4
  import numpy as np
5
  import torch
6
 
7
+ additional_tags = {
8
+ "as": "`AS",
9
+ "and": "`AND",
10
+ "of": "`OF",
11
+ "how": "`HOW",
12
+ "but": "`BUT",
13
+ "the": "`THE",
14
+ "a": "`A",
15
+ "an": "`A",
16
+ "which": "`WHICH",
17
+ "what": "`WHAT",
18
+ "where": "`WHERE",
19
+ "that": "`THAT",
20
+ "who": "`WHO",
21
+ "when": "`WHEN",
22
+ }
23
+
24
+ def get_upenn_tags_dict():
25
+ # tagger = PerceptronTagger()
26
+
27
+ # tags = list(tagger.tagdict.values())
28
+
29
+ # # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
30
+ # tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
31
+ # tags = list(set(tags))
32
+ # tags.sort()
33
+ # tags.append("BREAK")
34
+
35
+ # tags_dict = dict()
36
+
37
+ # for index, tag in enumerate(tags):
38
+ # tags_dict[tag] = index
39
+
40
+ return {'#': 0, '$': 1, "''": 2,'(': 3,')': 4,',': 5,'.': 6,':': 7,'CC': 8,'CD': 9,'DT': 10,'EX': 11,'FW': 12,'IN': 13,'JJ': 14,'JJR': 15,'JJS': 16,'LS': 17,'MD': 18,'NN': 19,'NNP': 20,'NNPS': 21,'NNS': 22,'PDT': 23,'POS': 24,'PRP': 25,'PRP$': 26,'RB': 27,'RBR': 28,'RBS': 29,'RP': 30,'SYM': 31,'TO': 32,'UH': 33,'VB': 34,'VBD': 35,'VBG': 36,'VBN': 37,'VBP': 38,'VBZ': 39,'WDT': 40,'WP': 41,'WP$': 42,'WRB': 43,'``': 44,'BREAK': 45,
41
+ '`AS': 46,
42
+ '`AND': 47,
43
+ '`OF': 48,
44
+ '`HOW': 49,
45
+ '`BUT': 50,
46
+ '`THE': 51,
47
+ '`A': 52,
48
+ '`WHICH': 53,
49
+ '`WHAT': 54,
50
+ '`WHERE': 55,
51
+ '`THAT': 56,
52
+ '`WHO': 57,
53
+ '`WHEN': 58
54
+ }
55
+
56
+ def nltk_extend_tags(tagged_text: list[tuple[str, str]]):
57
+ result = []
58
+ for text, tag in tagged_text:
59
+ text_lower = text.lower().strip()
60
+ if text_lower in additional_tags:
61
+ yield (text, additional_tags[text_lower])
62
+ else:
63
+ yield (text, tag)
64
+
65
  def bind_wordtimings_to_tags(wt: list[WordTiming]):
66
  raw_words = [w.word for w in wt]
67
 
 
74
  tokens_wordtiming_map.append(len(tokens_word))
75
 
76
  tagged_words = nltk.pos_tag(tokenized_raw_words)
77
+ tagged_words = list(nltk_extend_tags(tagged_words))
78
 
79
  grouped_tags = []
80
 
 
108
 
109
  tokenized_full_text = nltk.word_tokenize(full_text)
110
  tagged_full_text = nltk.pos_tag(tokenized_full_text)
111
+ tagged_full_text = list(nltk_extend_tags(tagged_full_text))
112
 
113
  tagged_full_text_copy = tagged_full_text
114
 
 
135
 
136
  return reconstructed_tags
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def parse_tags(reconstructed_tags):
139
  """
140
  Parse reconstructed tags into input/tag datapoint.