ml-worker / app.py
inoki-giskard's picture
Add GPU detection before running for logging
44ac6bc
raw
history blame
4.75 kB
import gradio as gr
import logging
import subprocess
import threading
import sys
import os
from giskard.settings import settings
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("giskard").setLevel(logging.INFO)
GSK_HUB_URL = 'GSK_HUB_URL'
GSK_API_KEY = 'GSK_API_KEY'
HF_SPACE_HOST = 'SPACE_HOST'
HF_SPACE_TOKEN = 'GSK_HUB_HFS'
READONLY = os.environ.get("READONLY") if os.environ.get("READONLY") else False
LOG_FILE = "output.log"
def read_logs():
sys.stdout.flush()
try:
with open(LOG_FILE, "r") as f:
return f.read()
except Exception:
return "ML worker not running"
def detect_gpu():
try:
import torch
logger.info(f"PyTorch GPU: {torch.cuda.is_available()}")
except ImportError:
logger.warn("No PyTorch installed")
try:
import tensorflow as tf
logger.info(f"Tensorflow GPU: {len(tf.config.list_physical_devices('GPU')) > 0}")
except ImportError:
logger.warn("No Tensorflow installed")
detect_gpu()
previous_url = ""
ml_worker = None
def read_status():
if ml_worker:
return f"ML worker serving {previous_url}"
elif len(previous_url):
return f"ML worker exited for {previous_url}"
else:
return "ML worker not started"
def run_ml_worker(url, api_key, hf_token):
global ml_worker, previous_url
previous_url = url
subprocess.run(["giskard", "worker", "stop"])
ml_worker = subprocess.Popen(
[
"giskard", "worker", "start",
"-u", f"{url}", "-k", f"{api_key}", "-t", f"{hf_token}"
],
stdout=open(LOG_FILE, "w"), stderr=subprocess.STDOUT
)
args = ml_worker.args[:3]
logging.info(f"Process {args} exited with {ml_worker.wait()}")
ml_worker = None
def stop_ml_worker():
global ml_worker, previous_url
if ml_worker is not None:
logging.info(f"Stopping ML worker for {previous_url}")
ml_worker.terminate()
ml_worker = None
logging.info("ML worker stopped")
return "ML worker stopped"
return "ML worker not started"
def start_ml_worker(url, api_key, hf_token):
if not url or len(url) < 1:
return "Please provide URL of Giskard"
if ml_worker is not None:
return f"ML worker is still running for {previous_url}"
# Always run an external ML worker
stop_ml_worker()
logging.info(f"Starting ML worker for {url}")
thread = threading.Thread(target=run_ml_worker, args=(url, api_key, hf_token))
thread.start()
return f"ML worker running for {url}"
theme = gr.themes.Soft(
primary_hue="green",
)
with gr.Blocks(theme=theme) as iface:
with gr.Row():
with gr.Column():
url = os.environ.get(GSK_HUB_URL) if os.environ.get(GSK_HUB_URL) else f"http://{settings.host}:{settings.ws_port}"
url_input = gr.Textbox(
label="Giskard Hub URL",
interactive=not READONLY,
value=url,
)
api_key_input = gr.Textbox(
label="Giskard Hub API Key",
interactive=not READONLY,
type="password",
value=os.environ.get(GSK_API_KEY),
placeholder="gsk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx",
)
hf_token_input = gr.Textbox(
label="Hugging Face Spaces Token",
interactive=not READONLY,
type="password",
value=os.environ.get(HF_SPACE_TOKEN),
info="if using a private Giskard Hub on Hugging Face Spaces",
)
with gr.Column():
output = gr.Textbox(label="Status")
if READONLY:
gr.Textbox("You are browsering a read-only 🐢 Giskard ML worker instance. ", container=False)
gr.Textbox("Please duplicate this space to configure your own Giskard ML worker.", container=False)
gr.DuplicateButton(value="Duplicate Space for 🐢 Giskard ML worker", size='lg', variant="primary")
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
run_btn.click(start_ml_worker, [url_input, api_key_input, hf_token_input], output)
stop_btn = gr.Button("Stop", variant="stop", interactive=not READONLY)
stop_btn.click(stop_ml_worker, None, output)
logs = gr.Textbox(label="Giskard ML worker log:")
iface.load(read_logs, None, logs, every=0.5)
iface.load(read_status, None, output, every=5)
if os.environ.get(GSK_HUB_URL) and os.environ.get(GSK_API_KEY):
start_ml_worker(os.environ.get(GSK_HUB_URL), os.environ.get(GSK_API_KEY), os.environ.get(HF_SPACE_TOKEN))
iface.queue()
iface.launch()