| |
| |
| |
|
|
| import TextStoryAI.model |
| import os |
|
|
| def load_model(load_weights=True, debug=False): |
| m = model.create_model() |
| if load_weights: |
| loadWeights(m, debug=debug) |
| return m |
|
|
| |
|
|
| import tensorflow as tf |
| import h5py |
|
|
| def loadWeights(model, filename=os.path.join(__package__, "weights.h5"), debug=False): |
| with h5py.File(filename, 'r') as f: |
| |
| for g in f: |
| if isinstance(f[g], h5py.Group): |
| group = f[g] |
| layerName = group.attrs['Name'] |
| numVars = int(group.attrs['NumVars']) |
| if debug: |
| print("layerName:", layerName) |
| print(" numVars:", numVars) |
| |
| layerIdx = layerNum(model, layerName) |
| layer = model.layers[layerIdx] |
| if debug: |
| print(" layerIdx=", layerIdx) |
| |
| |
| weightList = [0]*numVars |
| for d in group: |
| dataset = group[d] |
| varName = dataset.attrs['Name'] |
| shp = intList(dataset.attrs['Shape']) |
| weightNum = int(dataset.attrs['WeightNum']) |
| |
| if debug: |
| print(" varName:", varName) |
| print(" shp:", shp) |
| print(" weightNum:", weightNum) |
| weightList[weightNum] = tf.constant(dataset[()], shape=shp) |
| |
| for w in range(numVars): |
| if debug: |
| print("Copying variable of shape:") |
| print(weightList[w].shape) |
| layer.variables[w].assign(weightList[w]) |
| if debug: |
| print("Assignment successful.") |
| print("Set variable value:") |
| print(layer.variables[w]) |
| |
| if hasattr(layer, 'finalize_state'): |
| layer.finalize_state() |
|
|
| def layerNum(model, layerName): |
| |
| layers = model.layers |
| for i in range(len(layers)): |
| if layerName==layers[i].name: |
| return i |
| print("") |
| print("WEIGHT LOADING FAILED. MODEL DOES NOT CONTAIN LAYER WITH NAME: ", layerName) |
| print("") |
| return -1 |
|
|
| def intList(myList): |
| |
| return list(map(int, myList)) |
|
|
|
|