Cat0125 commited on
Commit
94d6d2b
·
1 Parent(s): 5f3fb7b

Update models & training system

Browse files
datamanager.py CHANGED
@@ -1,7 +1,7 @@
1
  import json
2
  import pickle
3
 
4
- from files import read_lines
5
 
6
  models = json.load(open("models/models.json"))
7
  TEXT_PATH = 'models/%s/text.txt'
@@ -21,6 +21,9 @@ def get_texts(model_name):
21
  """
22
  return read_lines(TEXT_PATH % model_name)
23
 
 
 
 
24
  def set_data(model_name, data):
25
  """
26
  This function saves data to a file using the pickle module, with the filename specified by the
 
1
  import json
2
  import pickle
3
 
4
+ from files import read_lines, read_file
5
 
6
  models = json.load(open("models/models.json"))
7
  TEXT_PATH = 'models/%s/text.txt'
 
21
  """
22
  return read_lines(TEXT_PATH % model_name)
23
 
24
+ def get_text(model_name):
25
+ return read_file(TEXT_PATH % model_name)
26
+
27
  def set_data(model_name, data):
28
  """
29
  This function saves data to a file using the pickle module, with the filename specified by the
main.py CHANGED
@@ -13,25 +13,6 @@ WEIGHTS_MAP = [
13
  ]
14
 
15
  def get_next_word_results(db:dict, message:str, prev_word:str, text:str, repeat:int = 0):
16
- """
17
- This function takes in a database, a message, a previous word, and an optional repeat count, and
18
- returns a list of tokens from the database that match the previous word and have a score based on
19
- their context in the message.
20
-
21
- :param db: a dictionary containing information about words and their contexts
22
- :param message: a string representing the message or text being analyzed
23
- :type message: str
24
- :param prev_word: The previous word that we want to find the next word(s) for
25
- :type prev_word: str
26
- :param repeat: The repeat parameter is an optional integer parameter that specifies how many times
27
- the previous word can be repeated in the message before it is no longer considered a valid context
28
- for the next word. If repeat is set to 0, then there is no limit on the number of times the previous
29
- word can be repeated, defaults to 0
30
- :type repeat: int (optional)
31
- :return: a list of Token objects that are the next possible words in a given message based on the
32
- previous word and its contexts in a database. If the previous word is not in the database, an empty
33
- list is returned.
34
- """
35
  results = []
36
  if prev_word not in db:
37
  return []
@@ -94,7 +75,7 @@ def generator(user_message, word_count, mode, model_name):
94
  yield text
95
  break
96
  if i == 0 and text.strip() == '.':
97
- raise gr.Error("Error in generating. Try to use another prompt")
98
  i += 1
99
  yield text.strip()
100
 
@@ -103,11 +84,8 @@ demo = gr.Blocks(
103
  )
104
 
105
  title_html = """
106
- <center>
107
- <h1>Text Generator v2</h1>
108
- <p>Generates text using per-word context system</p>
109
- <a href="http://j93153xm.beget.tech/app/index.html?id=text-ai"><img src="https://img.shields.io/badge/Text%20Generator%20v1-RU%20only-brightgreen"></a>
110
- </center>
111
  """
112
  info_text = """
113
  # Information about the models
@@ -127,16 +105,6 @@ info_text = """
127
  `Language`: Russian
128
  `Quality`: 7-8/10
129
  `Sources`: http://staging.budsvetom.com/literature_items/ochen-dlinnyy-tekst
130
-
131
- # Training
132
- ```bash
133
- python train.py -r <models to train> [-t] [-l ...]
134
- ```
135
- `--rebuild` (`-r`) - Models that will be trained.
136
- `--turbo` (`-t`) - Enables turbo training. Will skip morphological analysis and just add all words directly.
137
- `--log` (`-l`) - Logs listed databases to the console after training.
138
-
139
- > **Note:** Use `--turbo` only when training with Russian texts.
140
  """
141
  with demo:
142
  gr.HTML(title_html)
 
13
  ]
14
 
15
  def get_next_word_results(db:dict, message:str, prev_word:str, text:str, repeat:int = 0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  results = []
17
  if prev_word not in db:
18
  return []
 
75
  yield text
76
  break
77
  if i == 0 and text.strip() == '.':
78
+ raise gr.Error("Error while generating. Please try again.")
79
  i += 1
80
  yield text.strip()
81
 
 
84
  )
85
 
86
  title_html = """
87
+ <h1>Text Generator v2</h1>
88
+ <a href="http://j93153xm.beget.tech/app/index.html?id=text-ai"><img src="https://img.shields.io/badge/Text%20Generator%20v1-RU%20only-brightgreen"></a>
 
 
 
89
  """
90
  info_text = """
91
  # Information about the models
 
105
  `Language`: Russian
106
  `Quality`: 7-8/10
107
  `Sources`: http://staging.budsvetom.com/literature_items/ochen-dlinnyy-tekst
 
 
 
 
 
 
 
 
 
 
108
  """
109
  with demo:
110
  gr.HTML(title_html)
models/en/data.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a44379910fb94a41548dc22de9d5c94d31b74a71d50951804c9f4b904311ae9
3
- size 892450
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:decf142f192f3ae9576f87a8aa119e39dd852317c8bdba2d83fc4eddebb3cc3b
3
+ size 3717733
models/en/data3.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc6f7a452bcc7aa0e6d9508a38c498c23df71e430bb26e9452e3492e464e786e
3
- size 926524
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:707e9ea7c67970a84c25f068aa74c130b45b11d1235d9e59efead924c6efd3a7
3
+ size 3698343
models/ru-lg/data.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a37fc4eab3fb56240c4db1e4e0011f29b4f7c454dd2c723ce38b36ccbc38da25
3
- size 3436464
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa404e61cfc8a102e228335ed384af25fae86ecead0143acbff67954adde0bb7
3
+ size 3997228
models/ru-lg/data3.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:722b3c087d60a91f35333faeb2bc98ca5d609aebfabcfd7ea57c8561d29fcda2
3
- size 3449818
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f839b88e09df9f6fb3b010cdee4945ead7b8385d252cb0f30529d3aa5229d8eb
3
+ size 4022056
models/ru-lg/text.txt CHANGED
The diff for this file is too large to render. See raw diff
 
models/ru-lite/data.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:13461ea05cfe7f067013241ade45a97bfa104ce90ea4f4b3edfaf21e35beda92
3
  size 560516
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63b42b3a11e0216544a7fca69986557e246cfb0cd773c2ca6687932fe9ede410
3
  size 560516
models/ru-lite/data3.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f38a913d682d96840deab6b5a9468539c4373c5b14ab3eb1ef7009c2c1dc8dc9
3
  size 573246
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8bb5f68b452a5cacbdbac92b096f3cfd8fe18ca6012951a4f42ee21615dcf17
3
  size 573246
train.py CHANGED
@@ -5,20 +5,21 @@ from pprint import pprint
5
  from tqdm import tqdm
6
 
7
  from classes import Token
8
- from datamanager import get_data, get_texts, models, set_data, set_data_v3
9
 
10
  turbo = False
11
 
12
 
13
  def normalize_text(sentence):
14
  sentence = sentence.strip()
 
15
  sentence = re.sub(r'\s+([.,!?;:])', r'\1', sentence)
16
  sentence = re.sub(r'([.,!?;:])(\S)', r'\1 \2', sentence)
17
  sentence = re.sub(r'\s+\'|\'\s+', '\'', sentence)
18
  sentence = re.sub(r'\s+', ' ', sentence)
19
  return sentence
20
 
21
- def process_sentence(db, db3, sentence:str, text):
22
  words = sentence.strip().split()
23
  for i in range(len(words)):
24
  word = words[i].strip()
@@ -39,15 +40,9 @@ def train(model_name):
39
  db = []
40
  db3 = {}
41
  print(f'Rebuilding database for "{model_name}"...')
42
- k = 0
43
- texts = get_texts(model_name)
44
- total_texts = len(texts)
45
- for text in texts:
46
- k += 1
47
- print(f'Processing text {k} of {total_texts}...')
48
- text = normalize_text(text)
49
- process_text(db, db3, text)
50
-
51
  set_data(model_name, db)
52
  models[model_name]["db"] = db
53
  set_data_v3(model_name, db3)
@@ -55,8 +50,8 @@ def train(model_name):
55
 
56
  if __name__ == '__main__':
57
  parser = argparse.ArgumentParser(
58
- prog='Text Generator v2',
59
- description='Generates text from a text file')
60
  parser.add_argument('-r', '--rebuild', action='extend', nargs="+", type=str)
61
  parser.add_argument('-l', '--log', action='extend', nargs="+", type=str)
62
  parser.add_argument('-t', '--turbo', action='store_true')
 
5
  from tqdm import tqdm
6
 
7
  from classes import Token
8
+ from datamanager import get_data, get_text, models, set_data, set_data_v3
9
 
10
  turbo = False
11
 
12
 
13
  def normalize_text(sentence):
14
  sentence = sentence.strip()
15
+ sentence = re.sub(r'(\s+|\n+)', ' ', sentence)
16
  sentence = re.sub(r'\s+([.,!?;:])', r'\1', sentence)
17
  sentence = re.sub(r'([.,!?;:])(\S)', r'\1 \2', sentence)
18
  sentence = re.sub(r'\s+\'|\'\s+', '\'', sentence)
19
  sentence = re.sub(r'\s+', ' ', sentence)
20
  return sentence
21
 
22
+ def process_sentence(db, db3, sentence, text):
23
  words = sentence.strip().split()
24
  for i in range(len(words)):
25
  word = words[i].strip()
 
40
  db = []
41
  db3 = {}
42
  print(f'Rebuilding database for "{model_name}"...')
43
+ text = get_text(model_name)
44
+ text = normalize_text(text)
45
+ process_text(db, db3, text)
 
 
 
 
 
 
46
  set_data(model_name, db)
47
  models[model_name]["db"] = db
48
  set_data_v3(model_name, db3)
 
50
 
51
  if __name__ == '__main__':
52
  parser = argparse.ArgumentParser(
53
+ prog='Train',
54
+ description='Training system for Text Generator v2')
55
  parser.add_argument('-r', '--rebuild', action='extend', nargs="+", type=str)
56
  parser.add_argument('-l', '--log', action='extend', nargs="+", type=str)
57
  parser.add_argument('-t', '--turbo', action='store_true')