Spaces:
Running
Running
Cat0125
commited on
Commit
·
94d6d2b
1
Parent(s):
5f3fb7b
Update models & training system
Browse files- datamanager.py +4 -1
- main.py +3 -35
- models/en/data.pkl +2 -2
- models/en/data3.pkl +2 -2
- models/ru-lg/data.pkl +2 -2
- models/ru-lg/data3.pkl +2 -2
- models/ru-lg/text.txt +0 -0
- models/ru-lite/data.pkl +1 -1
- models/ru-lite/data3.pkl +1 -1
- train.py +8 -13
datamanager.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import json
|
2 |
import pickle
|
3 |
|
4 |
-
from files import read_lines
|
5 |
|
6 |
models = json.load(open("models/models.json"))
|
7 |
TEXT_PATH = 'models/%s/text.txt'
|
@@ -21,6 +21,9 @@ def get_texts(model_name):
|
|
21 |
"""
|
22 |
return read_lines(TEXT_PATH % model_name)
|
23 |
|
|
|
|
|
|
|
24 |
def set_data(model_name, data):
|
25 |
"""
|
26 |
This function saves data to a file using the pickle module, with the filename specified by the
|
|
|
1 |
import json
|
2 |
import pickle
|
3 |
|
4 |
+
from files import read_lines, read_file
|
5 |
|
6 |
models = json.load(open("models/models.json"))
|
7 |
TEXT_PATH = 'models/%s/text.txt'
|
|
|
21 |
"""
|
22 |
return read_lines(TEXT_PATH % model_name)
|
23 |
|
24 |
+
def get_text(model_name):
|
25 |
+
return read_file(TEXT_PATH % model_name)
|
26 |
+
|
27 |
def set_data(model_name, data):
|
28 |
"""
|
29 |
This function saves data to a file using the pickle module, with the filename specified by the
|
main.py
CHANGED
@@ -13,25 +13,6 @@ WEIGHTS_MAP = [
|
|
13 |
]
|
14 |
|
15 |
def get_next_word_results(db:dict, message:str, prev_word:str, text:str, repeat:int = 0):
|
16 |
-
"""
|
17 |
-
This function takes in a database, a message, a previous word, and an optional repeat count, and
|
18 |
-
returns a list of tokens from the database that match the previous word and have a score based on
|
19 |
-
their context in the message.
|
20 |
-
|
21 |
-
:param db: a dictionary containing information about words and their contexts
|
22 |
-
:param message: a string representing the message or text being analyzed
|
23 |
-
:type message: str
|
24 |
-
:param prev_word: The previous word that we want to find the next word(s) for
|
25 |
-
:type prev_word: str
|
26 |
-
:param repeat: The repeat parameter is an optional integer parameter that specifies how many times
|
27 |
-
the previous word can be repeated in the message before it is no longer considered a valid context
|
28 |
-
for the next word. If repeat is set to 0, then there is no limit on the number of times the previous
|
29 |
-
word can be repeated, defaults to 0
|
30 |
-
:type repeat: int (optional)
|
31 |
-
:return: a list of Token objects that are the next possible words in a given message based on the
|
32 |
-
previous word and its contexts in a database. If the previous word is not in the database, an empty
|
33 |
-
list is returned.
|
34 |
-
"""
|
35 |
results = []
|
36 |
if prev_word not in db:
|
37 |
return []
|
@@ -94,7 +75,7 @@ def generator(user_message, word_count, mode, model_name):
|
|
94 |
yield text
|
95 |
break
|
96 |
if i == 0 and text.strip() == '.':
|
97 |
-
raise gr.Error("Error
|
98 |
i += 1
|
99 |
yield text.strip()
|
100 |
|
@@ -103,11 +84,8 @@ demo = gr.Blocks(
|
|
103 |
)
|
104 |
|
105 |
title_html = """
|
106 |
-
<
|
107 |
-
|
108 |
-
<p>Generates text using per-word context system</p>
|
109 |
-
<a href="http://j93153xm.beget.tech/app/index.html?id=text-ai"><img src="https://img.shields.io/badge/Text%20Generator%20v1-RU%20only-brightgreen"></a>
|
110 |
-
</center>
|
111 |
"""
|
112 |
info_text = """
|
113 |
# Information about the models
|
@@ -127,16 +105,6 @@ info_text = """
|
|
127 |
`Language`: Russian
|
128 |
`Quality`: 7-8/10
|
129 |
`Sources`: http://staging.budsvetom.com/literature_items/ochen-dlinnyy-tekst
|
130 |
-
|
131 |
-
# Training
|
132 |
-
```bash
|
133 |
-
python train.py -r <models to train> [-t] [-l ...]
|
134 |
-
```
|
135 |
-
`--rebuild` (`-r`) - Models that will be trained.
|
136 |
-
`--turbo` (`-t`) - Enables turbo training. Will skip morphological analysis and just add all words directly.
|
137 |
-
`--log` (`-l`) - Logs listed databases to the console after training.
|
138 |
-
|
139 |
-
> **Note:** Use `--turbo` only when training with Russian texts.
|
140 |
"""
|
141 |
with demo:
|
142 |
gr.HTML(title_html)
|
|
|
13 |
]
|
14 |
|
15 |
def get_next_word_results(db:dict, message:str, prev_word:str, text:str, repeat:int = 0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
results = []
|
17 |
if prev_word not in db:
|
18 |
return []
|
|
|
75 |
yield text
|
76 |
break
|
77 |
if i == 0 and text.strip() == '.':
|
78 |
+
raise gr.Error("Error while generating. Please try again.")
|
79 |
i += 1
|
80 |
yield text.strip()
|
81 |
|
|
|
84 |
)
|
85 |
|
86 |
title_html = """
|
87 |
+
<h1>Text Generator v2</h1>
|
88 |
+
<a href="http://j93153xm.beget.tech/app/index.html?id=text-ai"><img src="https://img.shields.io/badge/Text%20Generator%20v1-RU%20only-brightgreen"></a>
|
|
|
|
|
|
|
89 |
"""
|
90 |
info_text = """
|
91 |
# Information about the models
|
|
|
105 |
`Language`: Russian
|
106 |
`Quality`: 7-8/10
|
107 |
`Sources`: http://staging.budsvetom.com/literature_items/ochen-dlinnyy-tekst
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
"""
|
109 |
with demo:
|
110 |
gr.HTML(title_html)
|
models/en/data.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:decf142f192f3ae9576f87a8aa119e39dd852317c8bdba2d83fc4eddebb3cc3b
|
3 |
+
size 3717733
|
models/en/data3.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:707e9ea7c67970a84c25f068aa74c130b45b11d1235d9e59efead924c6efd3a7
|
3 |
+
size 3698343
|
models/ru-lg/data.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa404e61cfc8a102e228335ed384af25fae86ecead0143acbff67954adde0bb7
|
3 |
+
size 3997228
|
models/ru-lg/data3.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f839b88e09df9f6fb3b010cdee4945ead7b8385d252cb0f30529d3aa5229d8eb
|
3 |
+
size 4022056
|
models/ru-lg/text.txt
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
models/ru-lite/data.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 560516
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63b42b3a11e0216544a7fca69986557e246cfb0cd773c2ca6687932fe9ede410
|
3 |
size 560516
|
models/ru-lite/data3.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 573246
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8bb5f68b452a5cacbdbac92b096f3cfd8fe18ca6012951a4f42ee21615dcf17
|
3 |
size 573246
|
train.py
CHANGED
@@ -5,20 +5,21 @@ from pprint import pprint
|
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
from classes import Token
|
8 |
-
from datamanager import get_data,
|
9 |
|
10 |
turbo = False
|
11 |
|
12 |
|
13 |
def normalize_text(sentence):
|
14 |
sentence = sentence.strip()
|
|
|
15 |
sentence = re.sub(r'\s+([.,!?;:])', r'\1', sentence)
|
16 |
sentence = re.sub(r'([.,!?;:])(\S)', r'\1 \2', sentence)
|
17 |
sentence = re.sub(r'\s+\'|\'\s+', '\'', sentence)
|
18 |
sentence = re.sub(r'\s+', ' ', sentence)
|
19 |
return sentence
|
20 |
|
21 |
-
def process_sentence(db, db3, sentence
|
22 |
words = sentence.strip().split()
|
23 |
for i in range(len(words)):
|
24 |
word = words[i].strip()
|
@@ -39,15 +40,9 @@ def train(model_name):
|
|
39 |
db = []
|
40 |
db3 = {}
|
41 |
print(f'Rebuilding database for "{model_name}"...')
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
for text in texts:
|
46 |
-
k += 1
|
47 |
-
print(f'Processing text {k} of {total_texts}...')
|
48 |
-
text = normalize_text(text)
|
49 |
-
process_text(db, db3, text)
|
50 |
-
|
51 |
set_data(model_name, db)
|
52 |
models[model_name]["db"] = db
|
53 |
set_data_v3(model_name, db3)
|
@@ -55,8 +50,8 @@ def train(model_name):
|
|
55 |
|
56 |
if __name__ == '__main__':
|
57 |
parser = argparse.ArgumentParser(
|
58 |
-
prog='
|
59 |
-
description='
|
60 |
parser.add_argument('-r', '--rebuild', action='extend', nargs="+", type=str)
|
61 |
parser.add_argument('-l', '--log', action='extend', nargs="+", type=str)
|
62 |
parser.add_argument('-t', '--turbo', action='store_true')
|
|
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
from classes import Token
|
8 |
+
from datamanager import get_data, get_text, models, set_data, set_data_v3
|
9 |
|
10 |
turbo = False
|
11 |
|
12 |
|
13 |
def normalize_text(sentence):
|
14 |
sentence = sentence.strip()
|
15 |
+
sentence = re.sub(r'(\s+|\n+)', ' ', sentence)
|
16 |
sentence = re.sub(r'\s+([.,!?;:])', r'\1', sentence)
|
17 |
sentence = re.sub(r'([.,!?;:])(\S)', r'\1 \2', sentence)
|
18 |
sentence = re.sub(r'\s+\'|\'\s+', '\'', sentence)
|
19 |
sentence = re.sub(r'\s+', ' ', sentence)
|
20 |
return sentence
|
21 |
|
22 |
+
def process_sentence(db, db3, sentence, text):
|
23 |
words = sentence.strip().split()
|
24 |
for i in range(len(words)):
|
25 |
word = words[i].strip()
|
|
|
40 |
db = []
|
41 |
db3 = {}
|
42 |
print(f'Rebuilding database for "{model_name}"...')
|
43 |
+
text = get_text(model_name)
|
44 |
+
text = normalize_text(text)
|
45 |
+
process_text(db, db3, text)
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
set_data(model_name, db)
|
47 |
models[model_name]["db"] = db
|
48 |
set_data_v3(model_name, db3)
|
|
|
50 |
|
51 |
if __name__ == '__main__':
|
52 |
parser = argparse.ArgumentParser(
|
53 |
+
prog='Train',
|
54 |
+
description='Training system for Text Generator v2')
|
55 |
parser.add_argument('-r', '--rebuild', action='extend', nargs="+", type=str)
|
56 |
parser.add_argument('-l', '--log', action='extend', nargs="+", type=str)
|
57 |
parser.add_argument('-t', '--turbo', action='store_true')
|