import argparse |
from dataclasses import dataclass, field |
import json |
import copy |
import multiprocessing as mp |
import uuid |
from datetime import datetime, timedelta |
from collections import defaultdict, deque |
import io |
import zipfile |
import queue |
import time |
import random |
import logging |
from tensordict import TensorDict |
import cv2 |
from flask import Flask, request, make_response, send_file |
from PIL import Image |
import torchvision.transforms as T |
import numpy as np |
import torch as th |
from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE |
logging.basicConfig(level=logging.INFO) |
parser = argparse.ArgumentParser(description="Simple Dreamer") |
parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs") |
parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.") |
parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one") |
parser.add_argument("--port", type=int, default=5000) |
parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.") |
parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.") |
parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.") |
parser.add_argument("--image_width", type=int, default=300, help="Width of the image") |
parser.add_argument("--image_height", type=int, default=180, help="Height of the image") |
parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers") |
PREDICTION_JSON_FILENAME = "predictions.json" |
JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10) |
"temperature": 0.9, |
"top_k": None, |
"top_p": 1.0, |
"max_context_length": 10, |
} |
def float_or_none(string): |
if string.lower() == "none": |
return None |
return float(string) |
def be_image_preprocess(image, target_width, target_height): |
if target_width is not None and target_height is not None: |
if image.shape[1] != target_width or image.shape[0] != target_height: |
image = cv2.resize(image, (target_width, target_height)) |
return np.transpose(image, (2, 0, 1)) |
def action_vector_to_be_action_vector(action): |
action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1 |
return action |
def be_action_vector_to_action_vector(action): |
for stick_index in range(-4, 0): |
action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])] |
return action |
@dataclass |
class DreamJob: |
job_id: str |
sampling_settings: dict |
num_predictions_remaining: int |
num_predictions_done: int |
context_images: th.Tensor |
context_actions: th.Tensor |
context_tokens: list |
actions_to_take: th.Tensor = None |
@dataclass |
class DreamJobResult: |
job_id: str |
dream_step_index: int |
dreamt_image: th.Tensor |
dreamt_action: th.Tensor |
dreamt_tokens: th.Tensor |
result_creation_time: datetime = field(default_factory=datetime.now) |
def setup_and_load_model_be_model(args): |
model = load_model_from_checkpoint(args.model) |
th.set_float32_matmul_precision("high") |
th.backends.cuda.matmul.allow_tf32 = True |
return model |
def get_job_batchable_information(job): |
"""Return comparable object of job information. Used for batching""" |
context_length = job.context_images.shape[1] |
return (context_length, job.sampling_settings) |
def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1): |
"""Return a list of jobs (or empty list) that can be batched together""" |
batchable_jobs = [] |
required_job_info = None |
while len(batchable_jobs) < max_batch_size: |
try: |
job = job_queue.get(timeout=timeout) |
except queue.Empty: |
break |
except OSError: |
break |
if job.job_id in cancelled_ids_set: |
continue |
job_info = get_job_batchable_information(job) |
if required_job_info is None: |
required_job_info = job_info |
elif required_job_info != job_info: |
job_queue.put(job) |
break |
batchable_jobs.append(job) |
return batchable_jobs |
def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set): |
"""IN-PLACE Update cancelled_ids_set with new ids from the queue""" |
has_changed = False |
while not cancelled_ids_queue.empty(): |
try: |
cancelled_id = cancelled_ids_queue.get_nowait() |
except queue.Empty: |
break |
cancelled_ids_deque.append(cancelled_id) |
has_changed = True |
if has_changed: |
cancelled_ids_set.clear() |
cancelled_ids_set.update(cancelled_ids_deque) |
def predict_step(context_data, sampling_settings, model, tokens=None): |
with th.no_grad(): |
predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings) |
return predicted_step |
def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args): |
logger = logging.getLogger(f"dreamer_worker {device_to_use}") |
logger.info("Loading up model...") |
model = setup_and_load_model_be_model(args) |
model = model.to(device_to_use) |
logger.info("Model loaded. Fetching results") |
cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE) |
cancelled_ids_set = set() |
while not quit_flag.is_set(): |
update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set) |
batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size) |
if len(batchable_jobs) == 0: |
continue |
sampling_settings = batchable_jobs[0].sampling_settings |
max_context_length = sampling_settings.pop("max_context_length") |
images = [job.context_images[:, :max_context_length] for job in batchable_jobs] |
actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs] |
tokens = [job.context_tokens for job in batchable_jobs] |
images = th.concat(images, dim=0).to(device_to_use) |
actions = th.concat(actions, dim=0).to(device_to_use) |
context_data = TensorDict({ |
"images": images, |
"actions_output": actions |
}, batch_size=images.shape[:2]) |
predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens) |
predicted_step = predicted_step.cpu() |
predicted_images = predicted_step["images"] |
predicted_actions = predicted_step["actions_output"] |
predicted_image_tokens = predicted_image_tokens.cpu() |
for job_i, job in enumerate(batchable_jobs): |
image_context = job.context_images |
action_context = job.context_actions |
token_context = job.context_tokens |
dreamt_image = predicted_images[job_i].unsqueeze(0) |
dreamt_action = predicted_actions[job_i].unsqueeze(0) |
dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0) |
actions_to_take = job.actions_to_take |
if actions_to_take is not None and actions_to_take.shape[1] > 0: |
dreamt_action = actions_to_take[:, 0:1] |
actions_to_take = actions_to_take[:, 1:] |
if actions_to_take.shape[1] == 0: |
actions_to_take = None |
result_queue.put(DreamJobResult( |
job_id=job.job_id, |
dream_step_index=job.num_predictions_done, |
dreamt_image=dreamt_image, |
dreamt_action=dreamt_action, |
dreamt_tokens=dreamt_tokens |
)) |
if job.num_predictions_remaining > 0: |
if image_context.shape[1] >= max_context_length: |
image_context = image_context[:, 1:] |
action_context = action_context[:, 1:] |
token_context = token_context[1:] |
image_context = th.cat([image_context, dreamt_image], dim=1) |
action_context = th.cat([action_context, dreamt_action], dim=1) |
token_context.append(dreamt_tokens[0, 0].tolist()) |
job.sampling_settings["max_context_length"] = max_context_length |
job_queue.put(DreamJob( |
job_id=job.job_id, |
sampling_settings=job.sampling_settings, |
num_predictions_remaining=job.num_predictions_remaining - 1, |
num_predictions_done=job.num_predictions_done + 1, |
context_images=image_context, |
context_actions=action_context, |
context_tokens=token_context, |
actions_to_take=actions_to_take |
)) |
class DreamerServer: |
def __init__(self, num_workers, args): |
self.num_workers = num_workers |
self.args = args |
self.model = None |
self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs) |
self.results_queue = mp.Queue() |
self.cancelled_jobs = set() |
self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)] |
self._last_result_cleanup = datetime.now() |
self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan) |
self.local_results = defaultdict(list) |
self.logger = logging.getLogger("DreamerServer") |
def get_details(self): |
details = { |
"model_file": self.args.model, |
"max_concurrent_jobs": self.args.max_concurrent_jobs, |
"max_dream_steps_per_job": self.args.max_dream_steps_per_job, |
"max_job_lifespan": self.args.max_job_lifespan, |
} |
return json.dumps(details) |
def _check_if_should_remove_old_jobs(self): |
time_now = datetime.now() |
if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE: |
return |
self._last_result_cleanup = time_now |
self._gather_new_results() |
job_ids = list(self.local_results.keys()) |
for job_id in job_ids: |
results = self.local_results[job_id] |
if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime: |
self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}") |
del self.local_results[job_id] |
def add_new_job(self, request, request_json): |
""" |
Add new dreaming job to the queues. |
Request should have: |
Returns: json object with new job id |
""" |
self._check_if_should_remove_old_jobs() |
sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS) |
if "num_steps_to_predict" not in request_json: |
return make_response("num_steps_to_predict not in request", 400) |
num_steps_to_predict = request_json['num_steps_to_predict'] |
if num_steps_to_predict > self.args.max_dream_steps_per_job: |
return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400) |
num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1 |
if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs: |
return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400) |
for key in sampling_settings: |
sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key] |
context_images = [] |
context_actions = [] |
context_tokens = [] |
future_actions = [] |
for step in request_json["steps"]: |
image_path = step["image_name"] |
image = np.array(Image.open(request.files[image_path].stream)) |
image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height) |
context_images.append(th.from_numpy(image)) |
action = step["action"] |
action = action_vector_to_be_action_vector(action) |
context_actions.append(th.tensor(action)) |
tokens = step["tokens"] |
context_tokens.append(tokens) |
future_actions = None |
if "future_actions" in request_json: |
future_actions = [] |
for step in request_json["future_actions"]: |
action = step["action"] |
action = action_vector_to_be_action_vector(action) |
future_actions.append(th.tensor(action)) |
context_images = th.stack(context_images).unsqueeze(0) |
context_actions = th.stack(context_actions).unsqueeze(0) |
future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None |
list_of_job_ids = [] |
for _ in range(num_parallel_predictions): |
job_id = uuid.uuid4().hex |
self.jobs.put(DreamJob( |
job_id=job_id, |
sampling_settings=sampling_settings, |
num_predictions_remaining=num_steps_to_predict, |
num_predictions_done=0, |
context_images=context_images, |
context_actions=context_actions, |
context_tokens=context_tokens, |
actions_to_take=future_actions |
)) |
list_of_job_ids.append(job_id) |
job_queue_size = self.jobs.qsize() |
return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size}) |
def _gather_new_results(self): |
if not self.results_queue.empty(): |
for _ in range(self.results_queue.qsize()): |
result = self.results_queue.get() |
if result.job_id in self.cancelled_jobs: |
continue |
self.local_results[result.job_id].append(result) |
def get_new_results(self, request, request_json): |
if "job_ids" not in request_json: |
return make_response("job_ids not in request", 400) |
self._gather_new_results() |
job_ids = request_json["job_ids"] |
if not isinstance(job_ids, list): |
job_ids = [job_ids] |
return_results = [] |
for job_id in job_ids: |
if job_id in self.local_results: |
return_results.append(self.local_results[job_id]) |
del self.local_results[job_id] |
if len(return_results) == 0: |
return make_response("No new responses", 204) |
output_json = [] |
output_image_bytes = {} |
for job_results in return_results: |
for result in job_results: |
action = result.dreamt_action.numpy() |
action = be_action_vector_to_action_vector(action[0, 0].tolist()) |
dreamt_tokens = result.dreamt_tokens[0, 0].tolist() |
image_filename = f"{result.job_id}_{result.dream_step_index}.png" |
output_json.append({ |
"job_id": result.job_id, |
"dream_step_index": result.dream_step_index, |
"action": action, |
"tokens": dreamt_tokens, |
"image_filename": image_filename |
}) |
image_bytes = io.BytesIO() |
T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG") |
output_image_bytes[image_filename] = image_bytes.getvalue() |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] |
zip_bytes = io.BytesIO() |
with zipfile.ZipFile(zip_bytes, "w") as z: |
for filename, bytes in output_image_bytes.items(): |
z.writestr(filename, bytes) |
z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json)) |
zip_bytes.seek(0) |
return send_file( |
zip_bytes, |
mimetype="zip", |
as_attachment=True, |
download_name=f"dreaming_results_{timestamp}.zip" |
) |
def cancel_job(self, request, request_json): |
if "job_id" not in request_json: |
return make_response("job_id not in request", 400) |
job_id = request_json["job_id"] |
self.cancelled_jobs.add(job_id) |
for job_queue in self.cancelled_jobs_queues: |
job_queue.put(job_id) |
return make_response("OK", 200) |
def main_run(args): |
app = Flask(__name__) |
num_workers = th.cuda.device_count() |
if num_workers == 0: |
raise RuntimeError("No CUDA devices found. Cannot run Dreamer.") |
server = DreamerServer(num_workers, args) |
quit_flag = mp.Event() |
dreamer_worker_processes = [] |
for device_i in range(num_workers): |
device = f"cuda:{device_i}" |
dreamer_worker_process = mp.Process( |
target=dreamer_worker, |
args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args) |
) |
dreamer_worker_process.daemon = True |
dreamer_worker_process.start() |
dreamer_worker_processes.append(dreamer_worker_process) |
@app.route('/') |
def details(): |
return server.get_details() |
@app.route('/new_job', methods=['POST']) |
def new_job(): |
request_json = json.loads(request.form["json"]) |
return server.add_new_job(request, request_json) |
@app.route('/get_job_results', methods=['GET']) |
def get_results(): |
request_json = {"job_ids": request.args.getlist("job_ids")} |
return server.get_new_results(request, request_json) |
@app.route('/cancel_job', methods=['GET']) |
def cancel_job(): |
request_json = request.args.to_dict() |
return server.cancel_job(request, request_json) |
app.run(host="", port=args.port, debug=args.debug) |
quit_flag.set() |
for dreamer_worker_process in dreamer_worker_processes: |
dreamer_worker_process.join() |
if __name__ == '__main__': |
args = parser.parse_args() |
main_run(args) |