ai-creature / worker.js
Artod's picture
Upload 16 files
b396e7b
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.12.0/dist/tf.min.js')
importScripts('agent_sac.js')
importScripts('reply_buffer.js')
;(async () => {
const DISABLED = false
const agent = new AgentSac({batchSize: 100, verbose: true})
await agent.init()
await agent.checkpoint() // overwrite
agent.actor.summary()
self.postMessage({weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))}) // syncronize
const rb = new ReplyBuffer(50000, ({ state: [telemetry, frameL, frameR], action, reward }) => {
frameL.dispose()
frameR.dispose()
telemetry.dispose()
action.dispose()
reward.dispose()
})
/**
* Worker.
*
* @returns delay in ms to get ready for the next job
*/
const job = async () => {
// throw 'disabled'
if (DISABLED) return 99999
if (rb.size < agent._batchSize*10) return 1000
const samples = rb.sample(agent._batchSize) // time fast
if (!samples.length) return 1000
const
framesL = [],
framesR = [],
telemetries = [],
actions = [],
rewards = [],
nextFramesL = [],
nextFramesR = [],
nextTelemetries = []
for (const {
state: [telemetry, frameL, frameR],
action,
reward,
nextState: [nextTelemetry, nextFrameL, nextFrameR]
} of samples) {
framesL.push(frameL)
framesR.push(frameR)
telemetries.push(telemetry)
actions.push(action)
rewards.push(reward)
nextFramesL.push(nextFrameL)
nextFramesR.push(nextFrameR)
nextTelemetries.push(nextTelemetry)
}
tf.tidy(() => {
console.time('train')
agent.train({
state: [tf.stack(telemetries), tf.stack(framesL), tf.stack(framesR)],
action: tf.stack(actions),
reward: tf.stack(rewards),
nextState: [tf.stack(nextTelemetries), tf.stack(nextFramesL), tf.stack(nextFramesR)]
})
console.timeEnd('train')
})
console.time('train postMessage')
self.postMessage({
weights: await Promise.all(agent.actor.getWeights().map(w => w.array()))
})
console.timeEnd('train postMessage')
return 1
}
/**
* Executes job.
*/
const tick = async () => {
try {
setTimeout(tick, await job())
} catch (e) {
console.error(e)
setTimeout(tick, 5000) // show must go on (҂◡_◡) ᕤ
}
}
setTimeout(tick, 1000)
/**
* Decode transition from the main thread.
*
* @param {{ id, state, action, reward }} transition
* @returns
*/
const decodeTransition = transition => {
let { id, state: [telemetry, frameL, frameR], action, reward, priority } = transition
return tf.tidy(() => {
state = [
tf.tensor1d(telemetry),
tf.tensor3d(frameL, agent._frameStackShape),
tf.tensor3d(frameR, agent._frameStackShape)
]
action = tf.tensor1d(action)
reward = tf.tensor1d([reward])
return { id, state, action, reward, priority }
})
}
let i = 0
self.addEventListener('message', async e => {
i++
if (DISABLED) return
if (i%50 === 0) console.log('RBSIZE: ', rb.size)
switch (e.data.action) {
case 'newTransition':
const transition = decodeTransition(e.data.transition)
rb.add(transition)
tf.tidy(()=> {
return
const {
state: [telemetry, frameL, frameR],
action,
} = transition;
const state = [tf.stack([telemetry]), tf.stack([frameL]), tf.stack([frameR])]
const q1TargValue = agent.q1Targ.predict([...state, tf.stack([action])], {batchSize: 1})
const q2TargValue = agent.q2Targ.predict([...state, tf.stack([action])], {batchSize: 1})
console.log('value', Math.min(q1TargValue.arraySync()[0][0], q2TargValue.arraySync()[0][0]).toFixed(5))
})
break
default:
console.warn('Unknown action')
break
}
if (i % rb._limit === 0)
agent.checkpoint() // timer ~ 500ms, don't await intentionally
})
})()