Spaces:
No application file
No application file
import json | |
import os.path | |
import shlex | |
import subprocess | |
from enum import Enum | |
from setup_tools.os import is_windows | |
extension_states = os.path.join('data', 'extensions.json') | |
ext_folder = os.path.join('extensions') | |
def git_ready(): | |
cmd = 'git --version' | |
cmd = cmd if is_windows() else shlex.split(cmd) | |
result = subprocess.run(cmd, capture_output=True).returncode | |
return result == 0 | |
class UpdateStatus(Enum): | |
no_git = -1 | |
unmanaged = 0 | |
updated = 1 | |
outdated = 2 | |
class Extension: | |
def __init__(self, ext_name, load_states): | |
self.enabled = (ext_name not in load_states.keys()) or load_states[ext_name] | |
self.extname = ext_name | |
# self.abspath = os.path.abspath(os.path.join(ext_folder, ext_name)) | |
self.path = os.path.join(ext_folder, ext_name) | |
self.main_file = os.path.join(self.path, 'main.py') | |
self.req_file = os.path.join(self.path, 'requirements.py') # Optional | |
self.style_file = os.path.join(self.path, 'style.py') | |
self.js_file = os.path.join(self.path, 'scripts', 'script.js') | |
self.git_dir = os.path.join(self.path, '.git') | |
self.update_el = None | |
extinfo = os.path.join(self.path, 'extension.json') | |
if os.path.isfile(extinfo): | |
with open(extinfo, 'r', encoding='utf8') as info_file: | |
self.info = json.load(info_file) | |
for k in ['name', 'description', 'author']: | |
if k not in self.info: | |
self.info[k] = 'Not provided' | |
if 'tags' not in self.info: | |
self.info['tags'] = [] | |
else: | |
raise FileNotFoundError(f'No extension.json file for {ext_name} extension.') | |
def activate(self): | |
if self.enabled and os.path.isfile(self.main_file): | |
__import__(os.path.splitext(self.main_file)[0].replace(os.path.sep, '.'), fromlist=['']) | |
def get_style_rules(self): | |
if self.enabled and os.path.isfile(self.style_file): | |
__import__(os.path.splitext(self.style_file)[0].replace(os.path.sep, '.'), fromlist=['']) | |
def get_requirements(self): | |
if self.enabled and os.path.isfile(self.req_file): | |
return __import__(os.path.splitext(self.req_file)[0].replace(os.path.sep, '.'), fromlist=['']).requirements() | |
return [] | |
def get_javascript(self) -> str | bool: | |
if self.enabled and os.path.isfile(self.js_file): | |
return self.js_file | |
return False | |
def set_enabled(self, new): | |
self.enabled = new | |
set_load_states() | |
try: | |
import gradio | |
return gradio.update(value=new) | |
except: | |
return new | |
def check_updates(self) -> UpdateStatus: | |
if not os.path.isdir(self.git_dir): | |
return UpdateStatus.unmanaged | |
command1 = 'git fetch' | |
command1 = command1 if is_windows() else shlex.split(command1) | |
command2 = 'git status -uno' | |
command2 = command2 if is_windows() else shlex.split(command2) | |
search_string = 'git pull' # Included in message from git if not up to date | |
neg_search_string = 'Your branch is up to date' | |
a = subprocess.run(command1, capture_output=True, cwd=self.path) | |
if a.returncode != 0: | |
return UpdateStatus.no_git | |
b = subprocess.run(command2, capture_output=True, cwd=self.path) | |
if a.returncode != 0: | |
return UpdateStatus.no_git | |
out_string = b.stdout.decode() | |
if search_string in out_string: | |
return UpdateStatus.outdated | |
if neg_search_string in out_string: | |
return UpdateStatus.updated | |
return UpdateStatus.outdated | |
def update(self): | |
if not os.path.isdir(self.git_dir): | |
return | |
command = 'git pull' | |
command = command if is_windows() else shlex.split(command) | |
output = subprocess.run(command, capture_output=True, cwd=self.path) | |
if output.returncode != 0: | |
print(f'Something went wrong during git pull for {self.extname}') | |
def get_valid_extensions(): | |
return [e for e in os.listdir(ext_folder) | |
if os.path.isdir(os.path.join(ext_folder, e)) | |
and os.path.isfile(os.path.join(ext_folder, e, 'extension.json'))] | |
states: dict[str, Extension] = {} | |
def set_load_states(): | |
s = {k: v.enabled for k, v in zip(states.keys(), states.values())} | |
json.dump(s, open(extension_states, 'w', encoding='utf8')) | |
def get_load_states(): | |
if os.path.isfile(extension_states): | |
return json.load(open(extension_states, 'r', encoding='utf8')) | |
return {} | |
register_callbacks = [ | |
'webui.init', | |
'webui.settings', | |
'webui.tabs', | |
'webui.tabs.utils', | |
'webui.tts.list' | |
] | |
def init_extensions(): | |
# Register default callbacks | |
from webui.extensionlib.callbacks import register_new as register | |
for cb in register_callbacks: | |
register(cb) | |
# Load enabled extensions | |
s = get_load_states() | |
exts = get_valid_extensions() | |
print(f'Found extensions: {", ".join(exts)}') | |
for ext in exts: | |
states[ext] = Extension(ext, s) | |
def get_scripts() -> list[str]: | |
out = [] | |
for script in [e.get_javascript() for e in states.values()]: | |
if script: | |
out.append(script) | |
return out | |
def get_requirements(): | |
out = [] | |
for req in [e.get_requirements() for e in states.values()]: | |
if req: | |
out += req | |
return out | |