audio-webui / webui /extensionlib /extensionmanager.py
mrtroydev's picture
Upload folder using huggingface_hub
3883c60 verified
import json
import os.path
import shlex
import subprocess
from enum import Enum
from setup_tools.os import is_windows
extension_states = os.path.join('data', 'extensions.json')
ext_folder = os.path.join('extensions')
def git_ready():
cmd = 'git --version'
cmd = cmd if is_windows() else shlex.split(cmd)
result = subprocess.run(cmd, capture_output=True).returncode
return result == 0
class UpdateStatus(Enum):
no_git = -1
unmanaged = 0
updated = 1
outdated = 2
class Extension:
def __init__(self, ext_name, load_states):
self.enabled = (ext_name not in load_states.keys()) or load_states[ext_name]
self.extname = ext_name
# self.abspath = os.path.abspath(os.path.join(ext_folder, ext_name))
self.path = os.path.join(ext_folder, ext_name)
self.main_file = os.path.join(self.path, 'main.py')
self.req_file = os.path.join(self.path, 'requirements.py') # Optional
self.style_file = os.path.join(self.path, 'style.py')
self.js_file = os.path.join(self.path, 'scripts', 'script.js')
self.git_dir = os.path.join(self.path, '.git')
self.update_el = None
extinfo = os.path.join(self.path, 'extension.json')
if os.path.isfile(extinfo):
with open(extinfo, 'r', encoding='utf8') as info_file:
self.info = json.load(info_file)
for k in ['name', 'description', 'author']:
if k not in self.info:
self.info[k] = 'Not provided'
if 'tags' not in self.info:
self.info['tags'] = []
else:
raise FileNotFoundError(f'No extension.json file for {ext_name} extension.')
def activate(self):
if self.enabled and os.path.isfile(self.main_file):
__import__(os.path.splitext(self.main_file)[0].replace(os.path.sep, '.'), fromlist=[''])
def get_style_rules(self):
if self.enabled and os.path.isfile(self.style_file):
__import__(os.path.splitext(self.style_file)[0].replace(os.path.sep, '.'), fromlist=[''])
def get_requirements(self):
if self.enabled and os.path.isfile(self.req_file):
return __import__(os.path.splitext(self.req_file)[0].replace(os.path.sep, '.'), fromlist=['']).requirements()
return []
def get_javascript(self) -> str | bool:
if self.enabled and os.path.isfile(self.js_file):
return self.js_file
return False
def set_enabled(self, new):
self.enabled = new
set_load_states()
try:
import gradio
return gradio.update(value=new)
except:
return new
def check_updates(self) -> UpdateStatus:
if not os.path.isdir(self.git_dir):
return UpdateStatus.unmanaged
command1 = 'git fetch'
command1 = command1 if is_windows() else shlex.split(command1)
command2 = 'git status -uno'
command2 = command2 if is_windows() else shlex.split(command2)
search_string = 'git pull' # Included in message from git if not up to date
neg_search_string = 'Your branch is up to date'
a = subprocess.run(command1, capture_output=True, cwd=self.path)
if a.returncode != 0:
return UpdateStatus.no_git
b = subprocess.run(command2, capture_output=True, cwd=self.path)
if a.returncode != 0:
return UpdateStatus.no_git
out_string = b.stdout.decode()
if search_string in out_string:
return UpdateStatus.outdated
if neg_search_string in out_string:
return UpdateStatus.updated
return UpdateStatus.outdated
def update(self):
if not os.path.isdir(self.git_dir):
return
command = 'git pull'
command = command if is_windows() else shlex.split(command)
output = subprocess.run(command, capture_output=True, cwd=self.path)
if output.returncode != 0:
print(f'Something went wrong during git pull for {self.extname}')
def get_valid_extensions():
return [e for e in os.listdir(ext_folder)
if os.path.isdir(os.path.join(ext_folder, e))
and os.path.isfile(os.path.join(ext_folder, e, 'extension.json'))]
states: dict[str, Extension] = {}
def set_load_states():
s = {k: v.enabled for k, v in zip(states.keys(), states.values())}
json.dump(s, open(extension_states, 'w', encoding='utf8'))
def get_load_states():
if os.path.isfile(extension_states):
return json.load(open(extension_states, 'r', encoding='utf8'))
return {}
register_callbacks = [
'webui.init',
'webui.settings',
'webui.tabs',
'webui.tabs.utils',
'webui.tts.list'
]
def init_extensions():
# Register default callbacks
from webui.extensionlib.callbacks import register_new as register
for cb in register_callbacks:
register(cb)
# Load enabled extensions
s = get_load_states()
exts = get_valid_extensions()
print(f'Found extensions: {", ".join(exts)}')
for ext in exts:
states[ext] = Extension(ext, s)
def get_scripts() -> list[str]:
out = []
for script in [e.get_javascript() for e in states.values()]:
if script:
out.append(script)
return out
def get_requirements():
out = []
for req in [e.get_requirements() for e in states.values()]:
if req:
out += req
return out