Spaces:
Runtime error
Runtime error
import csv | |
import os | |
import modules.textual_inversion.textual_inversion | |
from modules import shared | |
delayed_values = {} | |
def write_loss(log_directory, filename, step, epoch_len, values): | |
if shared.opts.training_write_csv_every == 0: | |
return | |
if (step + 1) % shared.opts.training_write_csv_every != 0: | |
return | |
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True | |
try: | |
with open(os.path.join(log_directory, filename), "a+", newline='') as fout: | |
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) | |
if write_csv_header: | |
csv_writer.writeheader() | |
if log_directory + filename in delayed_values: | |
delayed = delayed_values[log_directory + filename] | |
for step, epoch, epoch_step, values in delayed: | |
csv_writer.writerow({ | |
"step": step, | |
"epoch": epoch, | |
"epoch_step": epoch_step + 1, | |
**values, | |
}) | |
delayed.clear() | |
epoch = step // epoch_len | |
epoch_step = step % epoch_len | |
csv_writer.writerow({ | |
"step": step + 1, | |
"epoch": epoch, | |
"epoch_step": epoch_step + 1, | |
**values, | |
}) | |
except OSError: | |
epoch, epoch_step = divmod(step, epoch_len) | |
if log_directory + filename in delayed_values: | |
delayed_values[log_directory + filename].append((step + 1, epoch, epoch_step, values)) | |
else: | |
delayed_values[log_directory + filename] = [(step+1, epoch, epoch_step, values)] | |
modules.textual_inversion.textual_inversion.write_loss = write_loss |