3v324v23 commited on
Commit
5014604
1 Parent(s): 0fdada3
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.19.0
2
+ colorama
3
+ datasets
4
+ flexgen==0.1.7
5
+ gradio_client==0.2.5
6
+ gradio==3.31.0
7
+ markdown
8
+ numpy
9
+ pandas
10
+ Pillow>=9.5.0
11
+ pyyaml
12
+ requests
13
+ rwkv==0.7.3
14
+ safetensors==0.3.1
15
+ sentencepiece
16
+ tqdm
17
+ git+https://github.com/huggingface/peft
18
+ transformers==4.29.1
19
+ bitsandbytes==0.38.1
20
+ llama-cpp-python==0.1.51
.ipynb_checkpoints/server-checkpoint.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import requests
4
+ import warnings
5
+ import modules.logging_colors
6
+
7
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
8
+ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
9
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
10
+ logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
11
+
12
+ # This is a hack to prevent Gradio from phoning home when it gets imported
13
+ def my_get(url, **kwargs):
14
+ logging.info('Gradio HTTP request redirected to localhost :)')
15
+ kwargs.setdefault('allow_redirects', True)
16
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
17
+
18
+
19
+ original_get = requests.get
20
+ requests.get = my_get
21
+ import gradio as gr
22
+ requests.get = original_get
23
+
24
+ import matplotlib
25
+ matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
26
+
27
+ import importlib
28
+ import io
29
+ import json
30
+ import math
31
+ import os
32
+ import re
33
+ import sys
34
+ import time
35
+ import traceback
36
+ import zipfile
37
+ from datetime import datetime
38
+ from functools import partial
39
+ from pathlib import Path
40
+
41
+ import psutil
42
+ import torch
43
+ import yaml
44
+ from PIL import Image
45
+
46
+ import modules.extensions as extensions_module
47
+ from modules import chat, shared, ui, utils
48
+ from modules.extensions import apply_extensions
49
+ from modules.html_generator import chat_html_wrapper
50
+ #from modules.LoRA import add_lora_to_model
51
+ from modules.models import load_model, load_soft_prompt, unload_model
52
+ from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event
53
+
54
+
55
+ def load_model_wrapper(selected_model, autoload=False):
56
+ if not autoload:
57
+ yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
58
+ return
59
+
60
+ if selected_model == 'None':
61
+ yield "No model selected"
62
+ else:
63
+ try:
64
+ yield f"Loading {selected_model}..."
65
+ shared.model_name = selected_model
66
+ unload_model()
67
+ if selected_model != '':
68
+ shared.model, shared.tokenizer = load_model(shared.model_name)
69
+
70
+ yield f"Successfully loaded {selected_model}"
71
+ except:
72
+ yield traceback.format_exc()
73
+
74
+
75
+ def load_preset_values(preset_menu, state, return_dict=False):
76
+ generate_params = {
77
+ 'do_sample': True,
78
+ 'temperature': 1,
79
+ 'top_p': 1,
80
+ 'typical_p': 1,
81
+ 'repetition_penalty': 1,
82
+ 'encoder_repetition_penalty': 1,
83
+ 'top_k': 50,
84
+ 'num_beams': 1,
85
+ 'penalty_alpha': 0,
86
+ 'min_length': 0,
87
+ 'length_penalty': 1,
88
+ 'no_repeat_ngram_size': 0,
89
+ 'early_stopping': False,
90
+ }
91
+ with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
92
+ preset = infile.read()
93
+ for i in preset.splitlines():
94
+ i = i.rstrip(',').strip().split('=')
95
+ if len(i) == 2 and i[0].strip() != 'tokens':
96
+ generate_params[i[0].strip()] = eval(i[1].strip())
97
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
98
+
99
+ if return_dict:
100
+ return generate_params
101
+ else:
102
+ state.update(generate_params)
103
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
104
+
105
+
106
+ def upload_soft_prompt(file):
107
+ with zipfile.ZipFile(io.BytesIO(file)) as zf:
108
+ zf.extract('meta.json')
109
+ j = json.loads(open('meta.json', 'r').read())
110
+ name = j['name']
111
+ Path('meta.json').unlink()
112
+
113
+ with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
114
+ f.write(file)
115
+
116
+ return name
117
+
118
+
119
+ def open_save_prompt():
120
+ fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
121
+ return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
122
+
123
+
124
+ def save_prompt(text, fname):
125
+ if fname != "":
126
+ with open(Path(f'prompts/{fname}.txt'), 'w', encoding='utf-8') as f:
127
+ f.write(text)
128
+
129
+ message = f"Saved to prompts/{fname}.txt"
130
+ else:
131
+ message = "Error: No prompt name given."
132
+
133
+ return message, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
134
+
135
+
136
+ def load_prompt(fname):
137
+ if fname in ['None', '']:
138
+ return ''
139
+ elif fname.startswith('Instruct-'):
140
+ fname = re.sub('^Instruct-', '', fname)
141
+ with open(Path(f'characters/instruction-following/{fname}.yaml'), 'r', encoding='utf-8') as f:
142
+ data = yaml.safe_load(f)
143
+ output = ''
144
+ if 'context' in data:
145
+ output += data['context']
146
+
147
+ replacements = {
148
+ '<|user|>': data['user'],
149
+ '<|bot|>': data['bot'],
150
+ '<|user-message|>': 'Input',
151
+ }
152
+
153
+ output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
154
+ return output.rstrip(' ')
155
+ else:
156
+ with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
157
+ text = f.read()
158
+ if text[-1] == '\n':
159
+ text = text[:-1]
160
+
161
+ return text
162
+
163
+
164
+ def count_tokens(text):
165
+ tokens = get_encoded_length(text)
166
+ return f'{tokens} tokens in the input.'
167
+
168
+
169
+ def download_model_wrapper(repo_id):
170
+ try:
171
+ downloader = importlib.import_module("download-model")
172
+ repo_id_parts = repo_id.split(":")
173
+ model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
174
+ branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
175
+ check = False
176
+
177
+ yield ("Cleaning up the model/branch names")
178
+ model, branch = downloader.sanitize_model_and_branch_names(model, branch)
179
+
180
+ yield ("Getting the download links from Hugging Face")
181
+ links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
182
+
183
+ yield ("Getting the output folder")
184
+ output_folder = downloader.get_output_folder(model, branch, is_lora)
185
+
186
+ if check:
187
+ yield ("Checking previously downloaded files")
188
+ downloader.check_model_files(model, branch, links, sha256, output_folder)
189
+ else:
190
+ yield (f"Downloading files to {output_folder}")
191
+ downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
192
+ yield ("Done!")
193
+ except:
194
+ yield traceback.format_exc()
195
+
196
+
197
+ # Update the command-line arguments based on the interface values
198
+ def update_model_parameters(state, initial=False):
199
+ elements = ui.list_model_elements() # the names of the parameters
200
+ gpu_memories = []
201
+
202
+ for i, element in enumerate(elements):
203
+ if element not in state:
204
+ continue
205
+
206
+ value = state[element]
207
+ if element.startswith('gpu_memory'):
208
+ gpu_memories.append(value)
209
+ continue
210
+
211
+ if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
212
+ continue
213
+
214
+ # Setting null defaults
215
+ if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
216
+ value = vars(shared.args_defaults)[element]
217
+ elif element in ['cpu_memory'] and value == 0:
218
+ value = vars(shared.args_defaults)[element]
219
+
220
+ # Making some simple conversions
221
+ if element in ['wbits', 'groupsize', 'pre_layer']:
222
+ value = int(value)
223
+ elif element == 'cpu_memory' and value is not None:
224
+ value = f"{value}MiB"
225
+
226
+ if element in ['pre_layer']:
227
+ value = [value] if value > 0 else None
228
+
229
+ setattr(shared.args, element, value)
230
+
231
+ found_positive = False
232
+ for i in gpu_memories:
233
+ if i > 0:
234
+ found_positive = True
235
+ break
236
+
237
+ if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
238
+ if found_positive:
239
+ shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
240
+ else:
241
+ shared.args.gpu_memory = None
242
+
243
+
244
+ def get_model_specific_settings(model):
245
+ settings = shared.model_config
246
+ model_settings = {}
247
+
248
+ for pat in settings:
249
+ if re.match(pat.lower(), model.lower()):
250
+ for k in settings[pat]:
251
+ model_settings[k] = settings[pat][k]
252
+
253
+ return model_settings
254
+
255
+
256
+ def load_model_specific_settings(model, state, return_dict=False):
257
+ model_settings = get_model_specific_settings(model)
258
+ for k in model_settings:
259
+ if k in state:
260
+ state[k] = model_settings[k]
261
+
262
+ return state
263
+
264
+
265
+ def save_model_settings(model, state):
266
+ if model == 'None':
267
+ yield ("Not saving the settings because no model is loaded.")
268
+ return
269
+
270
+ with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
271
+ if p.exists():
272
+ user_config = yaml.safe_load(open(p, 'r').read())
273
+ else:
274
+ user_config = {}
275
+
276
+ model_regex = model + '$' # For exact matches
277
+ if model_regex not in user_config:
278
+ user_config[model_regex] = {}
279
+
280
+ for k in ui.list_model_elements():
281
+ user_config[model_regex][k] = state[k]
282
+
283
+ with open(p, 'w') as f:
284
+ f.write(yaml.dump(user_config))
285
+
286
+ yield (f"Settings for {model} saved to {p}")
287
+
288
+
289
+ def create_model_menus():
290
+ # Finding the default values for the GPU and CPU memories
291
+ total_mem = []
292
+ for i in range(torch.cuda.device_count()):
293
+ total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
294
+
295
+ default_gpu_mem = []
296
+ if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
297
+ for i in shared.args.gpu_memory:
298
+ if 'mib' in i.lower():
299
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
300
+ else:
301
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
302
+ while len(default_gpu_mem) < len(total_mem):
303
+ default_gpu_mem.append(0)
304
+
305
+ total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
306
+ if shared.args.cpu_memory is not None:
307
+ default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
308
+ else:
309
+ default_cpu_mem = 0
310
+
311
+
312
+ def create_settings_menus(default_preset):
313
+
314
+ generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
315
+
316
+ with gr.Row():
317
+ with gr.Column():
318
+ with gr.Row():
319
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
320
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
321
+ with gr.Column():
322
+ shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
323
+
324
+ with gr.Row():
325
+ with gr.Column():
326
+ with gr.Box():
327
+ gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
328
+ with gr.Row():
329
+ with gr.Column():
330
+ shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
331
+ shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
332
+ shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
333
+ shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
334
+ with gr.Column():
335
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
336
+ shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
337
+ shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
338
+ shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
339
+ shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
340
+ with gr.Column():
341
+ with gr.Box():
342
+ gr.Markdown('Contrastive search')
343
+ shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
344
+
345
+ gr.Markdown('Beam search (uses a lot of VRAM)')
346
+ with gr.Row():
347
+ with gr.Column():
348
+ shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
349
+ shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
350
+ with gr.Column():
351
+ shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
352
+
353
+ with gr.Box():
354
+ with gr.Row():
355
+ with gr.Column():
356
+ shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
357
+ shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
358
+ with gr.Column():
359
+ shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
360
+ shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
361
+
362
+ shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
363
+ shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
364
+
365
+ with gr.Accordion('Soft prompt', open=False):
366
+ with gr.Row():
367
+ shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
368
+ ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
369
+
370
+ gr.Markdown('Upload a soft prompt (.zip format):')
371
+ with gr.Row():
372
+ shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
373
+
374
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
375
+ shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
376
+ shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
377
+
378
+
379
+ def set_interface_arguments(interface_mode, extensions, bool_active):
380
+ modes = ["default", "notebook", "chat", "cai_chat"]
381
+ cmd_list = vars(shared.args)
382
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
383
+
384
+ shared.args.extensions = extensions
385
+ for k in modes[1:]:
386
+ setattr(shared.args, k, False)
387
+ if interface_mode != "default":
388
+ setattr(shared.args, interface_mode, True)
389
+
390
+ for k in bool_list:
391
+ setattr(shared.args, k, False)
392
+ for k in bool_active:
393
+ setattr(shared.args, k, True)
394
+
395
+ shared.need_restart = True
396
+
397
+
398
+ def create_interface():
399
+
400
+ # Defining some variables
401
+ gen_events = []
402
+ default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
403
+ if len(shared.lora_names) == 1:
404
+ default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
405
+ else:
406
+ default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
407
+ title = 'Text generation web UI'
408
+
409
+ # Authentication variables
410
+ auth = None
411
+ if shared.args.gradio_auth_path is not None:
412
+ gradio_auth_creds = []
413
+ with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
414
+ for line in file.readlines():
415
+ gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
416
+ auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
417
+
418
+ # Importing the extension files and executing their setup() functions
419
+ if shared.args.extensions is not None and len(shared.args.extensions) > 0:
420
+ extensions_module.load_extensions()
421
+
422
+ # css/js strings
423
+ css = ui.css if not shared.is_chat() else ui.css + ui.chat_css
424
+ js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js
425
+ css += apply_extensions('css')
426
+ js += apply_extensions('js')
427
+
428
+ with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
429
+
430
+ # Create chat mode interface
431
+ if shared.is_chat():
432
+ shared.input_elements = ui.list_interface_input_elements(chat=True)
433
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
434
+ shared.gradio['Chat input'] = gr.State()
435
+ shared.gradio['dummy'] = gr.State()
436
+
437
+ with gr.Tab('Text generation', elem_id='main'):
438
+ shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
439
+ shared.gradio['textbox'] = gr.Textbox(label='Input')
440
+ with gr.Row():
441
+ shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
442
+ shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
443
+ shared.gradio['Continue'] = gr.Button('Continue')
444
+
445
+ with gr.Row():
446
+ shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
447
+ shared.gradio['Regenerate'] = gr.Button('Regenerate')
448
+ shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
449
+
450
+ with gr.Row():
451
+ shared.gradio['Impersonate'] = gr.Button('Impersonate')
452
+ shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
453
+ shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
454
+
455
+ with gr.Row():
456
+ shared.gradio['Remove last'] = gr.Button('Remove last')
457
+ shared.gradio['Clear history'] = gr.Button('Clear history')
458
+ shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
459
+ shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
460
+
461
+ shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.')
462
+ shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct')
463
+
464
+ with gr.Tab('Chat settings', elem_id='chat-settings'):
465
+ with gr.Row():
466
+ shared.gradio['character_menu'] = gr.Dropdown(choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.')
467
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
468
+
469
+ with gr.Row():
470
+ with gr.Column(scale=8):
471
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
472
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
473
+ shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
474
+ shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
475
+
476
+ with gr.Column(scale=1):
477
+ shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
478
+ shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
479
+
480
+ shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.')
481
+ shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
482
+ shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
483
+ shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context')
484
+ shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
485
+ with gr.Row():
486
+ shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.')
487
+
488
+ with gr.Row():
489
+ with gr.Tab('Chat history'):
490
+ with gr.Row():
491
+ with gr.Column():
492
+ gr.Markdown('## Upload')
493
+ shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
494
+
495
+ with gr.Column():
496
+ gr.Markdown('## Download')
497
+ shared.gradio['download'] = gr.File()
498
+ shared.gradio['download_button'] = gr.Button(value='Click me')
499
+
500
+ with gr.Tab('Upload character'):
501
+ gr.Markdown('## JSON format')
502
+ with gr.Row():
503
+ with gr.Column():
504
+ gr.Markdown('1. Select the JSON file')
505
+ shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
506
+
507
+ with gr.Column():
508
+ gr.Markdown('2. Select your character\'s profile picture (optional)')
509
+ shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
510
+
511
+ shared.gradio['Upload character'] = gr.Button(value='Submit')
512
+ gr.Markdown('## TavernAI PNG format')
513
+ shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
514
+
515
+ with gr.Tab("Parameters", elem_id="parameters"):
516
+ with gr.Box():
517
+ gr.Markdown("Chat parameters")
518
+ with gr.Row():
519
+ with gr.Column():
520
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
521
+ shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
522
+
523
+ with gr.Column():
524
+ shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations')
525
+ shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
526
+
527
+ create_settings_menus(default_preset)
528
+
529
+ # Create notebook mode interface
530
+ elif shared.args.notebook:
531
+ shared.input_elements = ui.list_interface_input_elements(chat=False)
532
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
533
+ shared.gradio['last_input'] = gr.State('')
534
+ with gr.Tab("Text generation", elem_id="main"):
535
+ with gr.Row():
536
+ with gr.Column(scale=4):
537
+ with gr.Tab('Raw'):
538
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox", lines=27)
539
+
540
+ with gr.Tab('Markdown'):
541
+ shared.gradio['markdown'] = gr.Markdown()
542
+
543
+ with gr.Tab('HTML'):
544
+ shared.gradio['html'] = gr.HTML()
545
+
546
+ with gr.Row():
547
+ shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
548
+ shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
549
+ shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
550
+ shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
551
+
552
+ with gr.Column(scale=1):
553
+ gr.HTML('<div style="padding-bottom: 13px"></div>')
554
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
555
+ with gr.Row():
556
+ shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt')
557
+ ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
558
+
559
+ shared.gradio['open_save_prompt'] = gr.Button('Save prompt')
560
+ shared.gradio['save_prompt'] = gr.Button('Confirm save prompt', visible=False)
561
+ shared.gradio['prompt_to_save'] = gr.Textbox(elem_classes="textbox_default", lines=1, label='Prompt name:', interactive=True, visible=False)
562
+ shared.gradio['count_tokens'] = gr.Button('Count tokens')
563
+ shared.gradio['status'] = gr.Markdown('')
564
+
565
+ with gr.Tab("Parameters", elem_id="parameters"):
566
+ create_settings_menus(default_preset)
567
+
568
+ # Create default mode interface
569
+ else:
570
+ shared.input_elements = ui.list_interface_input_elements(chat=False)
571
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
572
+ shared.gradio['last_input'] = gr.State('')
573
+ with gr.Tab("Text generation", elem_id="main"):
574
+ with gr.Row():
575
+ with gr.Column():
576
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox_default", lines=27, label='Input')
577
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
578
+ with gr.Row():
579
+ shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
580
+ shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
581
+ shared.gradio['Continue'] = gr.Button('Continue', elem_classes="small-button")
582
+ shared.gradio['open_save_prompt'] = gr.Button('Save prompt', elem_classes="small-button")
583
+ shared.gradio['save_prompt'] = gr.Button('Confirm save prompt', visible=False, elem_classes="small-button")
584
+ shared.gradio['count_tokens'] = gr.Button('Count tokens', elem_classes="small-button")
585
+
586
+ with gr.Row():
587
+ with gr.Column():
588
+ with gr.Row():
589
+ shared.gradio['prompt_menu'] = gr.Dropdown(choices=utils.get_available_prompts(), value='None', label='Prompt')
590
+ ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button')
591
+
592
+ with gr.Column():
593
+ shared.gradio['prompt_to_save'] = gr.Textbox(elem_classes="textbox_default", lines=1, label='Prompt name:', interactive=True, visible=False)
594
+ shared.gradio['status'] = gr.Markdown('')
595
+
596
+ with gr.Column():
597
+ with gr.Tab('Raw'):
598
+ shared.gradio['output_textbox'] = gr.Textbox(elem_classes="textbox_default_output", lines=27, label='Output')
599
+
600
+ with gr.Tab('Markdown'):
601
+ shared.gradio['markdown'] = gr.Markdown()
602
+
603
+ with gr.Tab('HTML'):
604
+ shared.gradio['html'] = gr.HTML()
605
+
606
+ with gr.Tab("Parameters", elem_id="parameters"):
607
+ create_settings_menus(default_preset)
608
+
609
+ # Model tab
610
+ # with gr.Tab("Model", elem_id="model-tab"):
611
+ # create_model_menus()
612
+
613
+ # Training tab
614
+ # with gr.Tab("Training", elem_id="training-tab"):
615
+ # training.create_train_interface()
616
+
617
+ # Interface mode tab
618
+ # with gr.Tab("Interface mode", elem_id="interface-mode"):
619
+ # modes = ["default", "notebook", "chat"]
620
+ # current_mode = "default"
621
+ # for mode in modes[1:]:
622
+ # if getattr(shared.args, mode):
623
+ # current_mode = mode
624
+ # break
625
+
626
+ # cmd_list = vars(shared.args)
627
+ # bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes + ui.list_model_elements()])
628
+ # bool_active = [k for k in bool_list if vars(shared.args)[k]]
629
+
630
+ # shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
631
+ # shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions")
632
+ # shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
633
+ # shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
634
+
635
+ # # Reset interface event
636
+ # shared.gradio['reset_interface'].click(
637
+ # set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
638
+ # lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
639
+
640
+ # chat mode event handlers
641
+ if shared.is_chat():
642
+ shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
643
+ clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
644
+ shared.reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']]
645
+
646
+ gen_events.append(shared.gradio['Generate'].click(
647
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
648
+ lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
649
+ chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
650
+ chat.save_history, shared.gradio['mode'], None, show_progress=False)
651
+ )
652
+
653
+ gen_events.append(shared.gradio['textbox'].submit(
654
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
655
+ lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
656
+ chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
657
+ chat.save_history, shared.gradio['mode'], None, show_progress=False)
658
+ )
659
+
660
+ gen_events.append(shared.gradio['Regenerate'].click(
661
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
662
+ partial(chat.generate_chat_reply_wrapper, regenerate=True), shared.input_params, shared.gradio['display'], show_progress=False).then(
663
+ chat.save_history, shared.gradio['mode'], None, show_progress=False)
664
+ )
665
+
666
+ gen_events.append(shared.gradio['Continue'].click(
667
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
668
+ partial(chat.generate_chat_reply_wrapper, _continue=True), shared.input_params, shared.gradio['display'], show_progress=False).then(
669
+ chat.save_history, shared.gradio['mode'], None, show_progress=False)
670
+ )
671
+
672
+ gen_events.append(shared.gradio['Impersonate'].click(
673
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
674
+ lambda x: x, shared.gradio['textbox'], shared.gradio['Chat input'], show_progress=False).then(
675
+ chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)
676
+ )
677
+
678
+ shared.gradio['Replace last reply'].click(
679
+ chat.replace_last_reply, shared.gradio['textbox'], None).then(
680
+ lambda: '', None, shared.gradio['textbox'], show_progress=False).then(
681
+ chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
682
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
683
+
684
+ shared.gradio['Send dummy message'].click(
685
+ chat.send_dummy_message, shared.gradio['textbox'], None).then(
686
+ lambda: '', None, shared.gradio['textbox'], show_progress=False).then(
687
+ chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
688
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
689
+
690
+ shared.gradio['Send dummy reply'].click(
691
+ chat.send_dummy_reply, shared.gradio['textbox'], None).then(
692
+ lambda: '', None, shared.gradio['textbox'], show_progress=False).then(
693
+ chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
694
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
695
+
696
+ shared.gradio['Clear history-confirm'].click(
697
+ lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
698
+ chat.clear_chat_log, [shared.gradio[k] for k in ['greeting', 'mode']], None).then(
699
+ chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
700
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
701
+
702
+ shared.gradio['Stop'].click(
703
+ stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
704
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
705
+
706
+ shared.gradio['mode'].change(
707
+ lambda x: gr.update(visible=x != 'instruct'), shared.gradio['mode'], shared.gradio['chat_style'], show_progress=False).then(
708
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
709
+
710
+
711
+ shared.gradio['chat_style'].change(chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
712
+ shared.gradio['instruction_template'].change(
713
+ partial(chat.load_character, instruct=True), [shared.gradio[k] for k in ['instruction_template', 'name1_instruct', 'name2_instruct']], [shared.gradio[k] for k in ['name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template']])
714
+
715
+ shared.gradio['upload_chat_history'].upload(
716
+ chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
717
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
718
+
719
+ shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False)
720
+ shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
721
+ shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
722
+ shared.gradio['Remove last'].click(
723
+ chat.remove_last_message, None, shared.gradio['textbox'], show_progress=False).then(
724
+ chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
725
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
726
+
727
+ shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
728
+ shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
729
+ shared.gradio['character_menu'].change(
730
+ partial(chat.load_character, instruct=False), [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy']]).then(
731
+ chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
732
+
733
+ shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
734
+ shared.gradio['your_picture'].change(
735
+ chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then(
736
+ partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display'])
737
+
738
+ # notebook/default modes event handlers
739
+ else:
740
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
741
+ if shared.args.notebook:
742
+ output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
743
+ else:
744
+ output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
745
+
746
+ gen_events.append(shared.gradio['Generate'].click(
747
+ lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
748
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
749
+ generate_reply_wrapper, shared.input_params, output_params, show_progress=False) # .then(
750
+ # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
751
+ )
752
+
753
+ gen_events.append(shared.gradio['textbox'].submit(
754
+ lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
755
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
756
+ generate_reply_wrapper, shared.input_params, output_params, show_progress=False) # .then(
757
+ # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
758
+ )
759
+
760
+ if shared.args.notebook:
761
+ shared.gradio['Undo'].click(lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False)
762
+ gen_events.append(shared.gradio['Regenerate'].click(
763
+ lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then(
764
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
765
+ generate_reply_wrapper, shared.input_params, output_params, show_progress=False) # .then(
766
+ # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
767
+ )
768
+ else:
769
+ gen_events.append(shared.gradio['Continue'].click(
770
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
771
+ generate_reply_wrapper, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False) # .then(
772
+ # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
773
+ )
774
+
775
+ shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
776
+ shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False)
777
+ shared.gradio['open_save_prompt'].click(open_save_prompt, None, [shared.gradio[k] for k in ['prompt_to_save', 'open_save_prompt', 'save_prompt']], show_progress=False)
778
+ shared.gradio['save_prompt'].click(save_prompt, [shared.gradio[k] for k in ['textbox', 'prompt_to_save']], [shared.gradio[k] for k in ['status', 'prompt_to_save', 'open_save_prompt', 'save_prompt']], show_progress=False)
779
+ shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
780
+
781
+ shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}")
782
+ shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
783
+ # Extensions tabs
784
+ extensions_module.create_extensions_tabs()
785
+
786
+ # Extensions block
787
+ extensions_module.create_extensions_block()
788
+
789
+ print("start to Launch the interface")
790
+ # Launch the interface
791
+ shared.gradio['interface'].queue()
792
+ if shared.args.listen:
793
+ shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name=shared.args.listen_host or '0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
794
+ else:
795
+ shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
796
+
797
+
798
+ if __name__ == "__main__":
799
+ # Loading custom settings
800
+ settings_file = None
801
+ if shared.args.settings is not None and Path(shared.args.settings).exists():
802
+ settings_file = Path(shared.args.settings)
803
+ elif Path('settings.json').exists():
804
+ settings_file = Path('settings.json')
805
+
806
+ if settings_file is not None:
807
+ logging.info(f"Loading settings from {settings_file}...")
808
+ new_settings = json.loads(open(settings_file, 'r').read())
809
+ for item in new_settings:
810
+ shared.settings[item] = new_settings[item]
811
+
812
+ # Set default model settings based on settings.json
813
+ shared.model_config['.*'] = {
814
+ 'wbits': 'None',
815
+ 'model_type': 'None',
816
+ 'groupsize': 'None',
817
+ 'pre_layer': 0,
818
+ 'mode': shared.settings['mode'],
819
+ 'skip_special_tokens': shared.settings['skip_special_tokens'],
820
+ 'custom_stopping_strings': shared.settings['custom_stopping_strings'],
821
+ }
822
+
823
+ shared.model_config.move_to_end('.*', last=False) # Move to the beginning
824
+
825
+ # Default extensions
826
+ extensions_module.available_extensions = utils.get_available_extensions()
827
+ if shared.is_chat():
828
+ for extension in shared.settings['chat_default_extensions']:
829
+ shared.args.extensions = shared.args.extensions or []
830
+ if extension not in shared.args.extensions:
831
+ shared.args.extensions.append(extension)
832
+ else:
833
+ for extension in shared.settings['default_extensions']:
834
+ shared.args.extensions = shared.args.extensions or []
835
+ if extension not in shared.args.extensions:
836
+ shared.args.extensions.append(extension)
837
+
838
+ available_models = utils.get_available_models()
839
+
840
+ # Model defined through --model
841
+ if shared.args.model is not None:
842
+ shared.model_name = shared.args.model
843
+
844
+ # Only one model is available
845
+ elif len(available_models) == 1:
846
+ shared.model_name = available_models[0]
847
+
848
+ # Select the model from a command-line menu
849
+ elif shared.args.model_menu:
850
+ if len(available_models) == 0:
851
+ logging.error('No models are available! Please download at least one.')
852
+ sys.exit(0)
853
+ else:
854
+ print('The following models are available:\n')
855
+ for i, model in enumerate(available_models):
856
+ print(f'{i+1}. {model}')
857
+
858
+ print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
859
+ i = int(input()) - 1
860
+ print()
861
+
862
+ shared.model_name = available_models[i]
863
+
864
+ # If any model has been selected, load it
865
+ if shared.model_name != 'None':
866
+ model_settings = get_model_specific_settings(shared.model_name)
867
+ shared.settings.update(model_settings) # hijacking the interface defaults
868
+ update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
869
+
870
+ # Load the model
871
+ shared.model, shared.tokenizer = load_model(shared.model_name)
872
+ #if shared.args.lora:
873
+ # add_lora_to_model(shared.args.lora)
874
+
875
+ # Force a character to be loaded
876
+ if shared.is_chat():
877
+ shared.persistent_interface_state.update({
878
+ 'mode': shared.settings['mode'],
879
+ 'character_menu': shared.args.character or shared.settings['character'],
880
+ 'instruction_template': shared.settings['instruction_template']
881
+ })
882
+
883
+ print("start create_interface")
884
+ # Launch the web UI
885
+ create_interface()
886
+ while True:
887
+ time.sleep(0.5)
888
+ if shared.need_restart:
889
+ shared.need_restart = False
890
+ shared.gradio['interface'].close()
891
+ time.sleep(0.5)
892
+ create_interface()
requirements.txt CHANGED
@@ -17,4 +17,4 @@ tqdm
17
  git+https://github.com/huggingface/peft
18
  transformers==4.29.1
19
  bitsandbytes==0.38.1
20
- llama-cpp-python==0.1.50
 
17
  git+https://github.com/huggingface/peft
18
  transformers==4.29.1
19
  bitsandbytes==0.38.1
20
+ llama-cpp-python==0.1.51
server.py CHANGED
@@ -44,10 +44,10 @@ import yaml
44
  from PIL import Image
45
 
46
  import modules.extensions as extensions_module
47
- from modules import chat, shared, training, ui, utils
48
  from modules.extensions import apply_extensions
49
  from modules.html_generator import chat_html_wrapper
50
- from modules.LoRA import add_lora_to_model
51
  from modules.models import load_model, load_soft_prompt, unload_model
52
  from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event
53
 
@@ -72,12 +72,6 @@ def load_model_wrapper(selected_model, autoload=False):
72
  yield traceback.format_exc()
73
 
74
 
75
- def load_lora_wrapper(selected_loras):
76
- yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
77
- add_lora_to_model(selected_loras)
78
- yield ("Successfuly applied the LoRAs")
79
-
80
-
81
  def load_preset_values(preset_menu, state, return_dict=False):
82
  generate_params = {
83
  'do_sample': True,
@@ -875,8 +869,8 @@ if __name__ == "__main__":
875
 
876
  # Load the model
877
  shared.model, shared.tokenizer = load_model(shared.model_name)
878
- if shared.args.lora:
879
- add_lora_to_model(shared.args.lora)
880
 
881
  # Force a character to be loaded
882
  if shared.is_chat():
 
44
  from PIL import Image
45
 
46
  import modules.extensions as extensions_module
47
+ from modules import chat, shared, ui, utils
48
  from modules.extensions import apply_extensions
49
  from modules.html_generator import chat_html_wrapper
50
+ #from modules.LoRA import add_lora_to_model
51
  from modules.models import load_model, load_soft_prompt, unload_model
52
  from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event
53
 
 
72
  yield traceback.format_exc()
73
 
74
 
 
 
 
 
 
 
75
  def load_preset_values(preset_menu, state, return_dict=False):
76
  generate_params = {
77
  'do_sample': True,
 
869
 
870
  # Load the model
871
  shared.model, shared.tokenizer = load_model(shared.model_name)
872
+ #if shared.args.lora:
873
+ # add_lora_to_model(shared.args.lora)
874
 
875
  # Force a character to be loaded
876
  if shared.is_chat():