import time import os from ding.interaction import Slave, TaskFail from ding.utils import lists_to_dicts class NaiveLearner(Slave): def __init__(self, *args, prefix='', **kwargs): super().__init__(*args, **kwargs) self._prefix = prefix def _process_task(self, task): task_name = task['name'] if task_name == 'resource': return {'cpu': 'xxx', 'gpu': 'xxx'} elif task_name == 'learner_start_task': time.sleep(1) self.task_info = task['task_info'] self.count = 0 return {'message': 'learner task has started'} elif task_name == 'learner_get_data_task': time.sleep(0.01) return { 'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id'], 'batch_size': 2, 'cur_learner_iter': 1 } elif task_name == 'learner_learn_task': data = task['data'] if data is None: raise TaskFail(result={'message': 'no data'}) time.sleep(0.1) data = lists_to_dicts(data) assert 'data_id' in data.keys() priority_keys = ['replay_unique_id', 'replay_buffer_idx', 'priority'] self.count += 1 ret = { 'info': { 'learner_step': self.count }, 'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id'] } ret['info']['priority_info'] = {k: data[k] for k in priority_keys} if self.count > 5: ret['info']['learner_done'] = True os.popen('touch {}_final_model.pth'.format(self._prefix)) return ret elif task_name == 'learner_close_task': return {'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id']} else: raise TaskFail( result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) )