kertser commited on
Commit
2aed2a1
1 Parent(s): e281f7f

Upload 5 files

Browse files

Model training and evaluation

Train_WarBot_of_GPT.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
WarBot_test.ipynb ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "collapsed": true
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "from transformers import AutoTokenizer ,AutoModelForCausalLM\n",
12
+ "import torch\n",
13
+ "import re\n",
14
+ "from sklearn.utils import shuffle"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "outputs": [],
21
+ "source": [
22
+ "# util function to get expected len after tokenizing\n",
23
+ "def get_length_param(text: str, tokenizer) -> str:\n",
24
+ " tokens_count = len(tokenizer.encode(text))\n",
25
+ " if tokens_count <= 15:\n",
26
+ " len_param = '1'\n",
27
+ " elif tokens_count <= 50:\n",
28
+ " len_param = '2'\n",
29
+ " elif tokens_count <= 256:\n",
30
+ " len_param = '3'\n",
31
+ " else:\n",
32
+ " len_param = '-'\n",
33
+ " return len_param"
34
+ ],
35
+ "metadata": {
36
+ "collapsed": false
37
+ }
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 7,
42
+ "outputs": [],
43
+ "source": [
44
+ "def remove_duplicates(S):\n",
45
+ " S = re.sub(r'[a-zA-Z]+', '', S) #Remove english\n",
46
+ " S = S.split()\n",
47
+ " result = \"\"\n",
48
+ " for subst in S:\n",
49
+ " if subst not in result:\n",
50
+ " result += subst+\" \"\n",
51
+ " return result.rstrip()"
52
+ ],
53
+ "metadata": {
54
+ "collapsed": false
55
+ }
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 3,
60
+ "outputs": [],
61
+ "source": [
62
+ "fit_checkpoint = \"WarBot\"\n",
63
+ "tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)\n",
64
+ "model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)"
65
+ ],
66
+ "metadata": {
67
+ "collapsed": false
68
+ }
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 12,
73
+ "outputs": [],
74
+ "source": [
75
+ "quote = \"Однажды мы проснёмся и поймём, что бригада Нахаль наваляла десантникам по самые помидоры\""
76
+ ],
77
+ "metadata": {
78
+ "collapsed": false
79
+ }
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 72,
84
+ "outputs": [],
85
+ "source": [
86
+ "# encode the input, add the eos_token and return a tensor in Pytorch\n",
87
+ "user_inpit_ids = tokenizer.encode(f\"|0|{get_length_param(quote, tokenizer)}|\" \\\n",
88
+ " + quote + tokenizer.eos_token, return_tensors=\"pt\")\n",
89
+ "\n",
90
+ "#chat_history_ids = torch.cat([chat_history_ids, user_inpit_ids], dim=-1)\n",
91
+ "\n",
92
+ "chat_history_ids = user_inpit_ids # To be changed\n",
93
+ "\n",
94
+ "output_id = model.generate(\n",
95
+ " chat_history_ids,\n",
96
+ " num_return_sequences=1, # use for more variants, but have to print [i]\n",
97
+ " max_length=300, #512\n",
98
+ " no_repeat_ngram_size=1, #3\n",
99
+ " do_sample=True, #True\n",
100
+ " top_k=50,#50\n",
101
+ " top_p=0.9, #0.9\n",
102
+ " temperature = 0.45, # was 0.6, 0 for greedy\n",
103
+ " #mask_token_id=tokenizer.mask_token_id,\n",
104
+ " eos_token_id=tokenizer.eos_token_id,\n",
105
+ " #unk_token_id=tokenizer.unk_token_id,\n",
106
+ " pad_token_id=tokenizer.pad_token_id,\n",
107
+ " #pad_token_id=tokenizer.eos_token_id,\n",
108
+ " #device='cpu'\n",
109
+ " )"
110
+ ],
111
+ "metadata": {
112
+ "collapsed": false
113
+ }
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 59,
118
+ "outputs": [],
119
+ "source": [
120
+ "def removeSigns(S):\n",
121
+ " last_index = max(S.rfind(\".\"), S.rfind(\"!\"))\n",
122
+ " if last_index >= 0:\n",
123
+ " S = S[:last_index+1]\n",
124
+ " return S"
125
+ ],
126
+ "metadata": {
127
+ "collapsed": false
128
+ }
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 63,
133
+ "outputs": [],
134
+ "source": [
135
+ "def getResponce():\n",
136
+ " response = tokenizer.decode(output_id[0], skip_special_tokens=True)\n",
137
+ " response = removeSigns(response)\n",
138
+ " #response = re.sub(r'[^а-яА-Я;.,!?]', '', response) # Clear the response, remains only russian\n",
139
+ " response_с = response.split(quote)[-1] #Remove the Quote\n",
140
+ " clean_response = remove_duplicates(re.sub(r\"\\d{4,}\", \"\", response_с)) # Remove the consequent numbers with 4 or more digits\n",
141
+ " return clean_response"
142
+ ],
143
+ "metadata": {
144
+ "collapsed": false
145
+ }
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 73,
150
+ "outputs": [
151
+ {
152
+ "name": "stdout",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "Response: в бригаде есть несколько батальонов \"йерихон\". они основном для того чтобы отражать атаки хезов. батальонный уровень это ротные минометы на джипах прицепами (на уровне батальона). если них будет достаточно ракет могут даже накрыть батарею 120мм минометной установки или из состава бригады может быть хуже чем батареи 122 м-109 которые находятся под управлением роты/бат аля рейнджеры... опять все зависит ситуации например после первой ливанской артиллеристы стали очень сильно нервничать когда обстреливали израильские бпла типа 28, как приходилось отвечать свои задачи. теперь вот примеру американцы решили полностью перевести всю бригаду второй эшелон : 1) сократив количество артдивизионов 4х; 3 пехотных батальонах(м113); 5 танковых + отдельный армейский который сможет прикрывать танки непосредственно перед атакой противника.. правда нужно еще иметь возможность поддерживать свой штатную авиацию огнем своего штатного места без необходимости перебрасывать туда часть танков.... короче говоря вся эта система должна работать вместе /при условии полного отсутствия взаимозачетчиков между ними...но тут надо смотреть кто первый окажется дежурным батареей техасских коптеров..и итп....\n"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "print(\"Response:\",getResponce())"
161
+ ],
162
+ "metadata": {
163
+ "collapsed": false
164
+ }
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "source": [
169
+ "Spelling Fix:"
170
+ ],
171
+ "metadata": {
172
+ "collapsed": false
173
+ }
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 74,
178
+ "outputs": [],
179
+ "source": [
180
+ "from autocorrect import Speller\n",
181
+ "spell = Speller('ru')\n",
182
+ "spell_fix_response = spell(getResponce())"
183
+ ],
184
+ "metadata": {
185
+ "collapsed": false
186
+ }
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 75,
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "в бригаде есть несколько батальонов \"рихон\". они основном для того чтобы отражать атаки уезов. батальонный уровень это ротные минометы на джипах прицепами (на уровне батальона). если них будет достаточно ракет могут даже накрыть батарею 120мм минометной установки или из состава бригады может быть хуже чем батареи 122 м-109 которые находятся под управлением роты/бат аля рейнджеры... опять все зависит ситуации например после первой ливанской артиллеристы стали очень сильно нервничать когда обстреливали израильские была типа 28, как приходилось отвечать свои задачи. теперь вот примеру американцы решили полностью перевести всю бригаду второй эшелон : 1) сократив количество артдивизионов 4х; 3 пехотных батальонах(м113); 5 танковых + отдельный армейский который сможет прикрывать танки непосредственно перед атакой противника.. правда нужно еще иметь возможность поддерживать свой штатную авиацию огнем своего штатного места без необходимости перебрасывать туда часть танков.... короче говоря вся эта система должна работать вместе /при условии полного отсутствия взаимозачетчиков между ними...но тут надо смотреть кто первый окажется дежурным батареей техасских коптеров..и итп....\n"
197
+ ]
198
+ }
199
+ ],
200
+ "source": [
201
+ "print(spell_fix_response)"
202
+ ],
203
+ "metadata": {
204
+ "collapsed": false
205
+ }
206
+ }
207
+ ],
208
+ "metadata": {
209
+ "kernelspec": {
210
+ "display_name": "Python 3",
211
+ "language": "python",
212
+ "name": "python3"
213
+ },
214
+ "language_info": {
215
+ "codemirror_mode": {
216
+ "name": "ipython",
217
+ "version": 2
218
+ },
219
+ "file_extension": ".py",
220
+ "mimetype": "text/x-python",
221
+ "name": "python",
222
+ "nbconvert_exporter": "python",
223
+ "pygments_lexer": "ipython2",
224
+ "version": "2.7.6"
225
+ }
226
+ },
227
+ "nbformat": 4,
228
+ "nbformat_minor": 0
229
+ }
datacleaner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import glob
3
+ import re
4
+
5
+ def clean(text):
6
+ if type(text) == str:
7
+ url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
8
+ email_pattern = re.compile(r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}')
9
+ www_pattern = re.compile(r'\b\w*www\.\w*\b')
10
+ ftp_pattern = re.compile(r'ftp://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
11
+ file_pattern = re.compile(r'file://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
12
+ underscore_pattern = re.compile(r'\b\w*_\w*\b')
13
+ quote_pattern = re.compile(r'""([^"]*)""')
14
+ curly_brackets_pattern = re.compile(r'\{[^}]*\}')
15
+ post_pattern = re.compile(r'#post')
16
+ write_pattern = re.compile(r'(\b\w+)\sнаписал\(а\)\b')
17
+
18
+ text = re.sub(url_pattern, '', text)
19
+ text = re.sub(email_pattern, '', text)
20
+ text = re.sub(ftp_pattern, '', text)
21
+ text = re.sub(www_pattern, '', text)
22
+ text = re.sub(underscore_pattern, '', text)
23
+ #text = re.sub(quote_pattern, '', text)
24
+ text = re.sub(curly_brackets_pattern, '', text)
25
+ text = re.sub(file_pattern, '', text)
26
+ text = re.sub(post_pattern, '', text)
27
+ text = re.sub(write_pattern, '', text)
28
+
29
+ return text
30
+
31
+ path = r'Data' # use your path
32
+ all_files = glob.glob(path + "/*.csv")
33
+
34
+ li = []
35
+
36
+ for filename in all_files:
37
+ df = pd.read_csv(filename, index_col=None, header=0)
38
+ li.append(df)
39
+
40
+ frame = pd.concat(li, axis=0, ignore_index=True)
41
+
42
+ frame = frame.drop_duplicates()
43
+
44
+ #frame = frame[~frame.applymap(lambda x: x == 'nan').any(1)]
45
+ frame = frame.applymap(clean)
46
+ frame = frame.applymap(lambda x: str(x).replace("посмотреть вложение", ""))
47
+ #frame = frame.applymap(lambda x: str(x).replace('"', ''))
48
+
49
+ # And again:
50
+ frame = frame.drop_duplicates()
51
+ frame = frame[frame.apply(lambda x: 'nan' not in x.values, axis=1)]
52
+
53
+ frame.to_csv(path+'/combined.csv',index = False)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ requests
3
+ bs4
4
+ transformers
5
+ scikit-learn
6
+ tensorboardX
7
+ sentencepiece # summaruization
8
+ autocorrect # spelling
9
+ # pip install git+https://github.com/RussianNLP/russian_paraphrasers@master
10
+ #pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
11
+ #pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
waronlineforum.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """WarOnlineForum.ipynb"""
3
+
4
+ # Extracting messages from forum
5
+
6
+ import requests
7
+ from bs4 import BeautifulSoup
8
+ import re
9
+ import pandas as pd
10
+ import urllib.request as urllib
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # initiate the corpus of Quote->Response texts
15
+ corpus = pd.DataFrame(columns=['Quote', 'Response'])
16
+
17
+ def remove_substring(string, substring):
18
+ index = string.find(substring)
19
+ if index != -1:
20
+ start_index = string.rfind(" ", 0, index) + 1
21
+ end_index = string.find(" ", index)
22
+ if end_index == -1:
23
+ end_index = len(string)
24
+ return string[:start_index] + string[end_index:]
25
+ return string
26
+
27
+ def remove_attachments(string, substring='Посмотреть вложение'):
28
+ index = string.find(substring)
29
+ if index != -1:
30
+ end_index = string.find(" ", index)
31
+ if end_index == -1:
32
+ end_index = len(string)
33
+ return string[:index] + string[end_index:]
34
+ return string
35
+
36
+ def collectDataFromPage(url):
37
+ # specify the URL of the XenForo forum page you want to extract messages from
38
+
39
+ # send a request to the URL and get the HTML response
40
+ response = requests.get(url)
41
+ html = response.content
42
+
43
+ # parse the HTML using BeautifulSoup
44
+ soup = BeautifulSoup(response.content, "html.parser")
45
+
46
+ # Find all elements with class "messageContent"
47
+ message_contents = soup.find_all("div", class_="bbWrapper")
48
+
49
+ # Loop through each messageContent element
50
+ for message_content in message_contents:
51
+ # Find the text within the messageContent element
52
+ message_text = message_content.text.strip()
53
+
54
+ # Find the quoted text within the messageContent element
55
+ try:
56
+ quoted_text = message_content.find("blockquote").text.strip()
57
+ quoted_text = ''.join(BeautifulSoup(quoted_text, "html.parser").findAll(string=True))
58
+ quoted_text = quoted_text.replace('Нажмите для раскрытия...', '')
59
+ message_text = message_text.replace('Нажмите для раскрытия...', '')
60
+ # Remove the text in between "bbCodeBlock-expandLink js-expandLink" and "</div>"
61
+
62
+
63
+ # Print the message text and quoted text
64
+ Quote = re.sub(r'http\S+', '', ' '.join(quoted_text.split()).partition('(а): ')[2])
65
+ Quote = remove_substring(Quote,".com")
66
+ Quote = remove_attachments(Quote)
67
+ Quote = ' '.join(remove_substring(Quote,"@").split())
68
+
69
+ Message = ' '.join(message_text.replace(quoted_text,'').split())
70
+ Message = remove_substring(Message,".com")
71
+ Message = remove_attachments(Message)
72
+ Message = ' '.join(remove_substring(Message,"@").split())
73
+
74
+ if Message and Quote:
75
+ # corpus is a dataframe (global)
76
+ corpus.loc[len(corpus)]=[Quote,Message]
77
+ #print("Quoted Text:", Quote)
78
+ #print("Message Text:", Message)
79
+ #print('________________________')
80
+ except:
81
+ pass
82
+
83
+ def compare_pages(url1, url2):
84
+ page1 = requests.get(url1).text
85
+ page2 = requests.get(url2).text
86
+ # Stupid, but must be working
87
+ return len(page1) == len(page2)
88
+
89
+ def compare_pages2(url1, url2):
90
+ return urllib.urlopen(url1).geturl() == urllib.urlopen(url2).geturl()
91
+
92
+
93
+ def pages_of_thread(thread,startingPage=1):
94
+ page = startingPage
95
+ lastPage = False
96
+ while not lastPage:
97
+ response = requests.get(thread+'/page-'+str(page))
98
+ if response.status_code == 200:
99
+ collectDataFromPage(url = thread+'/page-'+str(page))
100
+ print(f'finished page #{page}')
101
+ if not compare_pages2(thread+'/page-'+str(page),thread+'/page-'+str(page+1)):
102
+ page+=1
103
+ else:
104
+ lastPage = True
105
+ else:
106
+ lastPage = True
107
+
108
+ # Usage Example:
109
+ #pages_of_thread(0,800) # Thread #0, starting page 800
110
+
111
+ """______________________________________ Main Code __________________________________________"""
112
+
113
+ # Define the URLs to be crawled
114
+ base_url = 'https://waronline.org'
115
+ # Pehota base subforum
116
+ #url = "https://waronline.org/fora/index.php?forums/%D0%9F%D0%B5%D1%85%D0%BE%D1%82%D0%B0.3/"
117
+ # Obshevoyskovie base subforum
118
+ #url = "https://waronline.org/fora/index.php?forums/%D0%9E%D0%B1%D1%89%D0%B5%D0%B2%D0%BE%D0%B9%D1%81%D0%BA%D0%BE%D0%B2%D1%8B%D0%B5-%D1%82%D0%B5%D0%BC%D1%8B.4/"
119
+ # VMF
120
+ url = "https://waronline.org/fora/index.php?forums/%D0%92%D0%9C%D0%A4-%D0%B3%D1%80%D0%B0%D0%B6%D0%B4%D0%B0%D0%BD%D1%81%D0%BA%D0%B8%D0%B9-%D1%84%D0%BB%D0%BE%D1%82.12/"
121
+
122
+ base_page = 1 #Starting with page-1
123
+ lastSubForumPage = False
124
+
125
+ while not lastSubForumPage:
126
+
127
+ # Send a GET request to the URL
128
+ response = requests.get(url+'page-'+str(base_page))
129
+ forum_threads = [] #threads on this page of subforum
130
+
131
+ # Check if the request was successful
132
+ if response.status_code == 200:
133
+ # Parse the HTML content of the page
134
+ soup = BeautifulSoup(response.content, "html.parser")
135
+
136
+ # Get all the thread-links on the page
137
+ links = soup.find_all("a")
138
+
139
+ # Get the links
140
+ for link in links:
141
+ lnk = link.get("href")
142
+ if lnk:
143
+ if 'threads' in lnk:
144
+ forum_threads.append((base_url+lnk).rsplit("/", 1)[0])
145
+
146
+ # Clear the duplicate links
147
+ forum_threads = list(set(forum_threads))
148
+
149
+ for trd in forum_threads:
150
+ pages_of_thread(trd) # Starting at page=1
151
+ print(f'finished thread: {trd}')
152
+
153
+ if not compare_pages2(url+'page-'+str(base_page),url+'page-'+str(base_page+1)):
154
+ print(f'finished subforum page #{base_page}')
155
+ base_page+=1
156
+ else:
157
+ lastSubForumPage = True
158
+
159
+ else:
160
+ print("Failed to load the page")
161
+ lastSubForumPage = True
162
+
163
+ # Lowercase all
164
+ corpus['Quote'] = corpus['Quote'].apply(lambda x: x.lower() if isinstance(x,str) else x)
165
+ corpus['Response'] = corpus['Response'].apply(lambda x: x.lower() if isinstance(x,str) else x)
166
+
167
+ # Remove all non-alphanumericals
168
+ corpus.Quote.str.replace('[^a-zA-Z]', '')
169
+ corpus.Response.str.replace('[^a-zA-Z]', '')
170
+
171
+ #Export to csv
172
+ pathToDrive = ''
173
+ filename = 'part5.csv'
174
+ corpus.to_csv(pathToDrive+filename,index=False)