|
import gradio as gr |
|
|
|
import logging |
|
|
|
import subprocess |
|
import threading |
|
|
|
import sys |
|
import os |
|
|
|
from giskard.settings import settings |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.getLogger().setLevel(logging.INFO) |
|
logging.getLogger("giskard").setLevel(logging.INFO) |
|
|
|
|
|
GSK_HUB_URL = 'GSK_HUB_URL' |
|
GSK_API_KEY = 'GSK_API_KEY' |
|
HF_SPACE_HOST = 'SPACE_HOST' |
|
HF_SPACE_TOKEN = 'GSK_HUB_HFS' |
|
READONLY = os.environ.get("READONLY") if os.environ.get("READONLY") else False |
|
|
|
LOG_FILE = "output.log" |
|
|
|
def read_logs(): |
|
sys.stdout.flush() |
|
try: |
|
with open(LOG_FILE, "r") as f: |
|
return f.read() |
|
except Exception: |
|
return "ML worker not running" |
|
|
|
def detect_gpu(): |
|
try: |
|
import torch |
|
logger.info(f"PyTorch GPU: {torch.cuda.is_available()}") |
|
except ImportError: |
|
logger.warn("No PyTorch installed") |
|
|
|
try: |
|
import tensorflow as tf |
|
logger.info(f"Tensorflow GPU: {len(tf.config.list_physical_devices('GPU')) > 0}") |
|
except ImportError: |
|
logger.warn("No Tensorflow installed") |
|
|
|
detect_gpu() |
|
|
|
previous_url = "" |
|
ml_worker = None |
|
|
|
def read_status(): |
|
if ml_worker: |
|
return f"ML worker serving {previous_url}" |
|
elif len(previous_url): |
|
return f"ML worker exited for {previous_url}" |
|
else: |
|
return "ML worker not started" |
|
|
|
|
|
def run_ml_worker(url, api_key, hf_token): |
|
global ml_worker, previous_url |
|
previous_url = url |
|
subprocess.run(["giskard", "worker", "stop"]) |
|
ml_worker = subprocess.Popen( |
|
[ |
|
"giskard", "worker", "start", |
|
"-u", f"{url}", "-k", f"{api_key}", "-t", f"{hf_token}" |
|
], |
|
stdout=open(LOG_FILE, "w"), stderr=subprocess.STDOUT |
|
) |
|
args = ml_worker.args[:3] |
|
logging.info(f"Process {args} exited with {ml_worker.wait()}") |
|
ml_worker = None |
|
|
|
|
|
def stop_ml_worker(): |
|
global ml_worker, previous_url |
|
if ml_worker is not None: |
|
logging.info(f"Stopping ML worker for {previous_url}") |
|
ml_worker.terminate() |
|
ml_worker = None |
|
logging.info("ML worker stopped") |
|
return "ML worker stopped" |
|
return "ML worker not started" |
|
|
|
|
|
def start_ml_worker(url, api_key, hf_token): |
|
if not url or len(url) < 1: |
|
return "Please provide URL of Giskard" |
|
|
|
if ml_worker is not None: |
|
return f"ML worker is still running for {previous_url}" |
|
|
|
|
|
stop_ml_worker() |
|
|
|
logging.info(f"Starting ML worker for {url}") |
|
thread = threading.Thread(target=run_ml_worker, args=(url, api_key, hf_token)) |
|
thread.start() |
|
return f"ML worker running for {url}" |
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="green", |
|
) |
|
|
|
with gr.Blocks(theme=theme) as iface: |
|
with gr.Row(): |
|
with gr.Column(): |
|
url = os.environ.get(GSK_HUB_URL) if os.environ.get(GSK_HUB_URL) else f"http://{settings.host}:{settings.ws_port}" |
|
url_input = gr.Textbox( |
|
label="Giskard Hub URL", |
|
interactive=not READONLY, |
|
value=url, |
|
) |
|
api_key_input = gr.Textbox( |
|
label="Giskard Hub API Key", |
|
interactive=not READONLY, |
|
type="password", |
|
value=os.environ.get(GSK_API_KEY), |
|
placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx", |
|
) |
|
hf_token_input = gr.Textbox( |
|
label="Hugging Face Spaces Token", |
|
interactive=not READONLY, |
|
type="password", |
|
value=os.environ.get(HF_SPACE_TOKEN), |
|
info="if using a private Giskard Hub on Hugging Face Spaces", |
|
) |
|
|
|
with gr.Column(): |
|
output = gr.Textbox(label="Status") |
|
if READONLY: |
|
gr.Textbox("You are browsering a read-only 🐢 Giskard ML worker instance. ", container=False) |
|
gr.Textbox("Please duplicate this space to configure your own Giskard ML worker.", container=False) |
|
gr.DuplicateButton(value="Duplicate Space for 🐢 Giskard ML worker", size='lg', variant="primary") |
|
|
|
with gr.Row(): |
|
run_btn = gr.Button("Run", variant="primary") |
|
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output) |
|
|
|
stop_btn = gr.Button("Stop", variant="stop", interactive=not READONLY) |
|
stop_btn.click(stop_ml_worker, None, output) |
|
|
|
logs = gr.Textbox(label="Giskard ML worker log:") |
|
iface.load(read_logs, None, logs, every=0.5) |
|
iface.load(read_status, None, output, every=5) |
|
|
|
if os.environ.get(GSK_HUB_URL) and os.environ.get(GSK_API_KEY): |
|
start_ml_worker(os.environ.get(GSK_HUB_URL), os.environ.get(GSK_API_KEY), os.environ.get(HF_SPACE_TOKEN)) |
|
|
|
iface.queue() |
|
iface.launch() |
|
|