yangheng's picture
init
9842c28

A newer version of the Gradio SDK is available: 4.36.1

Upgrade

Resume & Use Model Check Points

This folder contains check points for models and their weights. They are generated from PyTorch's pickle.

Model specifications are in each folder's ReadME.

Pickle names with "model" contain the entire models, and they can be used as an freeze module by calling the "forward_checkpoint" function to generate images.

Example:

import torch
# No need to reconstruct the model 
model = torch.load("./DCSCN/DCSCN_model_387epos_L12_noise_1.pt")
x = torch.randn((1,3,10,10)), torch.randn((1,3,20,20))
out = model.forward_checkpoint(a)

Pickle names with "weights" are model weights, and they are named dictionaries.

Example:

model = DCSCN(*)    # the setting must be the same to load check points weights. 
model.load_state_dict(torch.load("./DCSCN/DCSCN_weights_387epos_L12_noise_1.pt"))
# then you can resume the model training 

Model check poins in Upconv_7 and vgg_7 are from waifu2x's repo. To load weights into a model, please use load_pre_train_weights function.

Example:

model = UpConv_7()
model.load_pre_train_weights(json_file=...)
# then the model is ready to use