File size: 1,503 Bytes
df1aa82
66e8b15
4708837
e7aeb95
 
 
df1aa82
66e8b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4708837
66e8b15
 
 
 
 
e7aeb95
 
 
 
 
 
 
 
18f1fae
e7aeb95
66e8b15
 
 
 
df1aa82
66e8b15
df1aa82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr

from giskard.ml_worker.ml_worker import MLWorker
from pydantic import AnyHttpUrl
from giskard.settings import settings
from urllib.parse import urlparse

import asyncio
import threading

previous_url = ""
ml_worker = None


def run_ml_worker(ml_worker: MLWorker):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    loop.run_until_complete(ml_worker.start())
    loop.close()


def start_ml_worker(url, api_key, hf_token):
    global ml_worker, previous_url
    # Always run an external ML worker
    if ml_worker is not None:
        print(f"Stopping ML worker for {previous_url}")
        ml_worker.stop()
        print("ML worker stopped")

    parsed_url = urlparse(url)
    backend_url = AnyHttpUrl(
        url=f"{parsed_url.scheme if parsed_url.scheme else 'http'}://{parsed_url.hostname}"
            f"/{parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path}",
        scheme=parsed_url.scheme,
        host=parsed_url.hostname,
        path=parsed_url.path if parsed_url.path and len(parsed_url.path) else settings.ws_path,
    )
    print(f"Starting ML worker for {backend_url}")
    ml_worker = MLWorker(False, backend_url, api_key, hf_token)
    previous_url = backend_url
    thread = threading.Thread(target=run_ml_worker, args=(ml_worker,))
    thread.start()
    return f"ML worker running for {backend_url}"

iface = gr.Interface(fn=start_ml_worker, inputs=["text", "text", "text"], outputs="text")
iface.launch()