baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
2.76 kB
import jpype
import pygame as pg
import multiprocessing as mp
from src.smb.proxy import JVMPath
from src.smb.level import MarioLevel
from src.olgen.olg_policy import RLGenPolicy
from src.gan.gankits import get_decoder
from src.olgen.ol_generator import OnlineGenerator
from src.utils.filesys import getpath
def _ol_gen_worker(remote, parent_remote, d_path, g_path, g_device):
parent_remote.close()
designer = RLGenPolicy.from_path(d_path)
generator = get_decoder('models/decoder.pth') if g_path == '' else get_decoder(getpath(g_path))
ol_generator = OnlineGenerator(designer, generator, g_device)
remote.send(str(ol_generator.step()))
while True:
try:
cmd, data = remote.recv()
if cmd == 'step':
remote.send(str(ol_generator.step()))
elif cmd == "close":
remote.close()
break
except EOFError:
break
pass
class MarioOnlineGenGame:
def __init__(self, d_path, g_path='', g_device='cuda:0'):
if not jpype.isJVMStarted():
jar_path = getpath('smb/Mario-AI-Framework.jar')
jpype.startJVM(
jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
f"-Djava.class.path={jar_path}", '-Xmx4g'
)
self.d_path, self.g_path, self.g_device = d_path, g_path, g_device
self.ol_gen_remote, self.process = None, None
def play(self, max_length):
self.__init_ol_gen_remote()
seg_str = self.ol_gen_remote.recv()
print(seg_str)
game = jpype.JClass("MarioOnlineGenGame")(jpype.JString(seg_str))
clk = pg.time.Clock()
finish = False
n_seg = 1
self.ol_gen_remote.send(('step', None))
while not finish:
finish = bool(game.gameStep())
if n_seg < max_length and int(game.getTileDistantToExit()) < MarioLevel.default_seg_width:
seg_str = self.ol_gen_remote.recv()
game.appendSegment(jpype.JString(seg_str))
n_seg += 1
self.ol_gen_remote.send(('step', None))
clk.tick(30)
self.close()
def __init_ol_gen_remote(self):
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.ol_gen_remote, work_remote = ctx.Pipe()
args = (work_remote, self.ol_gen_remote, self.d_path, self.g_path, self.g_device)
self.process = ctx.Process(target=_ol_gen_worker, args=args, daemon=True)
self.process.start()
def close(self):
self.ol_gen_remote.send(('close', None))
self.process.join()