Upload 6 files
Browse filesUpdated an option to generate images on direct call
- WarBot.py +43 -4
- WarOnline_Chat.py +33 -12
- WarServer.py +18 -2
- config.py +6 -3
- requirements.txt +4 -1
WarBot.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Main library for WarBot
|
2 |
|
3 |
-
from transformers import AutoTokenizer ,AutoModelForCausalLM
|
4 |
import re
|
5 |
# Speller and punctuation:
|
6 |
import os
|
@@ -10,6 +10,8 @@ from torch import package
|
|
10 |
# not very necessary
|
11 |
#import textwrap
|
12 |
from textwrap3 import wrap
|
|
|
|
|
13 |
|
14 |
# util function to get expected len after tokenizing
|
15 |
def get_length_param(text: str, tokenizer) -> str:
|
@@ -41,8 +43,9 @@ def removeSigns(S):
|
|
41 |
|
42 |
def prepare_punct():
|
43 |
# Prepare the Punctuation Model
|
44 |
-
# Important! Enable for Unix version (python related)
|
45 |
-
|
|
|
46 |
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
|
47 |
'latest_silero_models.yml',
|
48 |
progress=False)
|
@@ -69,12 +72,47 @@ def prepare_punct():
|
|
69 |
return model_punct
|
70 |
|
71 |
def initialize():
|
|
|
72 |
""" Loading the model """
|
73 |
fit_checkpoint = "WarBot"
|
74 |
tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
|
75 |
model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
|
76 |
model_punсt = prepare_punct()
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def split_string(string,n=256):
|
80 |
return [string[i:i+n] for i in range(0, len(string), n)]
|
@@ -84,6 +122,7 @@ def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
|
|
84 |
try:
|
85 |
user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
|
86 |
+ quote + tokenizer.eos_token, return_tensors="pt")
|
|
|
87 |
except:
|
88 |
return "Exception in tokenization" # Exception in tokenization
|
89 |
|
|
|
1 |
# Main library for WarBot
|
2 |
|
3 |
+
from transformers import AutoTokenizer ,AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
4 |
import re
|
5 |
# Speller and punctuation:
|
6 |
import os
|
|
|
10 |
# not very necessary
|
11 |
#import textwrap
|
12 |
from textwrap3 import wrap
|
13 |
+
import replicate #imaging
|
14 |
+
|
15 |
|
16 |
# util function to get expected len after tokenizing
|
17 |
def get_length_param(text: str, tokenizer) -> str:
|
|
|
43 |
|
44 |
def prepare_punct():
|
45 |
# Prepare the Punctuation Model
|
46 |
+
# Important! Enable next line for Unix version (python related):
|
47 |
+
torch.backends.quantized.engine = 'qnnpack'
|
48 |
+
|
49 |
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
|
50 |
'latest_silero_models.yml',
|
51 |
progress=False)
|
|
|
72 |
return model_punct
|
73 |
|
74 |
def initialize():
|
75 |
+
# Initializes all the settings
|
76 |
""" Loading the model """
|
77 |
fit_checkpoint = "WarBot"
|
78 |
tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
|
79 |
model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
|
80 |
model_punсt = prepare_punct()
|
81 |
+
|
82 |
+
""" Initialize the translational model """
|
83 |
+
os.environ['REPLICATE_API_TOKEN'] = '2254e586b1380c49a948fd00d6802d45962492e4'
|
84 |
+
translation_model_name = "Helsinki-NLP/opus-mt-ru-en"
|
85 |
+
translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
|
86 |
+
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
|
87 |
+
|
88 |
+
""" Initialize the image model """
|
89 |
+
imageModel = replicate.models.get("stability-ai/stable-diffusion")
|
90 |
+
imgModel_version = imageModel.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478")
|
91 |
+
|
92 |
+
return (model, tokenizer, model_punсt, translation_model, translation_tokenizer, imgModel_version)
|
93 |
+
|
94 |
+
def translate(text:str,translation_model,translation_tokenizer):
|
95 |
+
# Translates from Russian to English
|
96 |
+
src = "ru" # source language
|
97 |
+
trg = "en" # target language
|
98 |
+
|
99 |
+
try:
|
100 |
+
batch = translation_tokenizer([text], return_tensors="pt")
|
101 |
+
generated_ids = translation_model.generate(**batch)
|
102 |
+
translated = translation_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
103 |
+
except:
|
104 |
+
translated = ""
|
105 |
+
return translated
|
106 |
+
|
107 |
+
def generate_image(prompt:str, imgModel_version):
|
108 |
+
# Generates an image from prompt and returns a url
|
109 |
+
prompt = prompt.replace("?","")
|
110 |
+
try:
|
111 |
+
output_url = imgModel_version.predict(prompt=prompt)[0]
|
112 |
+
except:
|
113 |
+
output_url = ""
|
114 |
+
|
115 |
+
return output_url
|
116 |
|
117 |
def split_string(string,n=256):
|
118 |
return [string[i:i+n] for i in range(0, len(string), n)]
|
|
|
122 |
try:
|
123 |
user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
|
124 |
+ quote + tokenizer.eos_token, return_tensors="pt")
|
125 |
+
# Better to force the lenparameter to be = {2}
|
126 |
except:
|
127 |
return "Exception in tokenization" # Exception in tokenization
|
128 |
|
WarOnline_Chat.py
CHANGED
@@ -53,7 +53,7 @@ def remove_non_english_russian_chars(s):
|
|
53 |
|
54 |
def remove_extra_spaces(s):
|
55 |
s = re.sub(r"\s+", " ", s) # replace all sequences of whitespace with a single space
|
56 |
-
s = re.sub(r"\s+([
|
57 |
return(s)
|
58 |
|
59 |
def getLastPage(thread_url=config.thread_url):
|
@@ -68,7 +68,6 @@ def getLastPage(thread_url=config.thread_url):
|
|
68 |
lastPage = True
|
69 |
return page
|
70 |
|
71 |
-
|
72 |
def login(username=config.username, password=config.password, thread_url=config.thread_url):
|
73 |
# Log-In to the forum and redirect to thread
|
74 |
|
@@ -92,12 +91,14 @@ def login(username=config.username, password=config.password, thread_url=config.
|
|
92 |
print('Login failed!')
|
93 |
exit()
|
94 |
|
95 |
-
def post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by="",quote_text="",quote_source=""):
|
96 |
#Post a message to the forum (with or without the quote
|
97 |
#quote_source is in format 'post-3920992'
|
98 |
quote_source = quote_source.split('-')[-1] # Take the numbers only
|
99 |
|
100 |
if quoted_by:
|
|
|
|
|
101 |
message = f'[QUOTE="{quoted_by}, post: {quote_source}"]{quote_text}[/QUOTE]{message}'
|
102 |
|
103 |
# Retrieve the thread page HTML
|
@@ -315,16 +316,27 @@ def WarOnlineBot():
|
|
315 |
message = fixString(message)
|
316 |
print('Reply: ', message)
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
324 |
|
325 |
-
|
326 |
-
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
time.sleep(10) # Standby time for server load release
|
330 |
|
@@ -335,6 +347,15 @@ if __name__ == '__main__':
|
|
335 |
while True:
|
336 |
print('Starting Session')
|
337 |
WarOnlineBot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
print('Session finished. Timeout...')
|
339 |
|
340 |
timer = range(60 * config.timeout)
|
|
|
53 |
|
54 |
def remove_extra_spaces(s):
|
55 |
s = re.sub(r"\s+", " ", s) # replace all sequences of whitespace with a single space
|
56 |
+
s = re.sub(r"\s+([.,-])", r"\1", s) # remove spaces before period, dash or comma
|
57 |
return(s)
|
58 |
|
59 |
def getLastPage(thread_url=config.thread_url):
|
|
|
68 |
lastPage = True
|
69 |
return page
|
70 |
|
|
|
71 |
def login(username=config.username, password=config.password, thread_url=config.thread_url):
|
72 |
# Log-In to the forum and redirect to thread
|
73 |
|
|
|
91 |
print('Login failed!')
|
92 |
exit()
|
93 |
|
94 |
+
def post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by="",quote_text="",quote_source="",img_url=""):
|
95 |
#Post a message to the forum (with or without the quote
|
96 |
#quote_source is in format 'post-3920992'
|
97 |
quote_source = quote_source.split('-')[-1] # Take the numbers only
|
98 |
|
99 |
if quoted_by:
|
100 |
+
if img_url: # It is an image
|
101 |
+
message = f'Примерно вот так: \n[IMG]{img_url}[/IMG]' # Set the image block
|
102 |
message = f'[QUOTE="{quoted_by}, post: {quote_source}"]{quote_text}[/QUOTE]{message}'
|
103 |
|
104 |
# Retrieve the thread page HTML
|
|
|
316 |
message = fixString(message)
|
317 |
print('Reply: ', message)
|
318 |
|
319 |
+
if message.endswith('.png'): # It is an image reply:
|
320 |
+
# Post an image reply:
|
321 |
+
login(username=config.username, password=config.password, thread_url=config.thread_url)
|
322 |
+
time.sleep(1)
|
323 |
+
post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by=msg['messengerName'],
|
324 |
+
quote_text=originalQuote, quote_source=msg['messageID'],
|
325 |
+
img_url=message)
|
326 |
+
# will not be added to the database, if image is a reply
|
327 |
|
328 |
+
else:
|
329 |
+
|
330 |
+
# Add the new conversation pair to the database
|
331 |
+
db.setmessages(username=msg['messengerName'], message_text=originalQuote, bot_reply=message)
|
332 |
+
# Clean up the excessive records, leaving only the remaining messages
|
333 |
+
db.cleanup(username=msg['messengerName'], remaining_messages=config.remaining_messages)
|
334 |
+
# Delete the duplicate records
|
335 |
+
db.deleteDuplicates()
|
336 |
+
|
337 |
+
login(username=config.username, password=config.password, thread_url=config.thread_url)
|
338 |
+
time.sleep(1)
|
339 |
+
post(message=message, thread_url=config.thread_url, post_url=config.post_url, quoted_by=msg['messengerName'], quote_text=originalQuote, quote_source=msg['messageID'])
|
340 |
|
341 |
time.sleep(10) # Standby time for server load release
|
342 |
|
|
|
347 |
while True:
|
348 |
print('Starting Session')
|
349 |
WarOnlineBot()
|
350 |
+
|
351 |
+
# Debug Only:
|
352 |
+
#imgWord = 'как выглядит'
|
353 |
+
"""
|
354 |
+
login(username=config.username, password=config.password, thread_url=config.thread_url)
|
355 |
+
print("logged in")
|
356 |
+
post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by='Test',
|
357 |
+
quote_text="posting an image",img_url='https://replicate.delivery/pbxt/knKBiJt8DPZ0B1o25PaLJSZjgv3D5HcwLoBIn0JESbe3nISIA/out-0.png')
|
358 |
+
"""
|
359 |
print('Session finished. Timeout...')
|
360 |
|
361 |
timer = range(60 * config.timeout)
|
WarServer.py
CHANGED
@@ -2,15 +2,20 @@
|
|
2 |
import socket
|
3 |
import WarBot
|
4 |
|
|
|
5 |
import warnings
|
6 |
warnings.filterwarnings("ignore")
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
HOST = '10.0.0.125'
|
11 |
PORT = 5000
|
12 |
|
13 |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
|
|
|
14 |
server_socket.bind((HOST, PORT))
|
15 |
server_socket.listen()
|
16 |
print(f'Server is listening on port {PORT}')
|
@@ -20,11 +25,22 @@ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
|
|
20 |
print(f'Connected by {addr}')
|
21 |
data = conn.recv(1024)
|
22 |
received_string = data.decode()
|
|
|
|
|
23 |
print(f'Received string from client: {received_string}')
|
24 |
|
25 |
response = ""
|
|
|
26 |
while not response:
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
response_string = response
|
30 |
|
|
|
2 |
import socket
|
3 |
import WarBot
|
4 |
|
5 |
+
# Kill all warnings
|
6 |
import warnings
|
7 |
warnings.filterwarnings("ignore")
|
8 |
|
9 |
+
imgWord = 'как выглядит'
|
10 |
+
|
11 |
+
# Initialize the base models
|
12 |
+
model,tokenizer,model_punct,translation_model,translation_tokenizer,imgModel_version = WarBot.initialize()
|
13 |
|
14 |
HOST = '10.0.0.125'
|
15 |
PORT = 5000
|
16 |
|
17 |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
|
18 |
+
# Server sockets operation
|
19 |
server_socket.bind((HOST, PORT))
|
20 |
server_socket.listen()
|
21 |
print(f'Server is listening on port {PORT}')
|
|
|
25 |
print(f'Connected by {addr}')
|
26 |
data = conn.recv(1024)
|
27 |
received_string = data.decode()
|
28 |
+
#received_string = data.decode('utf-8')
|
29 |
+
|
30 |
print(f'Received string from client: {received_string}')
|
31 |
|
32 |
response = ""
|
33 |
+
|
34 |
while not response:
|
35 |
+
if received_string.lower().startswith(imgWord): # check if that is a call for an image:
|
36 |
+
received_string = received_string.lower().split(imgWord)[1:][0].strip() # cut the code word for image
|
37 |
+
# Translate it to english:
|
38 |
+
translated_string = WarBot.translate(received_string, translation_model=translation_model,
|
39 |
+
translation_tokenizer=translation_tokenizer)
|
40 |
+
# Generated image url
|
41 |
+
response = WarBot.generate_image(prompt=translated_string, imgModel_version=imgModel_version)
|
42 |
+
else:
|
43 |
+
response = WarBot.get_response(received_string, model, tokenizer, model_punct, temperature=0.6)
|
44 |
|
45 |
response_string = response
|
46 |
|
config.py
CHANGED
@@ -8,8 +8,8 @@ post_url = "https://waronline.org/fora/index.php?threads/warbot-playground.17636
|
|
8 |
# SSH settings
|
9 |
ssh_host = '129.159.146.88'
|
10 |
ssh_user = 'ubuntu'
|
11 |
-
ssh_key_path = 'C:/Users/kerts/OneDrive/Documents/Keys/Ubuntu_Oracle/ssh-key-2023-02-12.key'
|
12 |
-
|
13 |
|
14 |
# MySQL settings:
|
15 |
mysql_host = 'localhost' # because we will connect through the SSH tunnel
|
@@ -35,4 +35,7 @@ MaxWords = 50 # The server is relatively weak to fast-process the long messages
|
|
35 |
remaining_messages = 2
|
36 |
|
37 |
# Time between the reply sessions:
|
38 |
-
timeout = 5 # min
|
|
|
|
|
|
|
|
8 |
# SSH settings
|
9 |
ssh_host = '129.159.146.88'
|
10 |
ssh_user = 'ubuntu'
|
11 |
+
#ssh_key_path = 'C:/Users/kerts/OneDrive/Documents/Keys/Ubuntu_Oracle/ssh-key-2023-02-12.key'
|
12 |
+
ssh_key_path = 'ssh-key-2023-02-12.key' #Important to change this on the target machine!
|
13 |
|
14 |
# MySQL settings:
|
15 |
mysql_host = 'localhost' # because we will connect through the SSH tunnel
|
|
|
35 |
remaining_messages = 2
|
36 |
|
37 |
# Time between the reply sessions:
|
38 |
+
timeout = 5 # min
|
39 |
+
|
40 |
+
# Call for image generation:
|
41 |
+
imgWord = 'как выглядит'
|
requirements.txt
CHANGED
@@ -3,7 +3,7 @@ requests
|
|
3 |
bs4
|
4 |
transformers
|
5 |
scikit-learn
|
6 |
-
tensorboardX
|
7 |
#gradio
|
8 |
schedule
|
9 |
tqdm
|
@@ -11,5 +11,8 @@ pyspellchecker
|
|
11 |
paramiko
|
12 |
pymysql
|
13 |
sshtunnel
|
|
|
|
|
|
|
14 |
#pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
|
15 |
#pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|
|
|
3 |
bs4
|
4 |
transformers
|
5 |
scikit-learn
|
6 |
+
#tensorboardX
|
7 |
#gradio
|
8 |
schedule
|
9 |
tqdm
|
|
|
11 |
paramiko
|
12 |
pymysql
|
13 |
sshtunnel
|
14 |
+
# Following ones are foo translation and replication
|
15 |
+
sentencepiece
|
16 |
+
replicate
|
17 |
#pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
|
18 |
#pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
|