|
from typing import TYPE_CHECKING |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
if TYPE_CHECKING: |
|
|
|
|
|
|
|
|
|
from ding.framework import Parallel |
|
|
|
|
|
class DistributedWriter(SummaryWriter): |
|
""" |
|
Overview: |
|
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. |
|
The best way is to use it in conjunction with the ``router`` to take advantage of the message \ |
|
and event components of the router (see ``writer.plugin``). |
|
Interfaces: |
|
``get_instance``, ``plugin``, ``initialize``, ``__del__`` |
|
""" |
|
root = None |
|
|
|
def __init__(self, *args, **kwargs): |
|
""" |
|
Overview: |
|
Initialize the DistributedWriter object. |
|
Arguments: |
|
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ |
|
SummaryWriter. |
|
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ |
|
SummaryWriter. |
|
""" |
|
|
|
self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True |
|
|
|
|
|
|
|
kwargs["write_to_disk"] = False |
|
super().__init__(*args, **kwargs) |
|
self._in_parallel = False |
|
self._router = None |
|
self._is_writer = False |
|
self._lazy_initialized = False |
|
|
|
@classmethod |
|
def get_instance(cls, *args, **kwargs) -> "DistributedWriter": |
|
""" |
|
Overview: |
|
Get instance and set the root level instance on the first called. If args and kwargs is none, |
|
this method will return root instance. |
|
Arguments: |
|
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ |
|
SummaryWriter. |
|
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ |
|
SummaryWriter. |
|
""" |
|
if args or kwargs: |
|
ins = cls(*args, **kwargs) |
|
if cls.root is None: |
|
cls.root = ins |
|
return ins |
|
else: |
|
return cls.root |
|
|
|
def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": |
|
""" |
|
Overview: |
|
Plugin ``router``, so when using this writer with active router, it will automatically send requests\ |
|
to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ |
|
and write them into one file. |
|
Arguments: |
|
- router (:obj:`Parallel`): The router to be plugged in. |
|
- is_writer (:obj:`bool`): Whether this writer is the main writer. |
|
Examples: |
|
>>> DistributedWriter().plugin(router, is_writer=True) |
|
""" |
|
if router.is_active: |
|
self._in_parallel = True |
|
self._router = router |
|
self._is_writer = is_writer |
|
if is_writer: |
|
self.initialize() |
|
self._lazy_initialized = True |
|
router.on("distributed_writer", self._on_distributed_writer) |
|
return self |
|
|
|
def _on_distributed_writer(self, fn_name: str, *args, **kwargs): |
|
""" |
|
Overview: |
|
This method is called when the router receives a request to write data. |
|
Arguments: |
|
- fn_name (:obj:`str`): The name of the function to be called. |
|
- args (:obj:`Tuple`): The arguments passed to the function to be called. |
|
- kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. |
|
""" |
|
|
|
if self._is_writer: |
|
getattr(self, fn_name)(*args, **kwargs) |
|
|
|
def initialize(self): |
|
""" |
|
Overview: |
|
Initialize the file writer. |
|
""" |
|
self.close() |
|
self._write_to_disk = self._default_writer_to_disk |
|
self._get_file_writer() |
|
self._lazy_initialized = True |
|
|
|
def __del__(self): |
|
""" |
|
Overview: |
|
Close the file writer. |
|
""" |
|
self.close() |
|
|
|
|
|
def enable_parallel(fn_name, fn): |
|
""" |
|
Overview: |
|
Decorator to enable parallel writing. |
|
Arguments: |
|
- fn_name (:obj:`str`): The name of the function to be called. |
|
- fn (:obj:`Callable`): The function to be called. |
|
""" |
|
|
|
def _parallel_fn(self: DistributedWriter, *args, **kwargs): |
|
if not self._lazy_initialized: |
|
self.initialize() |
|
if self._in_parallel and not self._is_writer: |
|
self._router.emit("distributed_writer", fn_name, *args, **kwargs) |
|
else: |
|
fn(self, *args, **kwargs) |
|
|
|
return _parallel_fn |
|
|
|
|
|
ready_to_parallel_fns = [ |
|
'add_audio', |
|
'add_custom_scalars', |
|
'add_custom_scalars_marginchart', |
|
'add_custom_scalars_multilinechart', |
|
'add_embedding', |
|
'add_figure', |
|
'add_graph', |
|
'add_graph_deprecated', |
|
'add_histogram', |
|
'add_histogram_raw', |
|
'add_hparams', |
|
'add_image', |
|
'add_image_with_boxes', |
|
'add_images', |
|
'add_mesh', |
|
'add_onnx_graph', |
|
'add_openvino_graph', |
|
'add_pr_curve', |
|
'add_pr_curve_raw', |
|
'add_scalar', |
|
'add_scalars', |
|
'add_text', |
|
'add_video', |
|
] |
|
for fn_name in ready_to_parallel_fns: |
|
if hasattr(DistributedWriter, fn_name): |
|
setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) |
|
|
|
|
|
|
|
|
|
distributed_writer = DistributedWriter() |
|
|