Spaces:
Runtime error
Runtime error
File size: 1,826 Bytes
ef9fd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import csv
import os
import modules.textual_inversion.textual_inversion
from modules import shared
delayed_values = {}
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if (step + 1) % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
try:
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
if write_csv_header:
csv_writer.writeheader()
if log_directory + filename in delayed_values:
delayed = delayed_values[log_directory + filename]
for step, epoch, epoch_step, values in delayed:
csv_writer.writerow({
"step": step,
"epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
delayed.clear()
epoch = step // epoch_len
epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
"epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
except OSError:
epoch, epoch_step = divmod(step, epoch_len)
if log_directory + filename in delayed_values:
delayed_values[log_directory + filename].append((step + 1, epoch, epoch_step, values))
else:
delayed_values[log_directory + filename] = [(step+1, epoch, epoch_step, values)]
modules.textual_inversion.textual_inversion.write_loss = write_loss |