File size: 5,501 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp
import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload
class ModelWorker():
def __init__(self, model: torch.nn.Module) -> None:
self._model = model
def save(self, storage: Storage) -> Storage:
storage.save(self._model.state_dict())
return storage
class ModelLoader(Supervisor, ABC):
def __init__(self, model: torch.nn.Module) -> None:
"""
Overview:
Save and send models asynchronously and load them synchronously.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
"""
if next(model.parameters()).is_cuda:
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
else:
super().__init__(type_=ChildType.PROCESS)
self._model = model
self._send_callback_loop = None
self._send_callbacks = {}
self._model_worker = ModelWorker(self._model)
def start(self):
if not self._running:
self._model.share_memory()
self.register(self._model_worker)
self.start_link()
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
self._send_callback_loop.start()
def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._send_callback_loop = None
self._send_callbacks = {}
def _loop_send_callback(self):
while True:
payload = self.recv(ignore_err=True)
if payload.err:
logging.warning("Got error when loading data: {}".format(payload.err))
if payload.req_id in self._send_callbacks:
del self._send_callbacks[payload.req_id]
else:
if payload.req_id in self._send_callbacks:
callback = self._send_callbacks.pop(payload.req_id)
callback(payload.data)
def load(self, storage: Storage) -> object:
"""
Overview:
Load model synchronously.
Arguments:
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
Returns:
- object (:obj:): The loaded model.
"""
return storage.load()
@abstractmethod
def save(self, callback: Callable) -> Storage:
"""
Overview:
Save model asynchronously.
Arguments:
- callback (:obj:`Callable`): The callback function after saving model.
Returns:
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
"""
raise NotImplementedError
class FileModelLoader(ModelLoader):
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
"""
Overview:
Model loader using files as storage media.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
- dirname (:obj:`str`): The directory for saving files.
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
files that do not time out when the process is stopped are not cleaned up \
(to avoid errors when other processes read the file), so you may need to \
clean up the remaining files manually
"""
super().__init__(model)
self._dirname = dirname
self._ttl = ttl
self._files = []
self._cleanup_thread = None
def _start_cleanup(self):
"""
Overview:
Start a cleanup thread to clean up files that are taking up too much time on the disk.
"""
if self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
self._cleanup_thread.start()
def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._cleanup_thread = None
def _loop_cleanup(self):
while True:
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
sleep(1)
continue
_, file_path = self._files.pop(0)
if path.exists(file_path):
os.remove(file_path)
def save(self, callback: Callable) -> FileModelStorage:
if not self._running:
logging.warning("Please start model loader before saving model.")
return
if not path.exists(self._dirname):
os.mkdir(self._dirname)
file_path = "model_{}.pth.tar".format(uuid.uuid1())
file_path = path.join(self._dirname, file_path)
model_storage = FileModelStorage(file_path)
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
self.send(payload)
def clean_callback(storage: Storage):
self._files.append([time(), file_path])
callback(storage)
self._send_callbacks[payload.req_id] = clean_callback
self._start_cleanup()
return model_storage
|