File size: 4,361 Bytes
16f428a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import re
from datetime import datetime
from pathlib import Path
from modules import github, shared
from modules.logging_colors import logger
# Helper function to get multiple values from shared.gradio
def gradio(*keys):
if len(keys) == 1 and type(keys[0]) in [list, tuple]:
keys = keys[0]
return [shared.gradio[k] for k in keys]
def save_file(fname, contents):
if fname == '':
logger.error('File name is empty!')
return
root_folder = Path(__file__).resolve().parent.parent
abs_path_str = os.path.abspath(fname)
rel_path_str = os.path.relpath(abs_path_str, root_folder)
rel_path = Path(rel_path_str)
if rel_path.parts[0] == '..':
logger.error(f'Invalid file path: \"{fname}\"')
return
with open(abs_path_str, 'w', encoding='utf-8') as f:
f.write(contents)
logger.info(f'Saved \"{abs_path_str}\".')
def delete_file(fname):
if fname == '':
logger.error('File name is empty!')
return
root_folder = Path(__file__).resolve().parent.parent
abs_path_str = os.path.abspath(fname)
rel_path_str = os.path.relpath(abs_path_str, root_folder)
rel_path = Path(rel_path_str)
if rel_path.parts[0] == '..':
logger.error(f'Invalid file path: \"{fname}\"')
return
if rel_path.exists():
rel_path.unlink()
logger.info(f'Deleted \"{fname}\".')
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def atoi(text):
return int(text) if text.isdigit() else text.lower()
# Replace multiple string pairs in a string
def replace_all(text, dic):
for i, j in dic.items():
text = text.replace(i, j)
return text
def natural_keys(text):
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_available_models():
model_list = []
for item in list(Path(f'{shared.args.model_dir}/').glob('*')):
if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml', '.py')) and 'llama-tokenizer' not in item.name:
model_list.append(re.sub('.pth$', '', item.name))
return ['None'] + sorted(model_list, key=natural_keys)
def get_available_presets():
return sorted(set((k.stem for k in Path('presets').glob('*.yaml'))), key=natural_keys)
def get_available_prompts():
prompts = []
files = set((k.stem for k in Path('prompts').glob('*.txt')))
prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True)
prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys)
prompts += ['None']
return prompts
def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return sorted(set((k.stem for k in paths)), key=natural_keys)
def get_available_instruction_templates():
path = "instruction-templates"
paths = []
if os.path.exists(path):
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys)
def get_available_extensions():
extensions = sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
extensions = [v for v in extensions if v not in github.new_extensions]
return extensions
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
def get_datasets(path: str, ext: str):
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
if ext == "txt":
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('*.txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
def get_available_grammars():
return ['None'] + sorted([item.name for item in list(Path('grammars').glob('*.gbnf'))], key=natural_keys)
|