File size: 4,765 Bytes
b396e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.12.0/dist/tf.min.js')
importScripts('agent_sac.js')
importScripts('reply_buffer.js')

;(async () => {
    const DISABLED = false

    const agent = new AgentSac({batchSize: 100, verbose: true})
    await agent.init()
    await agent.checkpoint() // overwrite
    agent.actor.summary()
    self.postMessage({weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))}) // syncronize

    const rb = new ReplyBuffer(50000, ({ state: [telemetry, frameL, frameR], action, reward }) => {
        frameL.dispose()
        frameR.dispose()
        telemetry.dispose()
        action.dispose()
        reward.dispose()
    })

    /**
     * Worker.
     * 
     * @returns delay in ms to get ready for the next job
     */
    const job = async () => {
// throw 'disabled'
        if (DISABLED) return 99999
        if (rb.size < agent._batchSize*10) return 1000
        
        const samples = rb.sample(agent._batchSize) // time fast
        if (!samples.length) return 1000
    
        const 
            framesL = [],
            framesR = [],
            telemetries = [],
            actions = [],
            rewards = [],
            nextFramesL = [],
            nextFramesR = [],
            nextTelemetries = []
    
        for (const {
            state: [telemetry, frameL, frameR], 
            action, 
            reward, 
            nextState: [nextTelemetry, nextFrameL, nextFrameR] 
        } of samples) {
            framesL.push(frameL)
            framesR.push(frameR)
            telemetries.push(telemetry)
            actions.push(action)
            rewards.push(reward)
            nextFramesL.push(nextFrameL)
            nextFramesR.push(nextFrameR)
            nextTelemetries.push(nextTelemetry)
        }
    
       tf.tidy(() => {
            console.time('train')
            agent.train({
                state:     [tf.stack(telemetries), tf.stack(framesL), tf.stack(framesR)],
                action:     tf.stack(actions), 
                reward:     tf.stack(rewards), 
                nextState: [tf.stack(nextTelemetries), tf.stack(nextFramesL), tf.stack(nextFramesR)]
            })
            console.timeEnd('train')
        })

        console.time('train postMessage')
        self.postMessage({
            weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))
        })
        console.timeEnd('train postMessage')
    
        return 1
    }
    
    /**
     * Executes job.
     */
    const tick = async () => {
        try {
            setTimeout(tick, await job())
        } catch (e) {
            console.error(e)
            setTimeout(tick, 5000) // show must go on (҂◡_◡) ᕤ
        }
    }
    
    setTimeout(tick, 1000)
    
    /**
     * Decode transition from the main thread.
     * 
     * @param {{ id, state, action, reward }} transition 
     * @returns 
     */
    const decodeTransition = transition => {
        let { id, state: [telemetry, frameL, frameR], action, reward, priority } = transition
    
        return tf.tidy(() => {
            state = [
                tf.tensor1d(telemetry),
                tf.tensor3d(frameL, agent._frameStackShape),
                tf.tensor3d(frameR, agent._frameStackShape)
            ]
            action = tf.tensor1d(action)
            reward = tf.tensor1d([reward])
    
            return { id, state, action, reward, priority }
        })
    }
    
    let i = 0
    self.addEventListener('message', async e => {
        i++

        if (DISABLED) return
        if (i%50 === 0) console.log('RBSIZE: ', rb.size)
    
        switch (e.data.action) {
            case 'newTransition':
                const transition = decodeTransition(e.data.transition)
                rb.add(transition)

                tf.tidy(()=> {
                    return
                    const {
                        state: [telemetry, frameL, frameR], 
                        action,
                    } = transition;
                    const state = [tf.stack([telemetry]), tf.stack([frameL]), tf.stack([frameR])]
                    const q1TargValue = agent.q1Targ.predict([...state, tf.stack([action])], {batchSize: 1})
                    const q2TargValue = agent.q2Targ.predict([...state, tf.stack([action])], {batchSize: 1})                    
                    console.log('value', Math.min(q1TargValue.arraySync()[0][0], q2TargValue.arraySync()[0][0]).toFixed(5))
                })


                break
            default:
                console.warn('Unknown action')
                break
        }
    
        if (i % rb._limit === 0)
            agent.checkpoint() // timer ~ 500ms, don't await intentionally
    })
})()