Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import jpype | |
| import pygame as pg | |
| import multiprocessing as mp | |
| from src.smb.proxy import JVMPath | |
| from src.smb.level import MarioLevel | |
| from src.olgen.olg_policy import RLGenPolicy | |
| from src.gan.gankits import get_decoder | |
| from src.olgen.ol_generator import OnlineGenerator | |
| from src.utils.filesys import getpath | |
| def _ol_gen_worker(remote, parent_remote, d_path, g_path, g_device): | |
| parent_remote.close() | |
| designer = RLGenPolicy.from_path(d_path) | |
| generator = get_decoder('models/decoder.pth') if g_path == '' else get_decoder(getpath(g_path)) | |
| ol_generator = OnlineGenerator(designer, generator, g_device) | |
| remote.send(str(ol_generator.step())) | |
| while True: | |
| try: | |
| cmd, data = remote.recv() | |
| if cmd == 'step': | |
| remote.send(str(ol_generator.step())) | |
| elif cmd == "close": | |
| remote.close() | |
| break | |
| except EOFError: | |
| break | |
| pass | |
| class MarioOnlineGenGame: | |
| def __init__(self, d_path, g_path='', g_device='cuda:0'): | |
| if not jpype.isJVMStarted(): | |
| jar_path = getpath('smb/Mario-AI-Framework.jar') | |
| jpype.startJVM( | |
| jpype.getDefaultJVMPath() if JVMPath is None else JVMPath, | |
| f"-Djava.class.path={jar_path}", '-Xmx4g' | |
| ) | |
| self.d_path, self.g_path, self.g_device = d_path, g_path, g_device | |
| self.ol_gen_remote, self.process = None, None | |
| def play(self, max_length): | |
| self.__init_ol_gen_remote() | |
| seg_str = self.ol_gen_remote.recv() | |
| print(seg_str) | |
| game = jpype.JClass("MarioOnlineGenGame")(jpype.JString(seg_str)) | |
| clk = pg.time.Clock() | |
| finish = False | |
| n_seg = 1 | |
| self.ol_gen_remote.send(('step', None)) | |
| while not finish: | |
| finish = bool(game.gameStep()) | |
| if n_seg < max_length and int(game.getTileDistantToExit()) < MarioLevel.default_seg_width: | |
| seg_str = self.ol_gen_remote.recv() | |
| game.appendSegment(jpype.JString(seg_str)) | |
| n_seg += 1 | |
| self.ol_gen_remote.send(('step', None)) | |
| clk.tick(30) | |
| self.close() | |
| def __init_ol_gen_remote(self): | |
| forkserver_available = "forkserver" in mp.get_all_start_methods() | |
| start_method = "forkserver" if forkserver_available else "spawn" | |
| ctx = mp.get_context(start_method) | |
| self.ol_gen_remote, work_remote = ctx.Pipe() | |
| args = (work_remote, self.ol_gen_remote, self.d_path, self.g_path, self.g_device) | |
| self.process = ctx.Process(target=_ol_gen_worker, args=args, daemon=True) | |
| self.process.start() | |
| def close(self): | |
| self.ol_gen_remote.send(('close', None)) | |
| self.process.join() | |