File size: 5,064 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import jpype
from math import ceil
from enum import Enum
from root import PRJROOT
from jpype import JString, JInt, JBoolean, JLong
from typing import Union, Dict
from src.smb.level import MarioLevel, LevelRender
from src.utils.filesys import getpath

JVMPath = None

class MarioJavaAgents(Enum):
    Runner = 'agents.robinBaumgarten'
    Killer = 'agents.killer'
    Collector = 'agents.collector'

    def __str__(self):
        return self.value + '.Agent'


class MarioProxy:
    def __init__(self):
        if not jpype.isJVMStarted():
            jar_path = getpath('smb/Mario-AI-Framework.jar')
            jpype.startJVM(
                jpype.getDefaultJVMPath() if JVMPath is None else JVMPath,
                f"-Djava.class.path={jar_path}", '-Xmx2g'
            )
            """
                -Xmx{size} set the heap size.
            """
        jpype.JClass("java.lang.System").setProperty('user.dir', os.path.join(PRJROOT, 'smb'))
        self.__proxy = jpype.JClass("MarioProxy")()

    @staticmethod
    def __extract_res(jresult):
        return {
            'status': str(jresult.getGameStatus().toString()),
            'completing-ratio': float(jresult.getCompletionPercentage()),
            '#kills': int(jresult.getKillsTotal()),
            '#kills-by-fire': int(jresult.getKillsByFire()),
            '#kills-by-stomp': int(jresult.getKillsByStomp()),
            '#kills-by-shell': int(jresult.getKillsByShell()),
            'trace': [
                [float(item.getMarioX()), float(item.getMarioY())]
                for item in jresult.getAgentEvents()
            ],
            'JAgentEvents': jresult.getAgentEvents()
        }

    def play_game(self, level: Union[str, MarioLevel], lives=0, verbose=False, scale=2):
        if type(level) == str:
            level = MarioLevel.from_file(level)
        jresult = self.__proxy.playGame(JString(str(level)), JInt(lives), JBoolean(verbose), JInt(scale))
        return MarioProxy.__extract_res(jresult)

    def simulate_game(self,
        level: Union[str, MarioLevel],
        agent: MarioJavaAgents=MarioJavaAgents.Runner,
        render: bool=False,
        realTimeLim: int = 0
    ) -> Dict:
        """
        Run simulation with an agent for a given level
        :param level: if type is str, must be path_ of a valid level file.
        :param agent: type of the agent.
        :param render: render or not.
        :param realTimeLim: Real-time limit, in unit of microsecond.
        :return: dictionary of the results.
        """
        # start_time = time.perf_counter()
        jagent = jpype.JClass(str(agent))()
        if type(level) == str:
            level = MarioLevel.from_file(level)
        fps = 24 if render else 0
        jresult = self.__proxy.simulateGame(JString(str(level)), jagent, JBoolean(render), JInt(fps), JLong(realTimeLim * 1000))
        return MarioProxy.__extract_res(jresult)

    def simulate_complete(self,
        level: Union[str, MarioLevel],
        agent: MarioJavaAgents=MarioJavaAgents.Runner,
        segTimeK: int=80
    ) -> Dict:
        ts = LevelRender.tex_size
        jagent = jpype.JClass(str(agent))()
        if type(level) == str:
            level = MarioLevel.from_file(level)
        reached_tile = 0
        res = {'restarts': [], 'trace': []}
        dx = 0
        win = False
        while not win and reached_tile < level.w - 1:
            jresult = self.__proxy.simulateWithSegmentwiseTimeout(
                JString(str(level[:, reached_tile:])), jagent, JInt(segTimeK))
            pyresult = MarioProxy.__extract_res(jresult)
            reached = pyresult['trace'][-1][0]
            reached_tile += ceil(reached / ts)
            if pyresult['status'] != 'WIN':
                res['restarts'].append(reached_tile)
            else:
                win = True
            res['trace'] += [[dx + item[0], item[1]] for item in pyresult['trace']]
            dx = reached_tile * ts
        return res

    @staticmethod
    def get_seg_infos(full_info, check_points=None):
        restarts, trace = full_info['restarts'], full_info['trace']
        W = MarioLevel.seg_width
        ts = LevelRender.tex_size
        if check_points is None:
            end = ceil(trace[-1][0] / ts)
            check_points = [x for x in range(W, end, W)]
            check_points.append(end)
        res = [{'trace': [], 'playable': True} for _ in check_points]
        s, e, i = 0, 0, 0
        restart_pointer = 0
        while True:
            while e < len(trace) and trace[e][0] < ts * check_points[i]:
                if restart_pointer < len(restarts) and restarts[restart_pointer] < check_points[i]:
                    res[i]['playable'] = False
                    restart_pointer += 1
                e += 1
            x0 = trace[s][0]
            res[i]['trace'] = [[item[0] - x0, item[1]] for item in trace[s:e]]
            i += 1
            if i == len(check_points):
                break
            s = e
        return res

if __name__ == '__main__':
    simulator = MarioProxy()