kertser commited on
Commit
8f30cba
1 Parent(s): 768ac34

Upload 6 files

Browse files

Updated an option to generate images on direct call

Files changed (5) hide show
  1. WarBot.py +43 -4
  2. WarOnline_Chat.py +33 -12
  3. WarServer.py +18 -2
  4. config.py +6 -3
  5. requirements.txt +4 -1
WarBot.py CHANGED
@@ -1,6 +1,6 @@
1
  # Main library for WarBot
2
 
3
- from transformers import AutoTokenizer ,AutoModelForCausalLM
4
  import re
5
  # Speller and punctuation:
6
  import os
@@ -10,6 +10,8 @@ from torch import package
10
  # not very necessary
11
  #import textwrap
12
  from textwrap3 import wrap
 
 
13
 
14
  # util function to get expected len after tokenizing
15
  def get_length_param(text: str, tokenizer) -> str:
@@ -41,8 +43,9 @@ def removeSigns(S):
41
 
42
  def prepare_punct():
43
  # Prepare the Punctuation Model
44
- # Important! Enable for Unix version (python related)
45
- # torch.backends.quantized.engine = 'qnnpack'
 
46
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
47
  'latest_silero_models.yml',
48
  progress=False)
@@ -69,12 +72,47 @@ def prepare_punct():
69
  return model_punct
70
 
71
  def initialize():
 
72
  """ Loading the model """
73
  fit_checkpoint = "WarBot"
74
  tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
75
  model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
76
  model_punсt = prepare_punct()
77
- return (model,tokenizer,model_punсt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def split_string(string,n=256):
80
  return [string[i:i+n] for i in range(0, len(string), n)]
@@ -84,6 +122,7 @@ def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
84
  try:
85
  user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
86
  + quote + tokenizer.eos_token, return_tensors="pt")
 
87
  except:
88
  return "Exception in tokenization" # Exception in tokenization
89
 
 
1
  # Main library for WarBot
2
 
3
+ from transformers import AutoTokenizer ,AutoModelForCausalLM, AutoModelForSeq2SeqLM
4
  import re
5
  # Speller and punctuation:
6
  import os
 
10
  # not very necessary
11
  #import textwrap
12
  from textwrap3 import wrap
13
+ import replicate #imaging
14
+
15
 
16
  # util function to get expected len after tokenizing
17
  def get_length_param(text: str, tokenizer) -> str:
 
43
 
44
  def prepare_punct():
45
  # Prepare the Punctuation Model
46
+ # Important! Enable next line for Unix version (python related):
47
+ torch.backends.quantized.engine = 'qnnpack'
48
+
49
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
50
  'latest_silero_models.yml',
51
  progress=False)
 
72
  return model_punct
73
 
74
  def initialize():
75
+ # Initializes all the settings
76
  """ Loading the model """
77
  fit_checkpoint = "WarBot"
78
  tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
79
  model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
80
  model_punсt = prepare_punct()
81
+
82
+ """ Initialize the translational model """
83
+ os.environ['REPLICATE_API_TOKEN'] = '2254e586b1380c49a948fd00d6802d45962492e4'
84
+ translation_model_name = "Helsinki-NLP/opus-mt-ru-en"
85
+ translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
86
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
87
+
88
+ """ Initialize the image model """
89
+ imageModel = replicate.models.get("stability-ai/stable-diffusion")
90
+ imgModel_version = imageModel.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478")
91
+
92
+ return (model, tokenizer, model_punсt, translation_model, translation_tokenizer, imgModel_version)
93
+
94
+ def translate(text:str,translation_model,translation_tokenizer):
95
+ # Translates from Russian to English
96
+ src = "ru" # source language
97
+ trg = "en" # target language
98
+
99
+ try:
100
+ batch = translation_tokenizer([text], return_tensors="pt")
101
+ generated_ids = translation_model.generate(**batch)
102
+ translated = translation_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
103
+ except:
104
+ translated = ""
105
+ return translated
106
+
107
+ def generate_image(prompt:str, imgModel_version):
108
+ # Generates an image from prompt and returns a url
109
+ prompt = prompt.replace("?","")
110
+ try:
111
+ output_url = imgModel_version.predict(prompt=prompt)[0]
112
+ except:
113
+ output_url = ""
114
+
115
+ return output_url
116
 
117
  def split_string(string,n=256):
118
  return [string[i:i+n] for i in range(0, len(string), n)]
 
122
  try:
123
  user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
124
  + quote + tokenizer.eos_token, return_tensors="pt")
125
+ # Better to force the lenparameter to be = {2}
126
  except:
127
  return "Exception in tokenization" # Exception in tokenization
128
 
WarOnline_Chat.py CHANGED
@@ -53,7 +53,7 @@ def remove_non_english_russian_chars(s):
53
 
54
  def remove_extra_spaces(s):
55
  s = re.sub(r"\s+", " ", s) # replace all sequences of whitespace with a single space
56
- s = re.sub(r"\s+([.,])", r"\1", s) # remove spaces before period or comma
57
  return(s)
58
 
59
  def getLastPage(thread_url=config.thread_url):
@@ -68,7 +68,6 @@ def getLastPage(thread_url=config.thread_url):
68
  lastPage = True
69
  return page
70
 
71
-
72
  def login(username=config.username, password=config.password, thread_url=config.thread_url):
73
  # Log-In to the forum and redirect to thread
74
 
@@ -92,12 +91,14 @@ def login(username=config.username, password=config.password, thread_url=config.
92
  print('Login failed!')
93
  exit()
94
 
95
- def post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by="",quote_text="",quote_source=""):
96
  #Post a message to the forum (with or without the quote
97
  #quote_source is in format 'post-3920992'
98
  quote_source = quote_source.split('-')[-1] # Take the numbers only
99
 
100
  if quoted_by:
 
 
101
  message = f'[QUOTE="{quoted_by}, post: {quote_source}"]{quote_text}[/QUOTE]{message}'
102
 
103
  # Retrieve the thread page HTML
@@ -315,16 +316,27 @@ def WarOnlineBot():
315
  message = fixString(message)
316
  print('Reply: ', message)
317
 
318
- # Add the new conversation pair to the database
319
- db.setmessages(username=msg['messengerName'], message_text=originalQuote, bot_reply=message)
320
- # Clean up the excessive records, leaving only the remaining messages
321
- db.cleanup(username=msg['messengerName'], remaining_messages=config.remaining_messages)
322
- # Delete the duplicate records
323
- db.deleteDuplicates()
 
 
324
 
325
- login(username=config.username, password=config.password, thread_url=config.thread_url)
326
- time.sleep(1)
327
- post(message=message, thread_url=config.thread_url, post_url=config.post_url, quoted_by=msg['messengerName'], quote_text=originalQuote, quote_source=msg['messageID'])
 
 
 
 
 
 
 
 
 
328
 
329
  time.sleep(10) # Standby time for server load release
330
 
@@ -335,6 +347,15 @@ if __name__ == '__main__':
335
  while True:
336
  print('Starting Session')
337
  WarOnlineBot()
 
 
 
 
 
 
 
 
 
338
  print('Session finished. Timeout...')
339
 
340
  timer = range(60 * config.timeout)
 
53
 
54
  def remove_extra_spaces(s):
55
  s = re.sub(r"\s+", " ", s) # replace all sequences of whitespace with a single space
56
+ s = re.sub(r"\s+([.,-])", r"\1", s) # remove spaces before period, dash or comma
57
  return(s)
58
 
59
  def getLastPage(thread_url=config.thread_url):
 
68
  lastPage = True
69
  return page
70
 
 
71
  def login(username=config.username, password=config.password, thread_url=config.thread_url):
72
  # Log-In to the forum and redirect to thread
73
 
 
91
  print('Login failed!')
92
  exit()
93
 
94
+ def post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by="",quote_text="",quote_source="",img_url=""):
95
  #Post a message to the forum (with or without the quote
96
  #quote_source is in format 'post-3920992'
97
  quote_source = quote_source.split('-')[-1] # Take the numbers only
98
 
99
  if quoted_by:
100
+ if img_url: # It is an image
101
+ message = f'Примерно вот так: \n[IMG]{img_url}[/IMG]' # Set the image block
102
  message = f'[QUOTE="{quoted_by}, post: {quote_source}"]{quote_text}[/QUOTE]{message}'
103
 
104
  # Retrieve the thread page HTML
 
316
  message = fixString(message)
317
  print('Reply: ', message)
318
 
319
+ if message.endswith('.png'): # It is an image reply:
320
+ # Post an image reply:
321
+ login(username=config.username, password=config.password, thread_url=config.thread_url)
322
+ time.sleep(1)
323
+ post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by=msg['messengerName'],
324
+ quote_text=originalQuote, quote_source=msg['messageID'],
325
+ img_url=message)
326
+ # will not be added to the database, if image is a reply
327
 
328
+ else:
329
+
330
+ # Add the new conversation pair to the database
331
+ db.setmessages(username=msg['messengerName'], message_text=originalQuote, bot_reply=message)
332
+ # Clean up the excessive records, leaving only the remaining messages
333
+ db.cleanup(username=msg['messengerName'], remaining_messages=config.remaining_messages)
334
+ # Delete the duplicate records
335
+ db.deleteDuplicates()
336
+
337
+ login(username=config.username, password=config.password, thread_url=config.thread_url)
338
+ time.sleep(1)
339
+ post(message=message, thread_url=config.thread_url, post_url=config.post_url, quoted_by=msg['messengerName'], quote_text=originalQuote, quote_source=msg['messageID'])
340
 
341
  time.sleep(10) # Standby time for server load release
342
 
 
347
  while True:
348
  print('Starting Session')
349
  WarOnlineBot()
350
+
351
+ # Debug Only:
352
+ #imgWord = 'как выглядит'
353
+ """
354
+ login(username=config.username, password=config.password, thread_url=config.thread_url)
355
+ print("logged in")
356
+ post(message="", thread_url=config.thread_url, post_url=config.post_url, quoted_by='Test',
357
+ quote_text="posting an image",img_url='https://replicate.delivery/pbxt/knKBiJt8DPZ0B1o25PaLJSZjgv3D5HcwLoBIn0JESbe3nISIA/out-0.png')
358
+ """
359
  print('Session finished. Timeout...')
360
 
361
  timer = range(60 * config.timeout)
WarServer.py CHANGED
@@ -2,15 +2,20 @@
2
  import socket
3
  import WarBot
4
 
 
5
  import warnings
6
  warnings.filterwarnings("ignore")
7
 
8
- model,tokenizer,model_punct = WarBot.initialize()
 
 
 
9
 
10
  HOST = '10.0.0.125'
11
  PORT = 5000
12
 
13
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
 
14
  server_socket.bind((HOST, PORT))
15
  server_socket.listen()
16
  print(f'Server is listening on port {PORT}')
@@ -20,11 +25,22 @@ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
20
  print(f'Connected by {addr}')
21
  data = conn.recv(1024)
22
  received_string = data.decode()
 
 
23
  print(f'Received string from client: {received_string}')
24
 
25
  response = ""
 
26
  while not response:
27
- response = WarBot.get_response(received_string, model, tokenizer, model_punct, temperature=0.6)
 
 
 
 
 
 
 
 
28
 
29
  response_string = response
30
 
 
2
  import socket
3
  import WarBot
4
 
5
+ # Kill all warnings
6
  import warnings
7
  warnings.filterwarnings("ignore")
8
 
9
+ imgWord = 'как выглядит'
10
+
11
+ # Initialize the base models
12
+ model,tokenizer,model_punct,translation_model,translation_tokenizer,imgModel_version = WarBot.initialize()
13
 
14
  HOST = '10.0.0.125'
15
  PORT = 5000
16
 
17
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
18
+ # Server sockets operation
19
  server_socket.bind((HOST, PORT))
20
  server_socket.listen()
21
  print(f'Server is listening on port {PORT}')
 
25
  print(f'Connected by {addr}')
26
  data = conn.recv(1024)
27
  received_string = data.decode()
28
+ #received_string = data.decode('utf-8')
29
+
30
  print(f'Received string from client: {received_string}')
31
 
32
  response = ""
33
+
34
  while not response:
35
+ if received_string.lower().startswith(imgWord): # check if that is a call for an image:
36
+ received_string = received_string.lower().split(imgWord)[1:][0].strip() # cut the code word for image
37
+ # Translate it to english:
38
+ translated_string = WarBot.translate(received_string, translation_model=translation_model,
39
+ translation_tokenizer=translation_tokenizer)
40
+ # Generated image url
41
+ response = WarBot.generate_image(prompt=translated_string, imgModel_version=imgModel_version)
42
+ else:
43
+ response = WarBot.get_response(received_string, model, tokenizer, model_punct, temperature=0.6)
44
 
45
  response_string = response
46
 
config.py CHANGED
@@ -8,8 +8,8 @@ post_url = "https://waronline.org/fora/index.php?threads/warbot-playground.17636
8
  # SSH settings
9
  ssh_host = '129.159.146.88'
10
  ssh_user = 'ubuntu'
11
- ssh_key_path = 'C:/Users/kerts/OneDrive/Documents/Keys/Ubuntu_Oracle/ssh-key-2023-02-12.key'
12
- #ssh_key_path = 'ssh-key-2023-02-12.key'
13
 
14
  # MySQL settings:
15
  mysql_host = 'localhost' # because we will connect through the SSH tunnel
@@ -35,4 +35,7 @@ MaxWords = 50 # The server is relatively weak to fast-process the long messages
35
  remaining_messages = 2
36
 
37
  # Time between the reply sessions:
38
- timeout = 5 # min
 
 
 
 
8
  # SSH settings
9
  ssh_host = '129.159.146.88'
10
  ssh_user = 'ubuntu'
11
+ #ssh_key_path = 'C:/Users/kerts/OneDrive/Documents/Keys/Ubuntu_Oracle/ssh-key-2023-02-12.key'
12
+ ssh_key_path = 'ssh-key-2023-02-12.key' #Important to change this on the target machine!
13
 
14
  # MySQL settings:
15
  mysql_host = 'localhost' # because we will connect through the SSH tunnel
 
35
  remaining_messages = 2
36
 
37
  # Time between the reply sessions:
38
+ timeout = 5 # min
39
+
40
+ # Call for image generation:
41
+ imgWord = 'как выглядит'
requirements.txt CHANGED
@@ -3,7 +3,7 @@ requests
3
  bs4
4
  transformers
5
  scikit-learn
6
- tensorboardX
7
  #gradio
8
  schedule
9
  tqdm
@@ -11,5 +11,8 @@ pyspellchecker
11
  paramiko
12
  pymysql
13
  sshtunnel
 
 
 
14
  #pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
15
  #pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
 
3
  bs4
4
  transformers
5
  scikit-learn
6
+ #tensorboardX
7
  #gradio
8
  schedule
9
  tqdm
 
11
  paramiko
12
  pymysql
13
  sshtunnel
14
+ # Following ones are foo translation and replication
15
+ sentencepiece
16
+ replicate
17
  #pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
18
  #pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117