ygauravyy commited on
Commit
9eccb58
1 Parent(s): 6e9a5be

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +194 -0
utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import numpy as np
4
+
5
+
6
+ def get_hparams_from_file(config_path):
7
+ with open(config_path, "r", encoding="utf-8") as f:
8
+ data = f.read()
9
+ config = json.loads(data)
10
+
11
+ hparams = HParams(**config)
12
+ return hparams
13
+
14
+ class HParams:
15
+ def __init__(self, **kwargs):
16
+ for k, v in kwargs.items():
17
+ if type(v) == dict:
18
+ v = HParams(**v)
19
+ self[k] = v
20
+
21
+ def keys(self):
22
+ return self.__dict__.keys()
23
+
24
+ def items(self):
25
+ return self.__dict__.items()
26
+
27
+ def values(self):
28
+ return self.__dict__.values()
29
+
30
+ def __len__(self):
31
+ return len(self.__dict__)
32
+
33
+ def __getitem__(self, key):
34
+ return getattr(self, key)
35
+
36
+ def __setitem__(self, key, value):
37
+ return setattr(self, key, value)
38
+
39
+ def __contains__(self, key):
40
+ return key in self.__dict__
41
+
42
+ def __repr__(self):
43
+ return self.__dict__.__repr__()
44
+
45
+
46
+ def string_to_bits(string, pad_len=8):
47
+ # Convert each character to its ASCII value
48
+ ascii_values = [ord(char) for char in string]
49
+
50
+ # Convert ASCII values to binary representation
51
+ binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
52
+
53
+ # Convert binary strings to integer arrays
54
+ bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
55
+
56
+ # Convert list of arrays to NumPy array
57
+ numpy_array = np.array(bit_arrays)
58
+ numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
59
+ numpy_array_full[:, 2] = 1
60
+ max_len = min(pad_len, len(numpy_array))
61
+ numpy_array_full[:max_len] = numpy_array[:max_len]
62
+ return numpy_array_full
63
+
64
+
65
+ def bits_to_string(bits_array):
66
+ # Convert each row of the array to a binary string
67
+ binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
68
+
69
+ # Convert binary strings to ASCII values
70
+ ascii_values = [int(binary, 2) for binary in binary_values]
71
+
72
+ # Convert ASCII values to characters
73
+ output_string = ''.join(chr(value) for value in ascii_values)
74
+
75
+ return output_string
76
+
77
+
78
+ def split_sentence(text, min_len=10, language_str='[EN]'):
79
+ if language_str in ['EN']:
80
+ sentences = split_sentences_latin(text, min_len=min_len)
81
+ else:
82
+ sentences = split_sentences_zh(text, min_len=min_len)
83
+ return sentences
84
+
85
+ def split_sentences_latin(text, min_len=10):
86
+ """Split Long sentences into list of short ones
87
+
88
+ Args:
89
+ str: Input sentences.
90
+
91
+ Returns:
92
+ List[str]: list of output sentences.
93
+ """
94
+ # deal with dirty sentences
95
+ text = re.sub('[。!?;]', '.', text)
96
+ text = re.sub('[,]', ',', text)
97
+ text = re.sub('[“”]', '"', text)
98
+ text = re.sub('[‘’]', "'", text)
99
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
100
+ text = re.sub('[\n\t ]+', ' ', text)
101
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
102
+ # split
103
+ sentences = [s.strip() for s in text.split('$#!')]
104
+ if len(sentences[-1]) == 0: del sentences[-1]
105
+
106
+ new_sentences = []
107
+ new_sent = []
108
+ count_len = 0
109
+ for ind, sent in enumerate(sentences):
110
+ # print(sent)
111
+ new_sent.append(sent)
112
+ count_len += len(sent.split(" "))
113
+ if count_len > min_len or ind == len(sentences) - 1:
114
+ count_len = 0
115
+ new_sentences.append(' '.join(new_sent))
116
+ new_sent = []
117
+ return merge_short_sentences_latin(new_sentences)
118
+
119
+
120
+ def merge_short_sentences_latin(sens):
121
+ """Avoid short sentences by merging them with the following sentence.
122
+
123
+ Args:
124
+ List[str]: list of input sentences.
125
+
126
+ Returns:
127
+ List[str]: list of output sentences.
128
+ """
129
+ sens_out = []
130
+ for s in sens:
131
+ # If the previous sentense is too short, merge them with
132
+ # the current sentence.
133
+ if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
134
+ sens_out[-1] = sens_out[-1] + " " + s
135
+ else:
136
+ sens_out.append(s)
137
+ try:
138
+ if len(sens_out[-1].split(" ")) <= 2:
139
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
140
+ sens_out.pop(-1)
141
+ except:
142
+ pass
143
+ return sens_out
144
+
145
+ def split_sentences_zh(text, min_len=10):
146
+ text = re.sub('[。!?;]', '.', text)
147
+ text = re.sub('[,]', ',', text)
148
+ # 将文本中的换行符、空格和制表符替换为空格
149
+ text = re.sub('[\n\t ]+', ' ', text)
150
+ # 在标点符号后添加一个空格
151
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
152
+ # 分隔句子并去除前后空格
153
+ # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
154
+ sentences = [s.strip() for s in text.split('$#!')]
155
+ if len(sentences[-1]) == 0: del sentences[-1]
156
+
157
+ new_sentences = []
158
+ new_sent = []
159
+ count_len = 0
160
+ for ind, sent in enumerate(sentences):
161
+ new_sent.append(sent)
162
+ count_len += len(sent)
163
+ if count_len > min_len or ind == len(sentences) - 1:
164
+ count_len = 0
165
+ new_sentences.append(' '.join(new_sent))
166
+ new_sent = []
167
+ return merge_short_sentences_zh(new_sentences)
168
+
169
+
170
+ def merge_short_sentences_zh(sens):
171
+ # return sens
172
+ """Avoid short sentences by merging them with the following sentence.
173
+
174
+ Args:
175
+ List[str]: list of input sentences.
176
+
177
+ Returns:
178
+ List[str]: list of output sentences.
179
+ """
180
+ sens_out = []
181
+ for s in sens:
182
+ # If the previous sentense is too short, merge them with
183
+ # the current sentence.
184
+ if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
185
+ sens_out[-1] = sens_out[-1] + " " + s
186
+ else:
187
+ sens_out.append(s)
188
+ try:
189
+ if len(sens_out[-1]) <= 2:
190
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
191
+ sens_out.pop(-1)
192
+ except:
193
+ pass
194
+ return sens_out