OpenSLU / common /saver.py
LightChen2333's picture
Upload 78 files
223340a
raw
history blame
2.97 kB
'''
Author: Qiguang Chen
LastEditors: Qiguang Chen
Date: 2023-02-12 22:23:58
LastEditTime: 2023-02-19 14:14:56
Description:
'''
import json
import os
import queue
import shutil
import torch
import dill
from common import utils
class Saver():
def __init__(self, config, start_time=None) -> None:
self.config = config
if self.config.get("save_dir"):
self.model_save_dir = self.config["save_dir"]
else:
if not os.path.exists("save/"):
os.mkdir("save/")
self.model_save_dir = "save/" + start_time
if not os.path.exists(self.model_save_dir):
os.mkdir(self.model_save_dir)
save_mode = config.get("save_mode")
self.save_mode = save_mode if save_mode is not None else "save-by-eval"
max_save_num = self.config.get("max_save_num")
self.max_save_num = max_save_num if max_save_num is not None else 1
self.save_pool = queue.Queue(maxsize=max_save_num)
def save_tokenizer(self, tokenizer):
with open(os.path.join(self.model_save_dir, "tokenizer.pkl"), 'wb') as f:
dill.dump(tokenizer, f)
def save_label(self, intent_list, slot_list):
utils.save_json(os.path.join(self.model_save_dir, "label.json"), {"intent": intent_list, "slot": slot_list})
def save_model(self, model, train_state, accelerator=None):
step = train_state["step"]
if self.max_save_num != 1:
model_save_dir =os.path.join(self.model_save_dir, str(step))
if self.save_pool.full():
delete_dir = self.save_pool.get()
shutil.rmtree(delete_dir)
self.save_pool.put(model_save_dir)
else:
self.save_pool.put(model_save_dir)
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
else:
model_save_dir = self.model_save_dir
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
if accelerator is None:
torch.save(model, os.path.join(model_save_dir, "model.pkl"))
torch.save(train_state, os.path.join(model_save_dir, "train_state.pkl"), pickle_module=dill)
else:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model, os.path.join(model_save_dir, "model.pkl"))
accelerator.save_state(output_dir=model_save_dir)
def auto_save_step(self, model, train_state, accelerator=None):
step = train_state["step"]
if self.save_mode == "save-by-step" and step % self.config.get("save_step")==0 and step != 0:
self.save_model(model, train_state, accelerator)
return True
else:
return False
def save_output(self, outputs, dataset):
outputs.save(self.model_save_dir, dataset)