File size: 22,906 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
from ..basic_train import Learner
from ..basic_data import DatasetType, DataBunch
from ..vision import Image
from ..vision.gan import GANLearner
from ..callbacks import LearnerCallback
from ..core import *
from ..torch_core import *
from threading import Thread, Event
from time import sleep
from queue import Queue
import statistics
import torchvision.utils as vutils
from abc import ABC
#This is an optional dependency in fastai.  Must install separately.
try: from tensorboardX import SummaryWriter
except: print("To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")

__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']

#---Example usage (applies to any of the callbacks)--- 
# proj_id = 'Colorize'
# tboard_path = Path('data/tensorboard/' + proj_id)
# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))

class LearnerTensorboardWriter(LearnerCallback):
    "Broadly useful callback for Learners that writes to Tensorboard.  Writes model histograms, losses/metrics, and gradient stats."
    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
        super().__init__(learn=learn)
        self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters  = base_dir,name,loss_iters,hist_iters,stats_iters
        log_dir = base_dir/name
        self.tbwriter = SummaryWriter(str(log_dir))
        self.hist_writer = HistogramTBWriter()
        self.stats_writer = ModelStatsTBWriter()
        #self.graph_writer = GraphTBWriter()
        self.data = None
        self.metrics_root = '/metrics/'
        self._update_batches_if_needed()

    def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:
        "Retrieves new batch of DatasetType, and detaches it."
        return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)

    def _update_batches_if_needed(self)->None:
        "one_batch function is extremely slow with large datasets.  This is caching the result as an optimization."
        if self.learn.data.valid_dl is None: return # Running learning rate finder, so return
        update_batches = self.data is not self.learn.data
        if not update_batches: return
        self.data = self.learn.data
        self.trn_batch = self._get_new_batch(ds_type=DatasetType.Train)
        self.val_batch = self._get_new_batch(ds_type=DatasetType.Valid)

    def _write_model_stats(self, iteration:int)->None:
        "Writes gradient statistics to Tensorboard."
        self.stats_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)

    def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
        "Writes training loss to Tensorboard."
        scalar_value = to_np(last_loss)
        tag = self.metrics_root + 'train_loss'
        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)

    def _write_weight_histograms(self, iteration:int)->None:
        "Writes model weight histograms to Tensorboard."
        self.hist_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)

    def _write_scalar(self, name:str, scalar_value, iteration:int)->None:
        "Writes single scalar value to Tensorboard."
        tag = self.metrics_root + name
        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)

    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?
    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2)->None:
        "Writes training metrics to Tensorboard."
        recorder = self.learn.recorder
        for i, name in enumerate(recorder.names[start_idx:]):
            if last_metrics is None or len(last_metrics) < i+1: return
            scalar_value = last_metrics[i]
            self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)

    def on_train_begin(self, **kwargs: Any) -> None:
        #self.graph_writer.write(model=self.learn.model, tbwriter=self.tbwriter,
                                #input_to_model=next(iter(self.learn.data.dl(DatasetType.Single)))[0])
        return

    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs)->None:
        "Callback function that writes batch end appropriate data to Tensorboard."
        if iteration == 0: return
        self._update_batches_if_needed()
        if iteration % self.loss_iters == 0: self._write_training_loss(iteration=iteration, last_loss=last_loss)
        if iteration % self.hist_iters == 0: self._write_weight_histograms(iteration=iteration)

    # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
    def on_backward_end(self, iteration:int, **kwargs)->None:
        "Callback function that writes backward end appropriate data to Tensorboard."
        if iteration == 0: return
        self._update_batches_if_needed()
        if iteration % self.stats_iters == 0: self._write_model_stats(iteration=iteration)

    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
        "Callback function that writes epoch end appropriate data to Tensorboard."
        self._write_metrics(iteration=iteration, last_metrics=last_metrics)

# TODO:  We're overriding almost everything here.  Seems like a good idea to question that ("is a" vs "has a")
class GANTensorboardWriter(LearnerTensorboardWriter):
    "Callback for GANLearners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes."
    def __init__(self, learn:GANLearner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, 
                stats_iters:int=100, visual_iters:int=100):
        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, stats_iters=stats_iters)
        self.visual_iters = visual_iters
        self.img_gen_vis = ImageTBWriter()
        self.gen_stats_updated = True
        self.crit_stats_updated = True

    def _write_weight_histograms(self, iteration:int)->None:
        "Writes model weight histograms to Tensorboard."
        generator, critic = self.learn.gan_trainer.generator, self.learn.gan_trainer.critic
        self.hist_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
        self.hist_writer.write(model=critic,    iteration=iteration, tbwriter=self.tbwriter, name='critic')

    def _write_gen_model_stats(self, iteration:int)->None:
        "Writes gradient statistics for generator to Tensorboard."
        generator = self.learn.gan_trainer.generator
        self.stats_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
        self.gen_stats_updated = True

    def _write_critic_model_stats(self, iteration:int)->None:
        "Writes gradient statistics for critic to Tensorboard."
        critic = self.learn.gan_trainer.critic
        self.stats_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
        self.crit_stats_updated = True

    def _write_model_stats(self, iteration:int)->None:
        "Writes gradient statistics to Tensorboard."
        # We don't want to write stats when model is not iterated on and hence has zeroed out gradients
        gen_mode = self.learn.gan_trainer.gen_mode
        if gen_mode and not self.gen_stats_updated: self._write_gen_model_stats(iteration=iteration)
        if not gen_mode and not self.crit_stats_updated: self._write_critic_model_stats(iteration=iteration)

    def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
        "Writes training loss to Tensorboard."
        recorder = self.learn.gan_trainer.recorder
        if len(recorder.losses) == 0: return
        scalar_value = to_np((recorder.losses[-1:])[0])
        tag = self.metrics_root + 'train_loss'
        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)

    def _write_images(self, iteration:int)->None:
        "Writes model generated, original and real images to Tensorboard."
        trainer = self.learn.gan_trainer
        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?
        gen_mode = trainer.gen_mode
        try:
            trainer.switch(gen_mode=True)
            self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, 
                                    iteration=iteration, tbwriter=self.tbwriter)
        finally: trainer.switch(gen_mode=gen_mode)

    def on_batch_end(self, iteration:int, **kwargs)->None:
        "Callback function that writes batch end appropriate data to Tensorboard."
        super().on_batch_end(iteration=iteration, **kwargs)
        if iteration == 0: return
        if iteration % self.visual_iters == 0: self._write_images(iteration=iteration)

    def on_backward_end(self, iteration:int, **kwargs)->None:
        "Callback function that writes backward end appropriate data to Tensorboard."
        if iteration == 0: return
        self._update_batches_if_needed()
        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. 
        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
        if iteration % self.stats_iters == 0: self.gen_stats_updated, self.crit_stats_updated = False, False
        if not (self.gen_stats_updated and self.crit_stats_updated): self._write_model_stats(iteration=iteration)

class ImageGenTensorboardWriter(LearnerTensorboardWriter):
    "Callback for non-GAN image generating Learners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes."
    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100, 
                 visual_iters:int=100):
        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, 
                         stats_iters=stats_iters)
        self.visual_iters = visual_iters
        self.img_gen_vis = ImageTBWriter()

    def _write_images(self, iteration:int)->None:
        "Writes model generated, original and real images to Tensorboard"
        self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, iteration=iteration, 
                               tbwriter=self.tbwriter)

    def on_batch_end(self, iteration:int, **kwargs)->None:
        "Callback function that writes batch end appropriate data to Tensorboard."
        super().on_batch_end(iteration=iteration, **kwargs)
        if iteration == 0: return
        if iteration % self.visual_iters == 0: 
            self._write_images(iteration=iteration)

class TBWriteRequest(ABC):
    "A request object for Tensorboard writes.  Useful for queuing up and executing asynchronous writes."
    def __init__(self, tbwriter: SummaryWriter, iteration:int):
        super().__init__()
        self.tbwriter = tbwriter
        self.iteration = iteration

    @abstractmethod
    def write(self)->None: pass   

# SummaryWriter writes tend to block quite a bit.  This gets around that and greatly boosts performance.
# Not all tensorboard writes are using this- just the ones that take a long time.  Note that the 
# SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard, 
# so writes done outside of this async loop should be fine.
class AsyncTBWriter():
    "Callback for GANLearners that writes to Tensorboard.  Extends LearnerTensorboardWriter and adds output image writes."
    def __init__(self):
        super().__init__()
        self.stop_request = Event()
        self.queue = Queue()
        self.thread = Thread(target=self._queue_processor, daemon=True)
        self.thread.start()

    def request_write(self, request: TBWriteRequest)->None:
        "Queues up an asynchronous write request to Tensorboard."
        if self.stop_request.isSet(): return
        self.queue.put(request)

    def _queue_processor(self)->None:
        "Processes queued up write requests asynchronously to Tensorboard."
        while not self.stop_request.isSet():
            while not self.queue.empty():
                if self.stop_request.isSet(): return
                request = self.queue.get()
                request.write()
            sleep(0.2)

    #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own 
    # upon program exit, due to being a daemon.  So using this is probably unecessary.
    def close(self)->None:
        "Stops asynchronous request queue processing thread."
        self.stop_request.set()
        self.thread.join()

    # Nothing to do, thread already started.  Could start thread here to enforce use of context manager 
    # (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
    def __enter__(self): pass

    def __exit__(self, exc_type, exc_value, traceback): self.close()

asyncTBWriter = AsyncTBWriter() 

class ModelImageSet():
    "Convenience object that holds the original, real(target) and generated versions of a single image fed to a model."
    @staticmethod
    def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
        "Factory method to convert a batch of model images to a list of ModelImageSet."
        image_sets = []
        x,y = batch[0],batch[1]
        preds=[]
        preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)  
        for orig_px, real_px, gen in zip(x,y,preds):
            orig, real = Image(px=orig_px), Image(px=real_px)
            image_set = ModelImageSet(orig=orig, real=real, gen=gen)
            image_sets.append(image_set)
        return image_sets  

    def __init__(self, orig:Image, real:Image, gen:Image): self.orig, self.real, self.gen = orig, real, gen

class HistogramTBRequest(TBWriteRequest):
    "Request object for model histogram writes to Tensorboard."
    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
        super().__init__(tbwriter=tbwriter, iteration=iteration)
        self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
        self.name = name

    def _write_histogram(self, param_name:str, values)->None:
        "Writes single model histogram to Tensorboard."
        tag = self.name + '/weights/' + param_name
        self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)

    def write(self)->None:
        "Writes model histograms to Tensorboard."
        for param_name, values in self.params: self._write_histogram(param_name=param_name, values=values)

#If this isn't done async then this is sloooooow
class HistogramTBWriter():
    "Writes model histograms to Tensorboard."
    def __init__(self): super().__init__()

    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model')->None:
        "Writes model histograms to Tensorboard."
        request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
        asyncTBWriter.request_write(request)

class ModelStatsTBRequest(TBWriteRequest):
    "Request object for model gradient statistics writes to Tensorboard."
    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
        super().__init__(tbwriter=tbwriter, iteration=iteration)
        self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
        self.name = name

    def _add_gradient_scalar(self, name:str, scalar_value)->None:
        "Writes a single scalar value for a gradient statistic to Tensorboard."
        tag = self.name + '/gradients/' + name
        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=self.iteration)

    def _write_avg_norm(self, norms:[])->None:
        "Writes the average norm of the gradients to Tensorboard."
        avg_norm = sum(norms)/len(self.gradients)
        self._add_gradient_scalar('avg_norm', scalar_value=avg_norm)

    def _write_median_norm(self, norms:[])->None:
        "Writes the median norm of the gradients to Tensorboard."
        median_norm = statistics.median(norms)
        self._add_gradient_scalar('median_norm', scalar_value=median_norm)

    def _write_max_norm(self, norms:[])->None:
        "Writes the maximum norm of the gradients to Tensorboard."
        max_norm = max(norms)
        self._add_gradient_scalar('max_norm', scalar_value=max_norm)

    def _write_min_norm(self, norms:[])->None:
        "Writes the minimum norm of the gradients to Tensorboard."
        min_norm = min(norms)
        self._add_gradient_scalar('min_norm', scalar_value=min_norm)

    def _write_num_zeros(self)->None:
        "Writes the number of zeroes in the gradients to Tensorboard."
        gradient_nps = [to_np(x.data) for x in self.gradients]
        num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
        self._add_gradient_scalar('num_zeros', scalar_value=num_zeros)

    def _write_avg_gradient(self)->None:
        "Writes the average of the gradients to Tensorboard."
        avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
        self._add_gradient_scalar('avg_gradient', scalar_value=avg_gradient)

    def _write_median_gradient(self)->None:
        "Writes the median of the gradients to Tensorboard."
        median_gradient = statistics.median(x.data.median() for x in self.gradients)
        self._add_gradient_scalar('median_gradient', scalar_value=median_gradient)

    def _write_max_gradient(self)->None:
        "Writes the maximum of the gradients to Tensorboard."
        max_gradient = max(x.data.max() for x in self.gradients)
        self._add_gradient_scalar('max_gradient', scalar_value=max_gradient)

    def _write_min_gradient(self)->None:
        "Writes the minimum of the gradients to Tensorboard."
        min_gradient = min(x.data.min() for x in self.gradients)
        self._add_gradient_scalar('min_gradient', scalar_value=min_gradient)

    def write(self)->None:
        "Writes model gradient statistics to Tensorboard."
        if len(self.gradients) == 0: return
        norms = [x.data.norm() for x in self.gradients]
        self._write_avg_norm(norms=norms)
        self._write_median_norm(norms=norms)
        self._write_max_norm(norms=norms)
        self._write_min_norm(norms=norms)
        self._write_num_zeros()
        self._write_avg_gradient()
        self._write_median_gradient()
        self._write_max_gradient()
        self._write_min_gradient()

class ModelStatsTBWriter():
    "Writes model gradient statistics to Tensorboard."
    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats')->None:
        "Writes model gradient statistics to Tensorboard."
        request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
        asyncTBWriter.request_write(request)

class ImageTBRequest(TBWriteRequest):
    "Request object for model image output writes to Tensorboard."
    def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
        super().__init__(tbwriter=tbwriter, iteration=iteration)
        self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
        self.ds_type = ds_type

    def _write_images(self, name:str, images:[Tensor])->None:
        "Writes list of images as tensors to Tensorboard."
        tag = self.ds_type.name + ' ' + name
        self.tbwriter.add_image(tag=tag, img_tensor=vutils.make_grid(images, normalize=True), global_step=self.iteration)

    def _get_image_tensors(self)->([Tensor], [Tensor], [Tensor]):
        "Gets list of image tensors from lists of Image objects, as a tuple of original, generated and real(target) images."
        orig_images, gen_images, real_images = [], [], []
        for image_set in self.image_sets:
            orig_images.append(image_set.orig.px)
            gen_images.append(image_set.gen.px)
            real_images.append(image_set.real.px) 
        return orig_images, gen_images, real_images  

    def write(self)->None:
        "Writes original, generated and real(target) images to Tensorboard."
        orig_images, gen_images, real_images = self._get_image_tensors()
        self._write_images(name='orig images', images=orig_images)
        self._write_images(name='gen images',  images=gen_images)
        self._write_images(name='real images', images=real_images)

#If this isn't done async then this is noticeably slower
class ImageTBWriter():
    "Writes model image output to Tensorboard."
    def __init__(self): super().__init__()

    def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter)->None:
        "Writes training and validation batch images to Tensorboard."
        self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Valid)
        self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Train)

    def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType)->None:
        "Writes batch images of specified DatasetType to Tensorboard."
        request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
        asyncTBWriter.request_write(request)

class GraphTBRequest(TBWriteRequest):
    "Request object for model histogram writes to Tensorboard."
    def __init__(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor):
        super().__init__(tbwriter=tbwriter, iteration=0)
        self.model,self.input_to_model = model,input_to_model

    def write(self)->None:
        "Writes single model graph to Tensorboard."
        self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model)

class GraphTBWriter():
    "Writes model network graph to Tensorboard."
    def write(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor)->None:
        "Writes model graph to Tensorboard."
        request = GraphTBRequest(model=model, tbwriter=tbwriter, input_to_model=input_to_model)
        asyncTBWriter.request_write(request)