rodrigomasini commited on
Commit
ba553c4
1 Parent(s): 1cb7677

Upload 12 files

Browse files
modules/callbacks.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import traceback
3
+ from queue import Queue
4
+ from threading import Thread
5
+
6
+ import torch
7
+ import transformers
8
+
9
+ import modules.shared as shared
10
+
11
+
12
+ class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
13
+ def __init__(self):
14
+ transformers.StoppingCriteria.__init__(self)
15
+
16
+ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
17
+ return shared.stop_everything
18
+
19
+
20
+ class Stream(transformers.StoppingCriteria):
21
+ def __init__(self, callback_func=None):
22
+ self.callback_func = callback_func
23
+
24
+ def __call__(self, input_ids, scores) -> bool:
25
+ if self.callback_func is not None:
26
+ self.callback_func(input_ids[0])
27
+ return False
28
+
29
+
30
+ class Iteratorize:
31
+
32
+ """
33
+ Transforms a function that takes a callback
34
+ into a lazy iterator (generator).
35
+
36
+ Adapted from: https://stackoverflow.com/a/9969000
37
+ """
38
+
39
+ def __init__(self, func, args=None, kwargs=None, callback=None):
40
+ self.mfunc = func
41
+ self.c_callback = callback
42
+ self.q = Queue()
43
+ self.sentinel = object()
44
+ self.args = args or []
45
+ self.kwargs = kwargs or {}
46
+ self.stop_now = False
47
+
48
+ def _callback(val):
49
+ if self.stop_now or shared.stop_everything:
50
+ raise ValueError
51
+ self.q.put(val)
52
+
53
+ def gentask():
54
+ try:
55
+ ret = self.mfunc(callback=_callback, *args, **self.kwargs)
56
+ except ValueError:
57
+ pass
58
+ except:
59
+ traceback.print_exc()
60
+ pass
61
+
62
+ clear_torch_cache()
63
+ self.q.put(self.sentinel)
64
+ if self.c_callback:
65
+ self.c_callback(ret)
66
+
67
+ self.thread = Thread(target=gentask)
68
+ self.thread.start()
69
+
70
+ def __iter__(self):
71
+ return self
72
+
73
+ def __next__(self):
74
+ obj = self.q.get(True, None)
75
+ if obj is self.sentinel:
76
+ raise StopIteration
77
+ else:
78
+ return obj
79
+
80
+ def __del__(self):
81
+ clear_torch_cache()
82
+
83
+ def __enter__(self):
84
+ return self
85
+
86
+ def __exit__(self, exc_type, exc_val, exc_tb):
87
+ self.stop_now = True
88
+ clear_torch_cache()
89
+
90
+
91
+ def clear_torch_cache():
92
+ gc.collect()
93
+ if not shared.args.cpu:
94
+ torch.cuda.empty_cache()
modules/chat.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import functools
4
+ import json
5
+ import re
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ import yaml
11
+ from PIL import Image
12
+
13
+ import modules.shared as shared
14
+ from modules.extensions import apply_extensions
15
+ from modules.html_generator import chat_html_wrapper, make_thumbnail
16
+ from modules.logging_colors import logger
17
+ from modules.text_generation import (
18
+ generate_reply,
19
+ get_encoded_length,
20
+ get_max_prompt_length
21
+ )
22
+ from modules.utils import (
23
+ delete_file,
24
+ get_available_characters,
25
+ replace_all,
26
+ save_file
27
+ )
28
+
29
+
30
+ def str_presenter(dumper, data):
31
+ """
32
+ Copied from https://github.com/yaml/pyyaml/issues/240
33
+ Makes pyyaml output prettier multiline strings.
34
+ """
35
+
36
+ if data.count('\n') > 0:
37
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
38
+
39
+ return dumper.represent_scalar('tag:yaml.org,2002:str', data)
40
+
41
+
42
+ yaml.add_representer(str, str_presenter)
43
+ yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
44
+
45
+
46
+ def get_turn_substrings(state, instruct=False):
47
+ if instruct:
48
+ if 'turn_template' not in state or state['turn_template'] == '':
49
+ template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
50
+ else:
51
+ template = state['turn_template'].replace(r'\n', '\n')
52
+ else:
53
+ template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
54
+
55
+ replacements = {
56
+ '<|user|>': state['name1_instruct' if instruct else 'name1'].strip(),
57
+ '<|bot|>': state['name2_instruct' if instruct else 'name2'].strip(),
58
+ }
59
+
60
+ output = {
61
+ 'user_turn': template.split('<|bot|>')[0],
62
+ 'bot_turn': '<|bot|>' + template.split('<|bot|>')[1],
63
+ 'user_turn_stripped': template.split('<|bot|>')[0].split('<|user-message|>')[0],
64
+ 'bot_turn_stripped': '<|bot|>' + template.split('<|bot|>')[1].split('<|bot-message|>')[0],
65
+ }
66
+
67
+ for k in output:
68
+ output[k] = replace_all(output[k], replacements)
69
+
70
+ return output
71
+
72
+
73
+ def generate_chat_prompt(user_input, state, **kwargs):
74
+ impersonate = kwargs.get('impersonate', False)
75
+ _continue = kwargs.get('_continue', False)
76
+ also_return_rows = kwargs.get('also_return_rows', False)
77
+ history = kwargs.get('history', state['history'])['internal']
78
+ is_instruct = state['mode'] == 'instruct'
79
+
80
+ # Find the maximum prompt size
81
+ max_length = get_max_prompt_length(state)
82
+ all_substrings = {
83
+ 'chat': get_turn_substrings(state, instruct=False),
84
+ 'instruct': get_turn_substrings(state, instruct=True)
85
+ }
86
+
87
+ substrings = all_substrings['instruct' if is_instruct else 'chat']
88
+
89
+ # Create the template for "chat-instruct" mode
90
+ if state['mode'] == 'chat-instruct':
91
+ wrapper = ''
92
+ command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
93
+ wrapper += state['context_instruct']
94
+ wrapper += all_substrings['instruct']['user_turn'].replace('<|user-message|>', command)
95
+ wrapper += all_substrings['instruct']['bot_turn_stripped']
96
+ if impersonate:
97
+ wrapper += substrings['user_turn_stripped'].rstrip(' ')
98
+ elif _continue:
99
+ wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'], state)
100
+ wrapper += history[-1][1]
101
+ else:
102
+ wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state)
103
+ else:
104
+ wrapper = '<|prompt|>'
105
+
106
+ if is_instruct:
107
+ context = state['context_instruct']
108
+ else:
109
+ context = replace_character_names(
110
+ f"{state['context'].strip()}\n",
111
+ state['name1'],
112
+ state['name2']
113
+ )
114
+
115
+ # Build the prompt
116
+ rows = [context]
117
+ min_rows = 3
118
+ i = len(history) - 1
119
+ while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length:
120
+ if _continue and i == len(history) - 1:
121
+ if state['mode'] != 'chat-instruct':
122
+ rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip())
123
+ else:
124
+ rows.insert(1, substrings['bot_turn'].replace('<|bot-message|>', history[i][1].strip()))
125
+
126
+ string = history[i][0]
127
+ if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
128
+ rows.insert(1, replace_all(substrings['user_turn'], {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
129
+
130
+ i -= 1
131
+
132
+ if impersonate:
133
+ if state['mode'] == 'chat-instruct':
134
+ min_rows = 1
135
+ else:
136
+ min_rows = 2
137
+ rows.append(substrings['user_turn_stripped'].rstrip(' '))
138
+ elif not _continue:
139
+ # Add the user message
140
+ if len(user_input) > 0:
141
+ rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
142
+
143
+ # Add the character prefix
144
+ if state['mode'] != 'chat-instruct':
145
+ rows.append(apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state))
146
+
147
+ while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length:
148
+ rows.pop(1)
149
+
150
+ prompt = wrapper.replace('<|prompt|>', ''.join(rows))
151
+ if also_return_rows:
152
+ return prompt, rows
153
+ else:
154
+ return prompt
155
+
156
+
157
+ def get_stopping_strings(state):
158
+ stopping_strings = []
159
+ if state['mode'] in ['instruct', 'chat-instruct']:
160
+ stopping_strings += [
161
+ state['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] + '<|bot|>',
162
+ state['turn_template'].split('<|bot-message|>')[1] + '<|user|>'
163
+ ]
164
+
165
+ replacements = {
166
+ '<|user|>': state['name1_instruct'],
167
+ '<|bot|>': state['name2_instruct']
168
+ }
169
+
170
+ for i in range(len(stopping_strings)):
171
+ stopping_strings[i] = replace_all(stopping_strings[i], replacements).rstrip(' ').replace(r'\n', '\n')
172
+
173
+ if state['mode'] in ['chat', 'chat-instruct']:
174
+ stopping_strings += [
175
+ f"\n{state['name1']}:",
176
+ f"\n{state['name2']}:"
177
+ ]
178
+
179
+ if state['stop_at_newline']:
180
+ stopping_strings.append("\n")
181
+
182
+ return stopping_strings
183
+
184
+
185
+ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True):
186
+ history = state['history']
187
+ output = copy.deepcopy(history)
188
+ output = apply_extensions('history', output)
189
+ state = apply_extensions('state', state)
190
+ if shared.model_name == 'None' or shared.model is None:
191
+ logger.error("No model is loaded! Select one in the Model tab.")
192
+ yield output
193
+ return
194
+
195
+ # Defining some variables
196
+ just_started = True
197
+ visible_text = None
198
+ stopping_strings = get_stopping_strings(state)
199
+ is_stream = state['stream']
200
+
201
+ # Preparing the input
202
+ if not any((regenerate, _continue)):
203
+ visible_text = text
204
+ text, visible_text = apply_extensions('chat_input', text, visible_text, state)
205
+ text = apply_extensions('input', text, state)
206
+
207
+ # *Is typing...*
208
+ if loading_message:
209
+ yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
210
+ else:
211
+ text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
212
+ if regenerate:
213
+ output['visible'].pop()
214
+ output['internal'].pop()
215
+ # *Is typing...*
216
+ if loading_message:
217
+ yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
218
+ elif _continue:
219
+ last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
220
+ if loading_message:
221
+ yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
222
+
223
+ # Generating the prompt
224
+ kwargs = {
225
+ '_continue': _continue,
226
+ 'history': output,
227
+ }
228
+
229
+ prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
230
+ if prompt is None:
231
+ prompt = generate_chat_prompt(text, state, **kwargs)
232
+
233
+ # Generate
234
+ cumulative_reply = ''
235
+ for i in range(state['chat_generation_attempts']):
236
+ reply = None
237
+ for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)):
238
+ reply = cumulative_reply + reply
239
+
240
+ # Extract the reply
241
+ visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
242
+
243
+ # We need this global variable to handle the Stop event,
244
+ # otherwise gradio gets confused
245
+ if shared.stop_everything:
246
+ output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
247
+ yield output
248
+ return
249
+
250
+ if just_started:
251
+ just_started = False
252
+ if not _continue:
253
+ output['internal'].append(['', ''])
254
+ output['visible'].append(['', ''])
255
+
256
+ if _continue:
257
+ output['internal'][-1] = [text, last_reply[0] + reply]
258
+ output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
259
+ if is_stream:
260
+ yield output
261
+ elif not (j == 0 and visible_reply.strip() == ''):
262
+ output['internal'][-1] = [text, reply.lstrip(' ')]
263
+ output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
264
+ if is_stream:
265
+ yield output
266
+
267
+ if reply in [None, cumulative_reply]:
268
+ break
269
+ else:
270
+ cumulative_reply = reply
271
+
272
+ output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
273
+ yield output
274
+
275
+
276
+ def impersonate_wrapper(text, start_with, state):
277
+ if shared.model_name == 'None' or shared.model is None:
278
+ logger.error("No model is loaded! Select one in the Model tab.")
279
+ yield ''
280
+ return
281
+
282
+ # Defining some variables
283
+ cumulative_reply = ''
284
+ prompt = generate_chat_prompt('', state, impersonate=True)
285
+ stopping_strings = get_stopping_strings(state)
286
+
287
+ yield text + '...'
288
+ cumulative_reply = text
289
+ for i in range(state['chat_generation_attempts']):
290
+ reply = None
291
+ for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True):
292
+ reply = cumulative_reply + reply
293
+ yield reply.lstrip(' ')
294
+ if shared.stop_everything:
295
+ return
296
+
297
+ if reply in [None, cumulative_reply]:
298
+ break
299
+ else:
300
+ cumulative_reply = reply
301
+
302
+ yield cumulative_reply.lstrip(' ')
303
+
304
+
305
+ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True):
306
+ history = state['history']
307
+ if regenerate or _continue:
308
+ text = ''
309
+ if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
310
+ yield history
311
+ return
312
+
313
+ for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
314
+ yield history
315
+
316
+
317
+ # Same as above but returns HTML for the UI
318
+ def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _continue=False):
319
+ if start_with != '' and not _continue:
320
+ if regenerate:
321
+ text, state['history'] = remove_last_message(state['history'])
322
+ regenerate = False
323
+
324
+ _continue = True
325
+ send_dummy_message(text, state)
326
+ send_dummy_reply(start_with, state)
327
+
328
+ for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)):
329
+ yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style']), history
330
+
331
+
332
+ def remove_last_message(history):
333
+ if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
334
+ last = history['visible'].pop()
335
+ history['internal'].pop()
336
+ else:
337
+ last = ['', '']
338
+
339
+ return last[0], history
340
+
341
+
342
+ def send_last_reply_to_input(history):
343
+ if len(history['internal']) > 0:
344
+ return history['internal'][-1][1]
345
+ else:
346
+ return ''
347
+
348
+
349
+ def replace_last_reply(text, state):
350
+ history = state['history']
351
+ if len(history['visible']) > 0:
352
+ history['visible'][-1][1] = text
353
+ history['internal'][-1][1] = apply_extensions('input', text, state)
354
+
355
+ return history
356
+
357
+
358
+ def send_dummy_message(text, state):
359
+ history = state['history']
360
+ history['visible'].append([text, ''])
361
+ history['internal'].append([apply_extensions('input', text, state), ''])
362
+ return history
363
+
364
+
365
+ def send_dummy_reply(text, state):
366
+ history = state['history']
367
+ if len(history['visible']) > 0 and not history['visible'][-1][1] == '':
368
+ history['visible'].append(['', ''])
369
+ history['internal'].append(['', ''])
370
+
371
+ history['visible'][-1][1] = text
372
+ history['internal'][-1][1] = apply_extensions('input', text, state)
373
+ return history
374
+
375
+
376
+ def clear_chat_log(state):
377
+ greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
378
+ mode = state['mode']
379
+ history = state['history']
380
+
381
+ history['visible'] = []
382
+ history['internal'] = []
383
+ if mode != 'instruct':
384
+ if greeting != '':
385
+ history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
386
+ history['visible'] += [['', apply_extensions('output', greeting, state)]]
387
+
388
+ return history
389
+
390
+
391
+ def redraw_html(history, name1, name2, mode, style, reset_cache=False):
392
+ return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache)
393
+
394
+
395
+ def save_history(history, path=None):
396
+ p = path or Path('logs/exported_history.json')
397
+ with open(p, 'w', encoding='utf-8') as f:
398
+ f.write(json.dumps(history, indent=4))
399
+
400
+ return p
401
+
402
+
403
+ def load_history(file, history):
404
+ try:
405
+ file = file.decode('utf-8')
406
+ j = json.loads(file)
407
+ if 'internal' in j and 'visible' in j:
408
+ return j
409
+ else:
410
+ return history
411
+ except:
412
+ return history
413
+
414
+
415
+ def save_history_at_user_request(history, character, mode):
416
+ def make_timestamp_path(character=None):
417
+ return f"logs/{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
418
+
419
+ path = None
420
+ if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None]:
421
+ path = make_timestamp_path(character)
422
+ else:
423
+ # Try to use mode as the file name, otherwise just use the timestamp
424
+ try:
425
+ path = make_timestamp_path(mode.capitalize())
426
+ except:
427
+ path = make_timestamp_path()
428
+
429
+ return save_history(history, path)
430
+
431
+
432
+ def save_persistent_history(history, character, mode):
433
+ if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
434
+ save_history(history, path=Path(f'logs/{character}_persistent.json'))
435
+
436
+
437
+ def load_persistent_history(state):
438
+ if state['mode'] == 'instruct':
439
+ return state['history']
440
+
441
+ character = state['character_menu']
442
+ greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
443
+ p = Path(f'logs/{character}_persistent.json')
444
+ if not shared.args.multi_user and character not in ['None', '', None] and p.exists():
445
+ f = json.loads(open(p, 'rb').read())
446
+ if 'internal' in f and 'visible' in f:
447
+ history = f
448
+ else:
449
+ history = {'internal': [], 'visible': []}
450
+ history['internal'] = f['data']
451
+ history['visible'] = f['data_visible']
452
+ else:
453
+ history = {'internal': [], 'visible': []}
454
+ if greeting != "":
455
+ history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
456
+ history['visible'] += [['', apply_extensions('output', greeting, state)]]
457
+
458
+ return history
459
+
460
+
461
+ def replace_character_names(text, name1, name2):
462
+ text = text.replace('{{user}}', name1).replace('{{char}}', name2)
463
+ return text.replace('<USER>', name1).replace('<BOT>', name2)
464
+
465
+
466
+ def generate_pfp_cache(character):
467
+ cache_folder = Path("cache")
468
+ if not cache_folder.exists():
469
+ cache_folder.mkdir()
470
+
471
+ for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
472
+ if path.exists():
473
+ img = make_thumbnail(Image.open(path))
474
+ img.save(Path('cache/pfp_character.png'), format='PNG')
475
+ return img
476
+
477
+ return None
478
+
479
+
480
+ def load_character(character, name1, name2, instruct=False):
481
+ context = greeting = turn_template = ""
482
+ greeting_field = 'greeting'
483
+ picture = None
484
+
485
+ # Deleting the profile picture cache, if any
486
+ if Path("cache/pfp_character.png").exists():
487
+ Path("cache/pfp_character.png").unlink()
488
+
489
+ if character not in ['None', '', None]:
490
+ folder = 'characters' if not instruct else 'characters/instruction-following'
491
+ picture = generate_pfp_cache(character)
492
+ filepath = None
493
+ for extension in ["yml", "yaml", "json"]:
494
+ filepath = Path(f'{folder}/{character}.{extension}')
495
+ if filepath.exists():
496
+ break
497
+
498
+ if filepath is None:
499
+ logger.error(f"Could not find character file for {character} in {folder} folder. Please check your spelling.")
500
+ return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
501
+
502
+ file_contents = open(filepath, 'r', encoding='utf-8').read()
503
+ data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
504
+
505
+ # Finding the bot's name
506
+ for k in ['name', 'bot', '<|bot|>', 'char_name']:
507
+ if k in data and data[k] != '':
508
+ name2 = data[k]
509
+ break
510
+
511
+ # Find the user name (if any)
512
+ for k in ['your_name', 'user', '<|user|>']:
513
+ if k in data and data[k] != '':
514
+ name1 = data[k]
515
+ break
516
+
517
+ if 'context' in data:
518
+ context = data['context']
519
+ if not instruct:
520
+ context = context.strip() + '\n'
521
+ elif "char_persona" in data:
522
+ context = build_pygmalion_style_context(data)
523
+ greeting_field = 'char_greeting'
524
+
525
+ if 'example_dialogue' in data:
526
+ context += f"{data['example_dialogue'].strip()}\n"
527
+
528
+ if greeting_field in data:
529
+ greeting = data[greeting_field]
530
+
531
+ if 'turn_template' in data:
532
+ turn_template = data['turn_template']
533
+
534
+ else:
535
+ context = shared.settings['context']
536
+ name2 = shared.settings['name2']
537
+ greeting = shared.settings['greeting']
538
+ turn_template = shared.settings['turn_template']
539
+
540
+ return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
541
+
542
+
543
+ @functools.cache
544
+ def load_character_memoized(character, name1, name2, instruct=False):
545
+ return load_character(character, name1, name2, instruct=instruct)
546
+
547
+
548
+ def upload_character(file, img, tavern=False):
549
+ decoded_file = file if type(file) == str else file.decode('utf-8')
550
+ try:
551
+ data = json.loads(decoded_file)
552
+ except:
553
+ data = yaml.safe_load(decoded_file)
554
+
555
+ if 'char_name' in data:
556
+ name = data['char_name']
557
+ greeting = data['char_greeting']
558
+ context = build_pygmalion_style_context(data)
559
+ yaml_data = generate_character_yaml(name, greeting, context)
560
+ else:
561
+ name = data['name']
562
+ yaml_data = generate_character_yaml(data['name'], data['greeting'], data['context'])
563
+
564
+ outfile_name = name
565
+ i = 1
566
+ while Path(f'characters/{outfile_name}.yaml').exists():
567
+ outfile_name = f'{name}_{i:03d}'
568
+ i += 1
569
+
570
+ with open(Path(f'characters/{outfile_name}.yaml'), 'w', encoding='utf-8') as f:
571
+ f.write(yaml_data)
572
+
573
+ if img is not None:
574
+ img.save(Path(f'characters/{outfile_name}.png'))
575
+
576
+ logger.info(f'New character saved to "characters/{outfile_name}.yaml".')
577
+ return gr.update(value=outfile_name, choices=get_available_characters())
578
+
579
+
580
+ def build_pygmalion_style_context(data):
581
+ context = ""
582
+ if 'char_persona' in data and data['char_persona'] != '':
583
+ context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
584
+
585
+ if 'world_scenario' in data and data['world_scenario'] != '':
586
+ context += f"Scenario: {data['world_scenario']}\n"
587
+
588
+ context = f"{context.strip()}\n"
589
+ return context
590
+
591
+
592
+ def upload_tavern_character(img, _json):
593
+ _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']}
594
+ return upload_character(json.dumps(_json), img, tavern=True)
595
+
596
+
597
+ def check_tavern_character(img):
598
+ if "chara" not in img.info:
599
+ return "Not a TavernAI card", None, None, gr.update(interactive=False)
600
+
601
+ decoded_string = base64.b64decode(img.info['chara']).replace(b'\\r\\n', b'\\n')
602
+ _json = json.loads(decoded_string)
603
+ if "data" in _json:
604
+ _json = _json["data"]
605
+
606
+ return _json['name'], _json['description'], _json, gr.update(interactive=True)
607
+
608
+
609
+ def upload_your_profile_picture(img):
610
+ cache_folder = Path("cache")
611
+ if not cache_folder.exists():
612
+ cache_folder.mkdir()
613
+
614
+ if img is None:
615
+ if Path("cache/pfp_me.png").exists():
616
+ Path("cache/pfp_me.png").unlink()
617
+ else:
618
+ img = make_thumbnail(img)
619
+ img.save(Path('cache/pfp_me.png'))
620
+ logger.info('Profile picture saved to "cache/pfp_me.png"')
621
+
622
+
623
+ def generate_character_yaml(name, greeting, context):
624
+ data = {
625
+ 'name': name,
626
+ 'greeting': greeting,
627
+ 'context': context,
628
+ }
629
+
630
+ data = {k: v for k, v in data.items() if v} # Strip falsy
631
+ return yaml.dump(data, sort_keys=False, width=float("inf"))
632
+
633
+
634
+ def generate_instruction_template_yaml(user, bot, context, turn_template):
635
+ data = {
636
+ 'user': user,
637
+ 'bot': bot,
638
+ 'turn_template': turn_template,
639
+ 'context': context,
640
+ }
641
+
642
+ data = {k: v for k, v in data.items() if v} # Strip falsy
643
+ return yaml.dump(data, sort_keys=False, width=float("inf"))
644
+
645
+
646
+ def save_character(name, greeting, context, picture, filename):
647
+ if filename == "":
648
+ logger.error("The filename is empty, so the character will not be saved.")
649
+ return
650
+
651
+ data = generate_character_yaml(name, greeting, context)
652
+ filepath = Path(f'characters/{filename}.yaml')
653
+ save_file(filepath, data)
654
+ path_to_img = Path(f'characters/{filename}.png')
655
+ if picture is not None:
656
+ picture.save(path_to_img)
657
+ logger.info(f'Saved {path_to_img}.')
658
+
659
+
660
+ def delete_character(name, instruct=False):
661
+ for extension in ["yml", "yaml", "json"]:
662
+ delete_file(Path(f'characters/{name}.{extension}'))
663
+
664
+ delete_file(Path(f'characters/{name}.png'))
modules/evaluate.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from datasets import load_dataset
7
+ from tqdm import tqdm
8
+
9
+ from modules import shared
10
+ from modules.models import load_model, unload_model
11
+ from modules.models_settings import (
12
+ get_model_settings_from_yamls,
13
+ update_model_parameters
14
+ )
15
+ from modules.text_generation import encode
16
+
17
+
18
+ def load_past_evaluations():
19
+ if Path('logs/evaluations.csv').exists():
20
+ df = pd.read_csv(Path('logs/evaluations.csv'), dtype=str)
21
+ df['Perplexity'] = pd.to_numeric(df['Perplexity'])
22
+ return df
23
+ else:
24
+ return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
25
+
26
+
27
+ past_evaluations = load_past_evaluations()
28
+
29
+
30
+ def save_past_evaluations(df):
31
+ global past_evaluations
32
+ past_evaluations = df
33
+ filepath = Path('logs/evaluations.csv')
34
+ filepath.parent.mkdir(parents=True, exist_ok=True)
35
+ df.to_csv(filepath, index=False)
36
+
37
+
38
+ def calculate_perplexity(models, input_dataset, stride, _max_length):
39
+ '''
40
+ Based on:
41
+ https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models
42
+ '''
43
+
44
+ global past_evaluations
45
+ cumulative_log = ''
46
+ cumulative_log += "Loading the input dataset...\n\n"
47
+ yield cumulative_log
48
+
49
+ # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py
50
+ if input_dataset == 'wikitext':
51
+ data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
52
+ text = "\n\n".join(data['text'])
53
+ elif input_dataset == 'ptb':
54
+ data = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
55
+ text = "\n\n".join(data['sentence'])
56
+ elif input_dataset == 'ptb_new':
57
+ data = load_dataset('ptb_text_only', 'penn_treebank', split='test')
58
+ text = " ".join(data['sentence'])
59
+ else:
60
+ with open(Path(f'training/datasets/{input_dataset}.txt'), 'r', encoding='utf-8') as f:
61
+ text = f.read()
62
+
63
+ for model in models:
64
+ if is_in_past_evaluations(model, input_dataset, stride, _max_length):
65
+ cumulative_log += f"{model} has already been tested. Ignoring.\n\n"
66
+ yield cumulative_log
67
+ continue
68
+
69
+ if model != 'current model':
70
+ try:
71
+ yield cumulative_log + f"Loading {model}...\n\n"
72
+ model_settings = get_model_settings_from_yamls(model)
73
+ shared.settings.update(model_settings) # hijacking the interface defaults
74
+ update_model_parameters(model_settings) # hijacking the command-line arguments
75
+ shared.model_name = model
76
+ unload_model()
77
+ shared.model, shared.tokenizer = load_model(shared.model_name)
78
+ except:
79
+ cumulative_log += f"Failed to load {model}. Moving on.\n\n"
80
+ yield cumulative_log
81
+ continue
82
+
83
+ cumulative_log += f"Processing {shared.model_name}...\n\n"
84
+ yield cumulative_log + "Tokenizing the input dataset...\n\n"
85
+ encodings = encode(text, add_special_tokens=False)
86
+ seq_len = encodings.shape[1]
87
+ if _max_length:
88
+ max_length = _max_length
89
+ elif hasattr(shared.model.config, 'max_position_embeddings'):
90
+ max_length = shared.model.config.max_position_embeddings
91
+ else:
92
+ max_length = 2048
93
+
94
+ nlls = []
95
+ prev_end_loc = 0
96
+ for begin_loc in tqdm(range(0, seq_len, stride)):
97
+ yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%"
98
+ end_loc = min(begin_loc + max_length, seq_len)
99
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
100
+ input_ids = encodings[:, begin_loc:end_loc]
101
+ target_ids = input_ids.clone()
102
+ target_ids[:, :-trg_len] = -100
103
+
104
+ with torch.no_grad():
105
+ outputs = shared.model(input_ids=input_ids, labels=target_ids)
106
+
107
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
108
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
109
+ # to the left by 1.
110
+ neg_log_likelihood = outputs.loss
111
+
112
+ nlls.append(neg_log_likelihood)
113
+
114
+ prev_end_loc = end_loc
115
+ if end_loc == seq_len:
116
+ break
117
+
118
+ ppl = torch.exp(torch.stack(nlls).mean())
119
+ add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length)
120
+ save_past_evaluations(past_evaluations)
121
+ cumulative_log += f"The perplexity for {shared.model_name} is: {float(ppl)}\n\n"
122
+ yield cumulative_log
123
+
124
+
125
+ def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length):
126
+ global past_evaluations
127
+ entry = {
128
+ 'Model': model,
129
+ 'LoRAs': ', '.join(shared.lora_names) or '-',
130
+ 'Dataset': dataset,
131
+ 'Perplexity': perplexity,
132
+ 'stride': str(stride),
133
+ 'max_length': str(max_length),
134
+ 'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
135
+ 'Comment': ''
136
+ }
137
+ past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True)
138
+
139
+
140
+ def is_in_past_evaluations(model, dataset, stride, max_length):
141
+ entries = past_evaluations[(past_evaluations['Model'] == model) &
142
+ (past_evaluations['Dataset'] == dataset) &
143
+ (past_evaluations['max_length'] == str(max_length)) &
144
+ (past_evaluations['stride'] == str(stride))]
145
+
146
+ if entries.shape[0] > 0:
147
+ return True
148
+ else:
149
+ return False
150
+
151
+
152
+ def generate_markdown_table():
153
+ sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date'])
154
+ return sorted_df
modules/html_generator.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import markdown
7
+ from PIL import Image, ImageOps
8
+
9
+ from modules.utils import get_available_chat_styles
10
+
11
+ # This is to store the paths to the thumbnails of the profile pictures
12
+ image_cache = {}
13
+
14
+ with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f:
15
+ readable_css = f.read()
16
+ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f:
17
+ _4chan_css = css_f.read()
18
+ with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
19
+ instruct_css = f.read()
20
+
21
+ # Custom chat styles
22
+ chat_styles = {}
23
+ for k in get_available_chat_styles():
24
+ chat_styles[k] = open(Path(f'css/chat_style-{k}.css'), 'r').read()
25
+
26
+
27
+ def fix_newlines(string):
28
+ string = string.replace('\n', '\n\n')
29
+ string = re.sub(r"\n{3,}", "\n\n", string)
30
+ string = string.strip()
31
+ return string
32
+
33
+
34
+ def replace_blockquote(m):
35
+ return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '')
36
+
37
+
38
+ def convert_to_markdown(string):
39
+
40
+ # Blockquote
41
+ pattern = re.compile(r'\\begin{blockquote}(.*?)\\end{blockquote}', re.DOTALL)
42
+ string = pattern.sub(replace_blockquote, string)
43
+
44
+ # Code
45
+ string = string.replace('\\begin{code}', '```')
46
+ string = string.replace('\\end{code}', '```')
47
+ string = re.sub(r"(.)```", r"\1\n```", string)
48
+
49
+ result = ''
50
+ is_code = False
51
+ for line in string.split('\n'):
52
+ if line.lstrip(' ').startswith('```'):
53
+ is_code = not is_code
54
+
55
+ result += line
56
+ if is_code or line.startswith('|'): # Don't add an extra \n for tables or code
57
+ result += '\n'
58
+ else:
59
+ result += '\n\n'
60
+
61
+ if is_code:
62
+ result = result + '```' # Unfinished code block
63
+
64
+ string = result.strip()
65
+ return markdown.markdown(string, extensions=['fenced_code', 'tables'])
66
+
67
+
68
+ def generate_basic_html(string):
69
+ string = convert_to_markdown(string)
70
+ string = f'<style>{readable_css}</style><div class="container">{string}</div>'
71
+ return string
72
+
73
+
74
+ def process_post(post, c):
75
+ t = post.split('\n')
76
+ number = t[0].split(' ')[1]
77
+ if len(t) > 1:
78
+ src = '\n'.join(t[1:])
79
+ else:
80
+ src = ''
81
+ src = re.sub('>', '&gt;', src)
82
+ src = re.sub('(&gt;&gt;[0-9]*)', '<span class="quote">\\1</span>', src)
83
+ src = re.sub('\n', '<br>\n', src)
84
+ src = f'<blockquote class="message">{src}\n'
85
+ src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
86
+ return src
87
+
88
+
89
+ def generate_4chan_html(f):
90
+ posts = []
91
+ post = ''
92
+ c = -2
93
+ for line in f.splitlines():
94
+ line += "\n"
95
+ if line == '-----\n':
96
+ continue
97
+ elif line.startswith('--- '):
98
+ c += 1
99
+ if post != '':
100
+ src = process_post(post, c)
101
+ posts.append(src)
102
+ post = line
103
+ else:
104
+ post += line
105
+ if post != '':
106
+ src = process_post(post, c)
107
+ posts.append(src)
108
+
109
+ for i in range(len(posts)):
110
+ if i == 0:
111
+ posts[i] = f'<div class="op">{posts[i]}</div>\n'
112
+ else:
113
+ posts[i] = f'<div class="reply">{posts[i]}</div>\n'
114
+
115
+ output = ''
116
+ output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
117
+ for post in posts:
118
+ output += post
119
+ output += '</div></div>'
120
+ output = output.split('\n')
121
+ for i in range(len(output)):
122
+ output[i] = re.sub(r'^(&gt;(.*?)(<br>|</div>))', r'<span class="greentext">\1</span>', output[i])
123
+ output[i] = re.sub(r'^<blockquote class="message">(&gt;(.*?)(<br>|</div>))', r'<blockquote class="message"><span class="greentext">\1</span>', output[i])
124
+ output = '\n'.join(output)
125
+
126
+ return output
127
+
128
+
129
+ def make_thumbnail(image):
130
+ image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
131
+ if image.size[1] > 470:
132
+ image = ImageOps.fit(image, (350, 470), Image.LANCZOS)
133
+
134
+ return image
135
+
136
+
137
+ def get_image_cache(path):
138
+ cache_folder = Path("cache")
139
+ if not cache_folder.exists():
140
+ cache_folder.mkdir()
141
+
142
+ mtime = os.stat(path).st_mtime
143
+ if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
144
+ img = make_thumbnail(Image.open(path))
145
+ output_file = Path(f'cache/{path.name}_cache.png')
146
+ img.convert('RGB').save(output_file, format='PNG')
147
+ image_cache[path] = [mtime, output_file.as_posix()]
148
+
149
+ return image_cache[path][1]
150
+
151
+
152
+ def generate_instruct_html(history):
153
+ output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
154
+ for i, _row in enumerate(history[::-1]):
155
+ row = [convert_to_markdown(entry) for entry in _row]
156
+
157
+ output += f"""
158
+ <div class="assistant-message">
159
+ <div class="text">
160
+ <div class="message-body">
161
+ {row[1]}
162
+ </div>
163
+ </div>
164
+ </div>
165
+ """
166
+
167
+ if len(row[0]) == 0: # don't display empty user messages
168
+ continue
169
+
170
+ output += f"""
171
+ <div class="user-message">
172
+ <div class="text">
173
+ <div class="message-body">
174
+ {row[0]}
175
+ </div>
176
+ </div>
177
+ </div>
178
+ """
179
+
180
+ output += "</div>"
181
+
182
+ return output
183
+
184
+
185
+ def generate_cai_chat_html(history, name1, name2, style, reset_cache=False):
186
+ output = f'<style>{chat_styles[style]}</style><div class="chat" id="chat">'
187
+
188
+ # We use ?name2 and ?time.time() to force the browser to reset caches
189
+ img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
190
+ img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else ''
191
+
192
+ for i, _row in enumerate(history[::-1]):
193
+ row = [convert_to_markdown(entry) for entry in _row]
194
+
195
+ output += f"""
196
+ <div class="message">
197
+ <div class="circle-bot">
198
+ {img_bot}
199
+ </div>
200
+ <div class="text">
201
+ <div class="username">
202
+ {name2}
203
+ </div>
204
+ <div class="message-body">
205
+ {row[1]}
206
+ </div>
207
+ </div>
208
+ </div>
209
+ """
210
+
211
+ if len(row[0]) == 0: # don't display empty user messages
212
+ continue
213
+
214
+ output += f"""
215
+ <div class="message">
216
+ <div class="circle-you">
217
+ {img_me}
218
+ </div>
219
+ <div class="text">
220
+ <div class="username">
221
+ {name1}
222
+ </div>
223
+ <div class="message-body">
224
+ {row[0]}
225
+ </div>
226
+ </div>
227
+ </div>
228
+ """
229
+
230
+ output += "</div>"
231
+ return output
232
+
233
+
234
+ def generate_chat_html(history, name1, name2, reset_cache=False):
235
+ output = f'<style>{chat_styles["wpp"]}</style><div class="chat" id="chat">'
236
+
237
+ for i, _row in enumerate(history[::-1]):
238
+ row = [convert_to_markdown(entry) for entry in _row]
239
+
240
+ output += f"""
241
+ <div class="message">
242
+ <div class="text-bot">
243
+ <div class="message-body">
244
+ {row[1]}
245
+ </div>
246
+ </div>
247
+ </div>
248
+ """
249
+
250
+ if len(row[0]) == 0: # don't display empty user messages
251
+ continue
252
+
253
+ output += f"""
254
+ <div class="message">
255
+ <div class="text-you">
256
+ <div class="message-body">
257
+ {row[0]}
258
+ </div>
259
+ </div>
260
+ </div>
261
+ """
262
+
263
+ output += "</div>"
264
+ return output
265
+
266
+
267
+ def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
268
+ if mode == 'instruct':
269
+ return generate_instruct_html(history['visible'])
270
+ elif style == 'wpp':
271
+ return generate_chat_html(history['visible'], name1, name2)
272
+ else:
273
+ return generate_cai_chat_html(history['visible'], name1, name2, style, reset_cache)
modules/loaders.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import gradio as gr
4
+
5
+ from modules import shared
6
+
7
+ loaders_and_params = {
8
+ 'AutoGPTQ': [
9
+ 'triton',
10
+ 'no_inject_fused_attention',
11
+ 'no_inject_fused_mlp',
12
+ 'no_use_cuda_fp16',
13
+ 'wbits',
14
+ 'groupsize',
15
+ 'desc_act',
16
+ 'gpu_memory',
17
+ 'cpu_memory',
18
+ 'cpu',
19
+ 'disk',
20
+ 'auto_devices',
21
+ 'trust_remote_code',
22
+ 'autogptq_info',
23
+ ],
24
+ 'GPTQ-for-LLaMa': [
25
+ 'wbits',
26
+ 'groupsize',
27
+ 'model_type',
28
+ 'pre_layer',
29
+ 'gptq_for_llama_info',
30
+ ],
31
+ 'llama.cpp': [
32
+ 'n_ctx',
33
+ 'n_gqa',
34
+ 'rms_norm_eps',
35
+ 'n_gpu_layers',
36
+ 'n_batch',
37
+ 'threads',
38
+ 'no_mmap',
39
+ 'low_vram',
40
+ 'mlock',
41
+ 'llama_cpp_seed',
42
+ 'compress_pos_emb',
43
+ 'alpha_value',
44
+ ],
45
+ 'llamacpp_HF': [
46
+ 'n_ctx',
47
+ 'n_gqa',
48
+ 'rms_norm_eps',
49
+ 'n_gpu_layers',
50
+ 'n_batch',
51
+ 'threads',
52
+ 'no_mmap',
53
+ 'low_vram',
54
+ 'mlock',
55
+ 'llama_cpp_seed',
56
+ 'compress_pos_emb',
57
+ 'alpha_value',
58
+ 'llamacpp_HF_info',
59
+ ],
60
+ 'Transformers': [
61
+ 'cpu_memory',
62
+ 'gpu_memory',
63
+ 'trust_remote_code',
64
+ 'load_in_8bit',
65
+ 'bf16',
66
+ 'cpu',
67
+ 'disk',
68
+ 'auto_devices',
69
+ 'load_in_4bit',
70
+ 'use_double_quant',
71
+ 'quant_type',
72
+ 'compute_dtype',
73
+ 'trust_remote_code',
74
+ 'transformers_info'
75
+ ],
76
+ 'ExLlama': [
77
+ 'gpu_split',
78
+ 'max_seq_len',
79
+ 'compress_pos_emb',
80
+ 'alpha_value',
81
+ 'exllama_info',
82
+ ],
83
+ 'ExLlama_HF': [
84
+ 'gpu_split',
85
+ 'max_seq_len',
86
+ 'compress_pos_emb',
87
+ 'alpha_value',
88
+ 'exllama_HF_info',
89
+ ]
90
+ }
91
+
92
+ loaders_samplers = {
93
+ 'Transformers': {
94
+ 'temperature',
95
+ 'top_p',
96
+ 'top_k',
97
+ 'typical_p',
98
+ 'epsilon_cutoff',
99
+ 'eta_cutoff',
100
+ 'tfs',
101
+ 'top_a',
102
+ 'repetition_penalty',
103
+ 'repetition_penalty_range',
104
+ 'encoder_repetition_penalty',
105
+ 'no_repeat_ngram_size',
106
+ 'min_length',
107
+ 'seed',
108
+ 'do_sample',
109
+ 'penalty_alpha',
110
+ 'num_beams',
111
+ 'length_penalty',
112
+ 'early_stopping',
113
+ 'mirostat_mode',
114
+ 'mirostat_tau',
115
+ 'mirostat_eta',
116
+ 'ban_eos_token',
117
+ 'add_bos_token',
118
+ 'skip_special_tokens',
119
+ },
120
+ 'ExLlama_HF': {
121
+ 'temperature',
122
+ 'top_p',
123
+ 'top_k',
124
+ 'typical_p',
125
+ 'epsilon_cutoff',
126
+ 'eta_cutoff',
127
+ 'tfs',
128
+ 'top_a',
129
+ 'repetition_penalty',
130
+ 'repetition_penalty_range',
131
+ 'encoder_repetition_penalty',
132
+ 'no_repeat_ngram_size',
133
+ 'min_length',
134
+ 'seed',
135
+ 'do_sample',
136
+ 'mirostat_mode',
137
+ 'mirostat_tau',
138
+ 'mirostat_eta',
139
+ 'ban_eos_token',
140
+ 'add_bos_token',
141
+ 'skip_special_tokens',
142
+ },
143
+ 'ExLlama': {
144
+ 'temperature',
145
+ 'top_p',
146
+ 'top_k',
147
+ 'typical_p',
148
+ 'repetition_penalty',
149
+ 'repetition_penalty_range',
150
+ 'seed',
151
+ 'ban_eos_token',
152
+ },
153
+ 'AutoGPTQ': {
154
+ 'temperature',
155
+ 'top_p',
156
+ 'top_k',
157
+ 'typical_p',
158
+ 'epsilon_cutoff',
159
+ 'eta_cutoff',
160
+ 'tfs',
161
+ 'top_a',
162
+ 'repetition_penalty',
163
+ 'repetition_penalty_range',
164
+ 'encoder_repetition_penalty',
165
+ 'no_repeat_ngram_size',
166
+ 'min_length',
167
+ 'seed',
168
+ 'do_sample',
169
+ 'penalty_alpha',
170
+ 'num_beams',
171
+ 'length_penalty',
172
+ 'early_stopping',
173
+ 'mirostat_mode',
174
+ 'mirostat_tau',
175
+ 'mirostat_eta',
176
+ 'ban_eos_token',
177
+ 'add_bos_token',
178
+ 'skip_special_tokens',
179
+ },
180
+ 'GPTQ-for-LLaMa': {
181
+ 'temperature',
182
+ 'top_p',
183
+ 'top_k',
184
+ 'typical_p',
185
+ 'epsilon_cutoff',
186
+ 'eta_cutoff',
187
+ 'tfs',
188
+ 'top_a',
189
+ 'repetition_penalty',
190
+ 'repetition_penalty_range',
191
+ 'encoder_repetition_penalty',
192
+ 'no_repeat_ngram_size',
193
+ 'min_length',
194
+ 'seed',
195
+ 'do_sample',
196
+ 'penalty_alpha',
197
+ 'num_beams',
198
+ 'length_penalty',
199
+ 'early_stopping',
200
+ 'mirostat_mode',
201
+ 'mirostat_tau',
202
+ 'mirostat_eta',
203
+ 'ban_eos_token',
204
+ 'add_bos_token',
205
+ 'skip_special_tokens',
206
+ },
207
+ 'llama.cpp': {
208
+ 'temperature',
209
+ 'top_p',
210
+ 'top_k',
211
+ 'tfs',
212
+ 'repetition_penalty',
213
+ 'mirostat_mode',
214
+ 'mirostat_tau',
215
+ 'mirostat_eta',
216
+ 'ban_eos_token',
217
+ },
218
+ 'llamacpp_HF': {
219
+ 'temperature',
220
+ 'top_p',
221
+ 'top_k',
222
+ 'typical_p',
223
+ 'epsilon_cutoff',
224
+ 'eta_cutoff',
225
+ 'tfs',
226
+ 'top_a',
227
+ 'repetition_penalty',
228
+ 'repetition_penalty_range',
229
+ 'encoder_repetition_penalty',
230
+ 'no_repeat_ngram_size',
231
+ 'min_length',
232
+ 'seed',
233
+ 'do_sample',
234
+ 'mirostat_mode',
235
+ 'mirostat_tau',
236
+ 'mirostat_eta',
237
+ 'ban_eos_token',
238
+ 'add_bos_token',
239
+ 'skip_special_tokens',
240
+ },
241
+ }
242
+
243
+
244
+ @functools.cache
245
+ def list_all_samplers():
246
+ all_samplers = set()
247
+ for k in loaders_samplers:
248
+ for sampler in loaders_samplers[k]:
249
+ all_samplers.add(sampler)
250
+
251
+ return sorted(all_samplers)
252
+
253
+
254
+ def blacklist_samplers(loader):
255
+ all_samplers = list_all_samplers()
256
+ if loader == 'All':
257
+ return [gr.update(visible=True) for sampler in all_samplers]
258
+ else:
259
+ return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers]
260
+
261
+
262
+ def get_gpu_memory_keys():
263
+ return [k for k in shared.gradio if k.startswith('gpu_memory')]
264
+
265
+
266
+ @functools.cache
267
+ def get_all_params():
268
+ all_params = set()
269
+ for k in loaders_and_params:
270
+ for el in loaders_and_params[k]:
271
+ all_params.add(el)
272
+
273
+ if 'gpu_memory' in all_params:
274
+ all_params.remove('gpu_memory')
275
+ for k in get_gpu_memory_keys():
276
+ all_params.add(k)
277
+
278
+ return sorted(all_params)
279
+
280
+
281
+ def make_loader_params_visible(loader):
282
+ params = []
283
+ all_params = get_all_params()
284
+ if loader in loaders_and_params:
285
+ params = loaders_and_params[loader]
286
+
287
+ if 'gpu_memory' in params:
288
+ params.remove('gpu_memory')
289
+ params += get_gpu_memory_keys()
290
+
291
+ return [gr.update(visible=True) if k in params else gr.update(visible=False) for k in all_params]
modules/models.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import re
4
+ import time
5
+ from pathlib import Path
6
+ import hashlib
7
+
8
+ import torch
9
+ import transformers
10
+ from accelerate import infer_auto_device_map, init_empty_weights
11
+ from transformers import (
12
+ AutoConfig,
13
+ AutoModel,
14
+ AutoModelForCausalLM,
15
+ AutoModelForSeq2SeqLM,
16
+ AutoTokenizer,
17
+ BitsAndBytesConfig,
18
+ )
19
+
20
+ import modules.shared as shared
21
+ from modules import llama_attn_hijack, sampler_hijack
22
+ from modules.logging_colors import logger
23
+ from modules.models_settings import infer_loader
24
+
25
+ transformers.logging.set_verbosity_error()
26
+
27
+ local_rank = None
28
+ if shared.args.deepspeed:
29
+ import deepspeed
30
+ from transformers.deepspeed import (
31
+ HfDeepSpeedConfig,
32
+ is_deepspeed_zero3_enabled
33
+ )
34
+
35
+ from modules.deepspeed_parameters import generate_ds_config
36
+
37
+ # Distributed setup
38
+ local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
39
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
40
+ torch.cuda.set_device(local_rank)
41
+ deepspeed.init_distributed()
42
+ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
43
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
44
+
45
+ sampler_hijack.hijack_samplers()
46
+
47
+
48
+ def load_model(model_name, loader=None):
49
+ logger.info(f"Loading {model_name}...")
50
+ t0 = time.time()
51
+
52
+ shared.is_seq2seq = False
53
+ load_func_map = {
54
+ 'Transformers': huggingface_loader,
55
+ 'AutoGPTQ': AutoGPTQ_loader,
56
+ 'GPTQ-for-LLaMa': GPTQ_loader,
57
+ 'llama.cpp': llamacpp_loader,
58
+ 'llamacpp_HF': llamacpp_HF_loader,
59
+ 'RWKV': RWKV_loader,
60
+ 'ExLlama': ExLlama_loader,
61
+ 'ExLlama_HF': ExLlama_HF_loader
62
+ }
63
+
64
+ p = Path(model_name)
65
+ if p.exists():
66
+ model_name = p.parts[-1]
67
+
68
+ if loader is None:
69
+ if shared.args.loader is not None:
70
+ loader = shared.args.loader
71
+ else:
72
+ loader = infer_loader(model_name)
73
+ if loader is None:
74
+ logger.error('The path to the model does not exist. Exiting.')
75
+ return None, None
76
+
77
+ shared.args.loader = loader
78
+ output = load_func_map[loader](model_name)
79
+ if type(output) is tuple:
80
+ model, tokenizer = output
81
+ else:
82
+ model = output
83
+ if model is None:
84
+ return None, None
85
+ else:
86
+ tokenizer = load_tokenizer(model_name, model)
87
+
88
+ # Hijack attention with xformers
89
+ if any((shared.args.xformers, shared.args.sdp_attention)):
90
+ llama_attn_hijack.hijack_llama_attention()
91
+
92
+ logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
93
+ return model, tokenizer
94
+
95
+
96
+ def load_tokenizer(model_name, model):
97
+ tokenizer = None
98
+ path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
99
+ if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
100
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
101
+ elif path_to_model.exists():
102
+ try:
103
+ tokenizer = AutoTokenizer.from_pretrained(
104
+ path_to_model,
105
+ trust_remote_code=shared.args.trust_remote_code,
106
+ use_fast=False
107
+ )
108
+ except ValueError:
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ path_to_model,
111
+ trust_remote_code=shared.args.trust_remote_code,
112
+ use_fast=True
113
+ )
114
+
115
+ if tokenizer.__class__.__name__ == 'LlamaTokenizer':
116
+ pairs = [
117
+ ['tokenizer_config.json', '516c6167c884793a738c440e29ccb80c15e1493ffc965affc69a1a8ddef4572a'],
118
+ ['special_tokens_map.json', 'ff3b4a612c4e447acb02d40071bddd989fe0da87eb5b7fe0dbadfc4f74de7531']
119
+ ]
120
+
121
+ for pair in pairs:
122
+ p = path_to_model / pair[0]
123
+ if p.exists():
124
+ with open(p, "rb") as f:
125
+ bytes = f.read()
126
+
127
+ file_hash = hashlib.sha256(bytes).hexdigest()
128
+ if file_hash != pair[1]:
129
+ logger.warning(f"{p} is different from the original LlamaTokenizer file. It is either customized or outdated.")
130
+
131
+ return tokenizer
132
+
133
+
134
+ def huggingface_loader(model_name):
135
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
136
+ if 'chatglm' in model_name.lower():
137
+ LoaderClass = AutoModel
138
+ else:
139
+ config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
140
+ if config.to_dict().get("is_encoder_decoder", False):
141
+ LoaderClass = AutoModelForSeq2SeqLM
142
+ shared.is_seq2seq = True
143
+ else:
144
+ LoaderClass = AutoModelForCausalLM
145
+
146
+ # Load the model in simple 16-bit mode by default
147
+ if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
148
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
149
+ if torch.backends.mps.is_available():
150
+ device = torch.device('mps')
151
+ model = model.to(device)
152
+ else:
153
+ model = model.cuda()
154
+
155
+ # DeepSpeed ZeRO-3
156
+ elif shared.args.deepspeed:
157
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
158
+ model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
159
+ model.module.eval() # Inference
160
+ logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
161
+
162
+ # Custom
163
+ else:
164
+ params = {
165
+ "low_cpu_mem_usage": True,
166
+ "trust_remote_code": shared.args.trust_remote_code
167
+ }
168
+
169
+ if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())):
170
+ logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
171
+ shared.args.cpu = True
172
+
173
+ if shared.args.cpu:
174
+ params["torch_dtype"] = torch.float32
175
+ else:
176
+ params["device_map"] = 'auto'
177
+ if shared.args.load_in_4bit:
178
+
179
+ # See https://github.com/huggingface/transformers/pull/23479/files
180
+ # and https://huggingface.co/blog/4bit-transformers-bitsandbytes
181
+ quantization_config_params = {
182
+ 'load_in_4bit': True,
183
+ 'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
184
+ 'bnb_4bit_quant_type': shared.args.quant_type,
185
+ 'bnb_4bit_use_double_quant': shared.args.use_double_quant,
186
+ }
187
+
188
+ logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
189
+ params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
190
+
191
+ elif shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
192
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
193
+ elif shared.args.load_in_8bit:
194
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
195
+ elif shared.args.bf16:
196
+ params["torch_dtype"] = torch.bfloat16
197
+ else:
198
+ params["torch_dtype"] = torch.float16
199
+
200
+ params['max_memory'] = get_max_memory_dict()
201
+ if shared.args.disk:
202
+ params["offload_folder"] = shared.args.disk_cache_dir
203
+
204
+ checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
205
+ if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
206
+ config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
207
+ with init_empty_weights():
208
+ model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
209
+
210
+ model.tie_weights()
211
+ params['device_map'] = infer_auto_device_map(
212
+ model,
213
+ dtype=torch.int8,
214
+ max_memory=params['max_memory'],
215
+ no_split_module_classes=model._no_split_modules
216
+ )
217
+
218
+ model = LoaderClass.from_pretrained(checkpoint, **params)
219
+
220
+ return model
221
+
222
+
223
+ def RWKV_loader(model_name):
224
+ from modules.RWKV import RWKVModel, RWKVTokenizer
225
+
226
+ model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
227
+ tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
228
+ return model, tokenizer
229
+
230
+
231
+ def llamacpp_loader(model_name):
232
+ from modules.llamacpp_model import LlamaCppModel
233
+
234
+ path = Path(f'{shared.args.model_dir}/{model_name}')
235
+ if path.is_file():
236
+ model_file = path
237
+ else:
238
+ model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
239
+
240
+ logger.info(f"llama.cpp weights detected: {model_file}\n")
241
+ model, tokenizer = LlamaCppModel.from_pretrained(model_file)
242
+ return model, tokenizer
243
+
244
+
245
+ def llamacpp_HF_loader(model_name):
246
+ from modules.llamacpp_hf import LlamacppHF
247
+
248
+ for fname in ["oobabooga_llama-tokenizer", "llama-tokenizer"]:
249
+ path = Path(f'{shared.args.model_dir}/{fname}')
250
+ if path.exists():
251
+ break
252
+ else:
253
+ logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.")
254
+ return None, None
255
+
256
+ tokenizer = AutoTokenizer.from_pretrained(
257
+ path,
258
+ trust_remote_code=shared.args.trust_remote_code,
259
+ use_fast=False
260
+ )
261
+
262
+ model = LlamacppHF.from_pretrained(model_name)
263
+ return model, tokenizer
264
+
265
+
266
+ def GPTQ_loader(model_name):
267
+
268
+ # Monkey patch
269
+ if shared.args.monkey_patch:
270
+ logger.warning("Applying the monkey patch for using LoRAs with GPTQ models. It may cause undefined behavior outside its intended scope.")
271
+ from modules.monkey_patch_gptq_lora import load_model_llama
272
+
273
+ model, _ = load_model_llama(model_name)
274
+
275
+ # No monkey patch
276
+ else:
277
+ import modules.GPTQ_loader
278
+
279
+ model = modules.GPTQ_loader.load_quantized(model_name)
280
+
281
+ return model
282
+
283
+
284
+ def AutoGPTQ_loader(model_name):
285
+ import modules.AutoGPTQ_loader
286
+
287
+ return modules.AutoGPTQ_loader.load_quantized(model_name)
288
+
289
+
290
+ def ExLlama_loader(model_name):
291
+ from modules.exllama import ExllamaModel
292
+
293
+ model, tokenizer = ExllamaModel.from_pretrained(model_name)
294
+ return model, tokenizer
295
+
296
+
297
+ def ExLlama_HF_loader(model_name):
298
+ from modules.exllama_hf import ExllamaHF
299
+
300
+ return ExllamaHF.from_pretrained(model_name)
301
+
302
+
303
+ def get_max_memory_dict():
304
+ max_memory = {}
305
+ if shared.args.gpu_memory:
306
+ memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
307
+ for i in range(len(memory_map)):
308
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
309
+
310
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
311
+ max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
312
+
313
+ # If --auto-devices is provided standalone, try to get a reasonable value
314
+ # for the maximum memory of device :0
315
+ elif shared.args.auto_devices:
316
+ total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
317
+ suggestion = round((total_mem - 1000) / 1000) * 1000
318
+ if total_mem - suggestion < 800:
319
+ suggestion -= 1000
320
+
321
+ suggestion = int(round(suggestion / 1000))
322
+ logger.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
323
+ max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
324
+
325
+ return max_memory if len(max_memory) > 0 else None
326
+
327
+
328
+ def clear_torch_cache():
329
+ gc.collect()
330
+ if not shared.args.cpu:
331
+ torch.cuda.empty_cache()
332
+
333
+
334
+ def unload_model():
335
+ shared.model = shared.tokenizer = None
336
+ shared.lora_names = []
337
+ shared.model_dirty_from_training = False
338
+ clear_torch_cache()
339
+
340
+
341
+ def reload_model():
342
+ unload_model()
343
+ shared.model, shared.tokenizer = load_model(shared.model_name)
modules/models_settings.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
+
4
+ import yaml
5
+
6
+ from modules import loaders, shared, ui
7
+
8
+
9
+ def get_model_settings_from_yamls(model):
10
+ settings = shared.model_config
11
+ model_settings = {}
12
+ for pat in settings:
13
+ if re.match(pat.lower(), model.lower()):
14
+ for k in settings[pat]:
15
+ model_settings[k] = settings[pat][k]
16
+
17
+ return model_settings
18
+
19
+
20
+ def infer_loader(model_name):
21
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
22
+ model_settings = get_model_settings_from_yamls(model_name)
23
+ if not path_to_model.exists():
24
+ loader = None
25
+ elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
26
+ loader = 'AutoGPTQ'
27
+ elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
28
+ loader = 'llama.cpp'
29
+ elif re.match('.*ggml.*\.bin', model_name.lower()):
30
+ loader = 'llama.cpp'
31
+ elif re.match('.*rwkv.*\.pth', model_name.lower()):
32
+ loader = 'RWKV'
33
+ else:
34
+ loader = 'Transformers'
35
+
36
+ return loader
37
+
38
+
39
+ # UI: update the command-line arguments based on the interface values
40
+ def update_model_parameters(state, initial=False):
41
+ elements = ui.list_model_elements() # the names of the parameters
42
+ gpu_memories = []
43
+
44
+ for i, element in enumerate(elements):
45
+ if element not in state:
46
+ continue
47
+
48
+ value = state[element]
49
+ if element.startswith('gpu_memory'):
50
+ gpu_memories.append(value)
51
+ continue
52
+
53
+ if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
54
+ continue
55
+
56
+ # Setting null defaults
57
+ if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
58
+ value = vars(shared.args_defaults)[element]
59
+ elif element in ['cpu_memory'] and value == 0:
60
+ value = vars(shared.args_defaults)[element]
61
+
62
+ # Making some simple conversions
63
+ if element in ['wbits', 'groupsize', 'pre_layer']:
64
+ value = int(value)
65
+ elif element == 'cpu_memory' and value is not None:
66
+ value = f"{value}MiB"
67
+
68
+ if element in ['pre_layer']:
69
+ value = [value] if value > 0 else None
70
+
71
+ setattr(shared.args, element, value)
72
+
73
+ found_positive = False
74
+ for i in gpu_memories:
75
+ if i > 0:
76
+ found_positive = True
77
+ break
78
+
79
+ if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
80
+ if found_positive:
81
+ shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
82
+ else:
83
+ shared.args.gpu_memory = None
84
+
85
+
86
+ # UI: update the state variable with the model settings
87
+ def apply_model_settings_to_state(model, state):
88
+ model_settings = get_model_settings_from_yamls(model)
89
+ if 'loader' not in model_settings:
90
+ loader = infer_loader(model)
91
+ if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
92
+ loader = 'AutoGPTQ'
93
+
94
+ # If the user is using an alternative GPTQ loader, let them keep using it
95
+ if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']):
96
+ state['loader'] = loader
97
+
98
+ for k in model_settings:
99
+ if k in state:
100
+ if k in ['wbits', 'groupsize']:
101
+ state[k] = str(model_settings[k])
102
+ else:
103
+ state[k] = model_settings[k]
104
+
105
+ return state
106
+
107
+
108
+ # Save the settings for this model to models/config-user.yaml
109
+ def save_model_settings(model, state):
110
+ if model == 'None':
111
+ yield ("Not saving the settings because no model is loaded.")
112
+ return
113
+
114
+ with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
115
+ if p.exists():
116
+ user_config = yaml.safe_load(open(p, 'r').read())
117
+ else:
118
+ user_config = {}
119
+
120
+ model_regex = model + '$' # For exact matches
121
+ for _dict in [user_config, shared.model_config]:
122
+ if model_regex not in _dict:
123
+ _dict[model_regex] = {}
124
+
125
+ if model_regex not in user_config:
126
+ user_config[model_regex] = {}
127
+
128
+ for k in ui.list_model_elements():
129
+ if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
130
+ user_config[model_regex][k] = state[k]
131
+ shared.model_config[model_regex][k] = state[k]
132
+
133
+ output = yaml.dump(user_config, sort_keys=False)
134
+ with open(p, 'w') as f:
135
+ f.write(output)
136
+
137
+ yield (f"Settings for {model} saved to {p}")
modules/presets.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from pathlib import Path
3
+
4
+ import yaml
5
+
6
+
7
+ def default_preset():
8
+ return {
9
+ 'do_sample': True,
10
+ 'temperature': 1,
11
+ 'top_p': 1,
12
+ 'typical_p': 1,
13
+ 'epsilon_cutoff': 0,
14
+ 'eta_cutoff': 0,
15
+ 'tfs': 1,
16
+ 'top_a': 0,
17
+ 'repetition_penalty': 1,
18
+ 'repetition_penalty_range': 0,
19
+ 'encoder_repetition_penalty': 1,
20
+ 'top_k': 0,
21
+ 'num_beams': 1,
22
+ 'penalty_alpha': 0,
23
+ 'min_length': 0,
24
+ 'length_penalty': 1,
25
+ 'no_repeat_ngram_size': 0,
26
+ 'early_stopping': False,
27
+ 'mirostat_mode': 0,
28
+ 'mirostat_tau': 5.0,
29
+ 'mirostat_eta': 0.1,
30
+ }
31
+
32
+
33
+ def load_preset(name):
34
+ generate_params = default_preset()
35
+ if name not in ['None', None, '']:
36
+ with open(Path(f'presets/{name}.yaml'), 'r') as infile:
37
+ preset = yaml.safe_load(infile)
38
+
39
+ for k in preset:
40
+ generate_params[k] = preset[k]
41
+
42
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
43
+ return generate_params
44
+
45
+
46
+ @functools.cache
47
+ def load_preset_memoized(name):
48
+ return load_preset(name)
49
+
50
+
51
+ def load_preset_for_ui(name, state):
52
+ generate_params = load_preset(name)
53
+ state.update(generate_params)
54
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
55
+
56
+
57
+ def generate_preset_yaml(state):
58
+ defaults = default_preset()
59
+ data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
60
+
61
+ # Remove entries that are identical to the defaults
62
+ for k in list(data.keys()):
63
+ if data[k] == defaults[k]:
64
+ del data[k]
65
+
66
+ return yaml.dump(data, sort_keys=False)
modules/relative_imports.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+
5
+ class RelativeImport:
6
+ def __init__(self, path):
7
+ self.import_path = Path(path)
8
+
9
+ def __enter__(self):
10
+ sys.path.insert(0, str(self.import_path))
11
+
12
+ def __exit__(self, exc_type, exc_value, traceback):
13
+ sys.path.remove(str(self.import_path))
modules/text_generation.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import copy
3
+ import random
4
+ import re
5
+ import time
6
+ import traceback
7
+
8
+ import numpy as np
9
+ import torch
10
+ import transformers
11
+ from transformers import LogitsProcessorList
12
+
13
+ import modules.shared as shared
14
+ from modules.callbacks import (
15
+ Iteratorize,
16
+ Stream,
17
+ _StopEverythingStoppingCriteria
18
+ )
19
+ from modules.extensions import apply_extensions
20
+ from modules.html_generator import generate_4chan_html, generate_basic_html
21
+ from modules.logging_colors import logger
22
+ from modules.models import clear_torch_cache, local_rank
23
+
24
+
25
+ def generate_reply(*args, **kwargs):
26
+ shared.generation_lock.acquire()
27
+ try:
28
+ for result in _generate_reply(*args, **kwargs):
29
+ yield result
30
+ finally:
31
+ shared.generation_lock.release()
32
+
33
+
34
+ def get_max_prompt_length(state):
35
+ return state['truncation_length'] - state['max_new_tokens']
36
+
37
+
38
+ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
39
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
40
+ input_ids = shared.tokenizer.encode(str(prompt))
41
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
42
+ return input_ids
43
+ else:
44
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
45
+
46
+ # This is a hack for making replies more creative.
47
+ if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
48
+ input_ids = input_ids[:, 1:]
49
+
50
+ # Handling truncation
51
+ if truncation_length is not None:
52
+ input_ids = input_ids[:, -truncation_length:]
53
+
54
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
55
+ return input_ids
56
+ elif shared.args.deepspeed:
57
+ return input_ids.to(device=local_rank)
58
+ elif torch.backends.mps.is_available():
59
+ device = torch.device('mps')
60
+ return input_ids.to(device)
61
+ else:
62
+ return input_ids.cuda()
63
+
64
+
65
+ def get_encoded_length(prompt):
66
+ length_after_extensions = apply_extensions('tokenized_length', prompt)
67
+ if length_after_extensions is not None:
68
+ return length_after_extensions
69
+
70
+ return len(encode(prompt)[0])
71
+
72
+
73
+ def decode(output_ids, skip_special_tokens=True):
74
+ return shared.tokenizer.decode(output_ids, skip_special_tokens)
75
+
76
+
77
+ # Removes empty replies from gpt4chan outputs
78
+ def fix_gpt4chan(s):
79
+ for i in range(10):
80
+ s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
81
+ s = re.sub("--- [0-9]*\n *\n---", "---", s)
82
+ s = re.sub("--- [0-9]*\n\n\n---", "---", s)
83
+
84
+ return s
85
+
86
+
87
+ # Fix the LaTeX equations in galactica
88
+ def fix_galactica(s):
89
+ s = s.replace(r'\[', r'$')
90
+ s = s.replace(r'\]', r'$')
91
+ s = s.replace(r'\(', r'$')
92
+ s = s.replace(r'\)', r'$')
93
+ s = s.replace(r'$$', r'$')
94
+ s = re.sub(r'\n', r'\n\n', s)
95
+ s = re.sub(r"\n{3,}", "\n\n", s)
96
+ return s
97
+
98
+
99
+ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
100
+ if shared.is_seq2seq:
101
+ reply = decode(output_ids, state['skip_special_tokens'])
102
+ else:
103
+ new_tokens = len(output_ids) - len(input_ids[0])
104
+ reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
105
+ # Prevent LlamaTokenizer from skipping a space
106
+ if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
107
+ if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
108
+ reply = ' ' + reply
109
+
110
+ return reply
111
+
112
+
113
+ def formatted_outputs(reply, model_name):
114
+ if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
115
+ reply = fix_gpt4chan(reply)
116
+ return reply, generate_4chan_html(reply)
117
+ else:
118
+ return reply, generate_basic_html(reply)
119
+
120
+
121
+ def set_manual_seed(seed):
122
+ seed = int(seed)
123
+ if seed == -1:
124
+ seed = random.randint(1, 2**31)
125
+
126
+ torch.manual_seed(seed)
127
+ if torch.cuda.is_available():
128
+ torch.cuda.manual_seed_all(seed)
129
+
130
+ return seed
131
+
132
+
133
+ def stop_everything_event():
134
+ shared.stop_everything = True
135
+
136
+
137
+ def generate_reply_wrapper(question, state, stopping_strings=None):
138
+ reply = question if not shared.is_seq2seq else ''
139
+ yield formatted_outputs(reply, shared.model_name)
140
+
141
+ for reply in generate_reply(question, state, stopping_strings, is_chat=False):
142
+ if not shared.is_seq2seq:
143
+ reply = question + reply
144
+
145
+ yield formatted_outputs(reply, shared.model_name)
146
+
147
+
148
+ def apply_stopping_strings(reply, all_stop_strings):
149
+ stop_found = False
150
+ for string in all_stop_strings:
151
+ idx = reply.find(string)
152
+ if idx != -1:
153
+ reply = reply[:idx]
154
+ stop_found = True
155
+ break
156
+
157
+ if not stop_found:
158
+ # If something like "\nYo" is generated just before "\nYou:"
159
+ # is completed, trim it
160
+ for string in all_stop_strings:
161
+ for j in range(len(string) - 1, 0, -1):
162
+ if reply[-j:] == string[:j]:
163
+ reply = reply[:-j]
164
+ break
165
+ else:
166
+ continue
167
+
168
+ break
169
+
170
+ return reply, stop_found
171
+
172
+
173
+ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
174
+ generate_func = apply_extensions('custom_generate_reply')
175
+ if generate_func is None:
176
+ if shared.model_name == 'None' or shared.model is None:
177
+ logger.error("No model is loaded! Select one in the Model tab.")
178
+ yield ''
179
+ return
180
+
181
+ if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
182
+ generate_func = generate_reply_custom
183
+ else:
184
+ generate_func = generate_reply_HF
185
+
186
+ # Preparing the input
187
+ original_question = question
188
+ if not is_chat:
189
+ state = apply_extensions('state', state)
190
+ question = apply_extensions('input', question, state)
191
+
192
+ # Finding the stopping strings
193
+ all_stop_strings = []
194
+ for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
195
+ if type(st) is list and len(st) > 0:
196
+ all_stop_strings += st
197
+
198
+ if shared.args.verbose:
199
+ print(f'\n\n{question}\n--------------------\n')
200
+
201
+ shared.stop_everything = False
202
+ clear_torch_cache()
203
+ seed = set_manual_seed(state['seed'])
204
+ last_update = -1
205
+ reply = ''
206
+ is_stream = state['stream']
207
+ if len(all_stop_strings) > 0 and not state['stream']:
208
+ state = copy.deepcopy(state)
209
+ state['stream'] = True
210
+
211
+ for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
212
+ reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
213
+ if is_stream:
214
+ cur_time = time.time()
215
+ if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
216
+ last_update = cur_time
217
+ yield reply
218
+
219
+ if stop_found:
220
+ break
221
+
222
+ if not is_chat:
223
+ reply = apply_extensions('output', reply, state)
224
+
225
+ yield reply
226
+
227
+
228
+ def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
229
+ generate_params = {}
230
+ for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
231
+ generate_params[k] = state[k]
232
+
233
+ for k in ['epsilon_cutoff', 'eta_cutoff']:
234
+ if state[k] > 0:
235
+ generate_params[k] = state[k] * 1e-4
236
+
237
+ if state['ban_eos_token']:
238
+ generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
239
+
240
+ if shared.args.no_cache:
241
+ generate_params.update({'use_cache': False})
242
+
243
+ if shared.args.deepspeed:
244
+ generate_params.update({'synced_gpus': True})
245
+
246
+ # Encode the input
247
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
248
+ output = input_ids[0]
249
+ cuda = not any((shared.args.cpu, shared.args.deepspeed))
250
+
251
+ # Add the encoded tokens to generate_params
252
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
253
+ original_input_ids = input_ids
254
+ generate_params.update({'inputs': input_ids})
255
+ if inputs_embeds is not None:
256
+ generate_params.update({'inputs_embeds': inputs_embeds})
257
+
258
+ # Stopping criteria / eos token
259
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
260
+ generate_params['eos_token_id'] = eos_token_ids
261
+ generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
262
+ generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
263
+
264
+ processor = state.get('logits_processor', LogitsProcessorList([]))
265
+ # In case folks just pass in a processor by itself.
266
+ if type(processor) != LogitsProcessorList:
267
+ processor = LogitsProcessorList([processor])
268
+ apply_extensions('logits_processor', processor, input_ids)
269
+ generate_params['logits_processor'] = processor
270
+
271
+ t0 = time.time()
272
+ try:
273
+ if not is_chat and not shared.is_seq2seq:
274
+ yield ''
275
+
276
+ # Generate the entire reply at once.
277
+ if not state['stream']:
278
+ with torch.no_grad():
279
+ output = shared.model.generate(**generate_params)[0]
280
+ if cuda:
281
+ output = output.cuda()
282
+
283
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
284
+
285
+ # Stream the reply 1 token at a time.
286
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
287
+ else:
288
+
289
+ def generate_with_callback(callback=None, *args, **kwargs):
290
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
291
+ clear_torch_cache()
292
+ with torch.no_grad():
293
+ shared.model.generate(**kwargs)
294
+
295
+ def generate_with_streaming(**kwargs):
296
+ return Iteratorize(generate_with_callback, [], kwargs, callback=None)
297
+
298
+ with generate_with_streaming(**generate_params) as generator:
299
+ for output in generator:
300
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
301
+ if output[-1] in eos_token_ids:
302
+ break
303
+
304
+ except Exception:
305
+ traceback.print_exc()
306
+ finally:
307
+ t1 = time.time()
308
+ original_tokens = len(original_input_ids[0])
309
+ new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
310
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
311
+ return
312
+
313
+
314
+ def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
315
+ seed = set_manual_seed(state['seed'])
316
+
317
+ t0 = time.time()
318
+ reply = ''
319
+ try:
320
+ if not is_chat:
321
+ yield ''
322
+
323
+ if not state['stream']:
324
+ reply = shared.model.generate(question, state)
325
+ yield reply
326
+ else:
327
+ for reply in shared.model.generate_with_streaming(question, state):
328
+ yield reply
329
+
330
+ except Exception:
331
+ traceback.print_exc()
332
+ finally:
333
+ t1 = time.time()
334
+ original_tokens = len(encode(original_question)[0])
335
+ new_tokens = len(encode(original_question + reply)[0]) - original_tokens
336
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
337
+ return
modules/ui.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ import torch
6
+
7
+ from modules import shared
8
+
9
+
10
+ with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
11
+ css = f.read()
12
+ with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
13
+ chat_css = f.read()
14
+ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
15
+ main_js = f.read()
16
+ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
17
+ chat_js = f.read()
18
+
19
+ refresh_symbol = '🔄'
20
+ delete_symbol = '🗑️'
21
+ save_symbol = '💾'
22
+
23
+ theme = gr.themes.Default(
24
+ font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
25
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
26
+ ).set(
27
+ border_color_primary='#c5c5d2',
28
+ button_large_padding='6px 12px',
29
+ body_text_color_subdued='#484848',
30
+ background_fill_secondary='#eaeaea'
31
+ )
32
+
33
+
34
+ def list_model_elements():
35
+ elements = [
36
+ 'loader',
37
+ 'cpu_memory',
38
+ 'auto_devices',
39
+ 'disk',
40
+ 'cpu',
41
+ 'bf16',
42
+ 'load_in_8bit',
43
+ 'trust_remote_code',
44
+ 'load_in_4bit',
45
+ 'compute_dtype',
46
+ 'quant_type',
47
+ 'use_double_quant',
48
+ 'wbits',
49
+ 'groupsize',
50
+ 'model_type',
51
+ 'pre_layer',
52
+ 'triton',
53
+ 'desc_act',
54
+ 'no_inject_fused_attention',
55
+ 'no_inject_fused_mlp',
56
+ 'no_use_cuda_fp16',
57
+ 'threads',
58
+ 'n_batch',
59
+ 'no_mmap',
60
+ 'low_vram',
61
+ 'mlock',
62
+ 'n_gpu_layers',
63
+ 'n_ctx',
64
+ 'n_gqa',
65
+ 'rms_norm_eps',
66
+ 'llama_cpp_seed',
67
+ 'gpu_split',
68
+ 'max_seq_len',
69
+ 'compress_pos_emb',
70
+ 'alpha_value'
71
+ ]
72
+
73
+ for i in range(torch.cuda.device_count()):
74
+ elements.append(f'gpu_memory_{i}')
75
+
76
+ return elements
77
+
78
+
79
+ def list_interface_input_elements():
80
+ elements = [
81
+ 'max_new_tokens',
82
+ 'seed',
83
+ 'temperature',
84
+ 'top_p',
85
+ 'top_k',
86
+ 'typical_p',
87
+ 'epsilon_cutoff',
88
+ 'eta_cutoff',
89
+ 'repetition_penalty',
90
+ 'repetition_penalty_range',
91
+ 'encoder_repetition_penalty',
92
+ 'no_repeat_ngram_size',
93
+ 'min_length',
94
+ 'do_sample',
95
+ 'penalty_alpha',
96
+ 'num_beams',
97
+ 'length_penalty',
98
+ 'early_stopping',
99
+ 'mirostat_mode',
100
+ 'mirostat_tau',
101
+ 'mirostat_eta',
102
+ 'add_bos_token',
103
+ 'ban_eos_token',
104
+ 'truncation_length',
105
+ 'custom_stopping_strings',
106
+ 'skip_special_tokens',
107
+ 'stream',
108
+ 'tfs',
109
+ 'top_a',
110
+ ]
111
+
112
+ if shared.args.chat:
113
+ elements += [
114
+ 'character_menu',
115
+ 'history',
116
+ 'name1',
117
+ 'name2',
118
+ 'greeting',
119
+ 'context',
120
+ 'chat_generation_attempts',
121
+ 'stop_at_newline',
122
+ 'mode',
123
+ 'instruction_template',
124
+ 'name1_instruct',
125
+ 'name2_instruct',
126
+ 'context_instruct',
127
+ 'turn_template',
128
+ 'chat_style',
129
+ 'chat-instruct_command',
130
+ ]
131
+ else:
132
+ elements.append('textbox')
133
+ if not shared.args.notebook:
134
+ elements.append('output_textbox')
135
+
136
+ elements += list_model_elements()
137
+ return elements
138
+
139
+
140
+ def gather_interface_values(*args):
141
+ output = {}
142
+ for i, element in enumerate(list_interface_input_elements()):
143
+ output[element] = args[i]
144
+
145
+ if not shared.args.multi_user:
146
+ shared.persistent_interface_state = output
147
+ Path('logs').mkdir(exist_ok=True)
148
+ with open(Path(f'logs/session_{shared.get_mode()}_autosave.json'), 'w') as f:
149
+ f.write(json.dumps(output, indent=4))
150
+
151
+ return output
152
+
153
+
154
+ def apply_interface_values(state, use_persistent=False):
155
+ if use_persistent:
156
+ state = shared.persistent_interface_state
157
+
158
+ elements = list_interface_input_elements()
159
+ if len(state) == 0:
160
+ return [gr.update() for k in elements] # Dummy, do nothing
161
+ else:
162
+ return [state[k] if k in state else gr.update() for k in elements]
163
+
164
+
165
+ class ToolButton(gr.Button, gr.components.IOComponent):
166
+ """
167
+ Small button with single emoji as text, fits inside gradio forms
168
+ Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui
169
+ """
170
+
171
+ def __init__(self, **kwargs):
172
+ super().__init__(**kwargs)
173
+
174
+ def get_block_name(self):
175
+ return "button"
176
+
177
+
178
+ def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class):
179
+ """
180
+ Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui
181
+ """
182
+ def refresh():
183
+ refresh_method()
184
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
185
+
186
+ for k, v in args.items():
187
+ setattr(refresh_component, k, v)
188
+
189
+ return gr.update(**(args or {}))
190
+
191
+ refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class)
192
+ refresh_button.click(
193
+ fn=refresh,
194
+ inputs=[],
195
+ outputs=[refresh_component]
196
+ )
197
+
198
+ return refresh_button
199
+
200
+
201
+ def create_delete_button(**kwargs):
202
+ return ToolButton(value=delete_symbol, **kwargs)
203
+
204
+
205
+ def create_save_button(**kwargs):
206
+ return ToolButton(value=save_symbol, **kwargs)
modules/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+
6
+ from modules import shared
7
+ from modules.logging_colors import logger
8
+
9
+
10
+ # Helper function to get multiple values from shared.gradio
11
+ def gradio(*keys):
12
+ if len(keys) == 1 and type(keys[0]) is list:
13
+ keys = keys[0]
14
+
15
+ return [shared.gradio[k] for k in keys]
16
+
17
+
18
+ def save_file(fname, contents):
19
+ if fname == '':
20
+ logger.error('File name is empty!')
21
+ return
22
+
23
+ root_folder = Path(__file__).resolve().parent.parent
24
+ abs_path = Path(fname).resolve()
25
+ rel_path = abs_path.relative_to(root_folder)
26
+ if rel_path.parts[0] == '..':
27
+ logger.error(f'Invalid file path: {fname}')
28
+ return
29
+
30
+ with open(abs_path, 'w', encoding='utf-8') as f:
31
+ f.write(contents)
32
+
33
+ logger.info(f'Saved {abs_path}.')
34
+
35
+
36
+ def delete_file(fname):
37
+ if fname == '':
38
+ logger.error('File name is empty!')
39
+ return
40
+
41
+ root_folder = Path(__file__).resolve().parent.parent
42
+ abs_path = Path(fname).resolve()
43
+ rel_path = abs_path.relative_to(root_folder)
44
+ if rel_path.parts[0] == '..':
45
+ logger.error(f'Invalid file path: {fname}')
46
+ return
47
+
48
+ if abs_path.exists():
49
+ abs_path.unlink()
50
+ logger.info(f'Deleted {fname}.')
51
+
52
+
53
+ def current_time():
54
+ return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
55
+
56
+
57
+ def atoi(text):
58
+ return int(text) if text.isdigit() else text.lower()
59
+
60
+
61
+ # Replace multiple string pairs in a string
62
+ def replace_all(text, dic):
63
+ for i, j in dic.items():
64
+ text = text.replace(i, j)
65
+
66
+ return text
67
+
68
+
69
+ def natural_keys(text):
70
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
71
+
72
+
73
+ def get_available_models():
74
+ return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys)
75
+
76
+
77
+ def get_available_presets():
78
+ return sorted(set((k.stem for k in Path('presets').glob('*.yaml'))), key=natural_keys)
79
+
80
+
81
+ def get_available_prompts():
82
+ prompts = []
83
+ files = set((k.stem for k in Path('prompts').glob('*.txt')))
84
+ prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True)
85
+ prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys)
86
+ prompts += ['Instruct-' + k for k in get_available_instruction_templates() if k != 'None']
87
+ prompts += ['None']
88
+ return prompts
89
+
90
+
91
+ def get_available_characters():
92
+ paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
93
+ return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=natural_keys)
94
+
95
+
96
+ def get_available_instruction_templates():
97
+ path = "characters/instruction-following"
98
+ paths = []
99
+ if os.path.exists(path):
100
+ paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
101
+
102
+ return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)
103
+
104
+
105
+ def get_available_extensions():
106
+ return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
107
+
108
+
109
+ def get_available_loras():
110
+ return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
111
+
112
+
113
+ def get_datasets(path: str, ext: str):
114
+ # include subdirectories for raw txt files to allow training from a subdirectory of txt files
115
+ if ext == "txt":
116
+ return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
117
+
118
+ return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
119
+
120
+
121
+ def get_available_chat_styles():
122
+ return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
123
+
124
+
125
+ def get_available_sessions():
126
+ items = sorted(set(k.stem for k in Path('logs').glob(f'session_{shared.get_mode()}*')), key=natural_keys, reverse=True)
127
+ return [item for item in items if 'autosave' in item] + [item for item in items if 'autosave' not in item]