|
import os |
|
import json |
|
import requests |
|
import time |
|
from threading import Thread, Lock |
|
import re |
|
import multiprocessing |
|
import subprocess |
|
|
|
ERROR_PATTERN = re.compile(r"ERROR:") |
|
|
|
|
|
def get_image_name(): |
|
current_dir = os.path.basename(os.getcwd()) |
|
|
|
if "cog" in current_dir: |
|
return current_dir |
|
else: |
|
return f"cog-{current_dir}" |
|
|
|
|
|
def process_log_line(line): |
|
line = line.decode("utf-8").strip() |
|
try: |
|
log_data = json.loads(line) |
|
return json.dumps(log_data, indent=2) |
|
except json.JSONDecodeError: |
|
return line |
|
|
|
|
|
def capture_output(pipe, print_lock, logs=None, error_detected=None): |
|
for line in iter(pipe.readline, b""): |
|
formatted_line = process_log_line(line) |
|
with print_lock: |
|
print(formatted_line) |
|
if logs is not None: |
|
logs.append(formatted_line) |
|
if error_detected is not None: |
|
if ERROR_PATTERN.search(formatted_line): |
|
error_detected[0] = True |
|
|
|
|
|
def wait_for_server_to_be_ready(url, timeout=300): |
|
""" |
|
Waits for the server to be ready. |
|
|
|
Args: |
|
- url: The health check URL to poll. |
|
- timeout: Maximum time (in seconds) to wait for the server to be ready. |
|
""" |
|
start_time = time.time() |
|
while True: |
|
try: |
|
response = requests.get(url) |
|
data = response.json() |
|
|
|
if data["status"] == "READY": |
|
return |
|
elif data["status"] == "SETUP_FAILED": |
|
raise RuntimeError( |
|
"Server initialization failed with status: SETUP_FAILED" |
|
) |
|
|
|
except requests.RequestException: |
|
pass |
|
|
|
if time.time() - start_time > timeout: |
|
raise TimeoutError("Server did not become ready in the expected time.") |
|
|
|
time.sleep(5) |
|
|
|
|
|
def run_training_subprocess(command): |
|
|
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
|
|
|
print_lock = multiprocessing.Lock() |
|
logs = multiprocessing.Manager().list() |
|
error_detected = multiprocessing.Manager().list([False]) |
|
|
|
|
|
stdout_processor = multiprocessing.Process( |
|
target=capture_output, args=(process.stdout, print_lock, logs, error_detected) |
|
) |
|
stderr_processor = multiprocessing.Process( |
|
target=capture_output, args=(process.stderr, print_lock, logs, error_detected) |
|
) |
|
|
|
|
|
stdout_processor.start() |
|
stderr_processor.start() |
|
|
|
|
|
process.wait() |
|
|
|
|
|
stdout_processor.join() |
|
stderr_processor.join() |
|
|
|
|
|
if error_detected[0]: |
|
raise Exception("Error detected in training logs! Check logs for details") |
|
|
|
return list(logs) |
|
|