Spaces:
Sleeping
Sleeping
import jpype | |
import pygame as pg | |
import multiprocessing as mp | |
from src.smb.proxy import JVMPath | |
from src.smb.level import MarioLevel | |
from src.olgen.olg_policy import RLGenPolicy | |
from src.gan.gankits import get_decoder | |
from src.olgen.ol_generator import OnlineGenerator | |
from src.utils.filesys import getpath | |
def _ol_gen_worker(remote, parent_remote, d_path, g_path, g_device): | |
parent_remote.close() | |
designer = RLGenPolicy.from_path(d_path) | |
generator = get_decoder('models/decoder.pth') if g_path == '' else get_decoder(getpath(g_path)) | |
ol_generator = OnlineGenerator(designer, generator, g_device) | |
remote.send(str(ol_generator.step())) | |
while True: | |
try: | |
cmd, data = remote.recv() | |
if cmd == 'step': | |
remote.send(str(ol_generator.step())) | |
elif cmd == "close": | |
remote.close() | |
break | |
except EOFError: | |
break | |
pass | |
class MarioOnlineGenGame: | |
def __init__(self, d_path, g_path='', g_device='cuda:0'): | |
if not jpype.isJVMStarted(): | |
jar_path = getpath('smb/Mario-AI-Framework.jar') | |
jpype.startJVM( | |
jpype.getDefaultJVMPath() if JVMPath is None else JVMPath, | |
f"-Djava.class.path={jar_path}", '-Xmx4g' | |
) | |
self.d_path, self.g_path, self.g_device = d_path, g_path, g_device | |
self.ol_gen_remote, self.process = None, None | |
def play(self, max_length): | |
self.__init_ol_gen_remote() | |
seg_str = self.ol_gen_remote.recv() | |
print(seg_str) | |
game = jpype.JClass("MarioOnlineGenGame")(jpype.JString(seg_str)) | |
clk = pg.time.Clock() | |
finish = False | |
n_seg = 1 | |
self.ol_gen_remote.send(('step', None)) | |
while not finish: | |
finish = bool(game.gameStep()) | |
if n_seg < max_length and int(game.getTileDistantToExit()) < MarioLevel.default_seg_width: | |
seg_str = self.ol_gen_remote.recv() | |
game.appendSegment(jpype.JString(seg_str)) | |
n_seg += 1 | |
self.ol_gen_remote.send(('step', None)) | |
clk.tick(30) | |
self.close() | |
def __init_ol_gen_remote(self): | |
forkserver_available = "forkserver" in mp.get_all_start_methods() | |
start_method = "forkserver" if forkserver_available else "spawn" | |
ctx = mp.get_context(start_method) | |
self.ol_gen_remote, work_remote = ctx.Pipe() | |
args = (work_remote, self.ol_gen_remote, self.d_path, self.g_path, self.g_device) | |
self.process = ctx.Process(target=_ol_gen_worker, args=args, daemon=True) | |
self.process.start() | |
def close(self): | |
self.ol_gen_remote.send(('close', None)) | |
self.process.join() | |