Spaces:
Build error
Build error
import tensorflow as tf | |
from keras.layers import Input, Dense, Dropout | |
from keras.models import Model | |
from keras.losses import binary_crossentropy | |
from load_data import read_data | |
import joblib | |
import numpy as np | |
KL = tf.keras.layers | |
def perceptual_label_predictor(): | |
"""Assemble and return the perceptual_label_predictor.""" | |
mini_input = Input((20,)) | |
p = Dense(20, activation='relu')(mini_input) | |
p = Dropout(0.2)(p) | |
p = Dense(16, activation='relu')(p) | |
p = Dropout(0.2)(p) | |
p = Dense(5, activation='sigmoid')(p) | |
style_predictor = Model(mini_input, p) | |
style_predictor.summary() | |
return style_predictor | |
def train_perceptual_label_predictor(perceptual_label_predictor, encoder): | |
"""Train the perceptual_label_predictor. (Including data loading.)""" | |
Input_synthetic = read_data("./data/labeled_dataset/synthetic_data") | |
Input_AU = read_data("./data/external_data/ARTURIA_data")[:100] | |
AU_labels = joblib.load("./data/labeled_dataset/ARTURIA_labels") | |
synth_labels = joblib.load("./data/labeled_dataset/synthetic_data_labels") | |
AU_encode = encoder.predict(Input_AU)[0] | |
Synth_encode = encoder.predict(Input_synthetic)[0] | |
perceptual_label_predictor.compile(optimizer='adam', loss=binary_crossentropy) | |
perceptual_label_predictor.fit(np.vstack([AU_encode, Synth_encode]), np.vstack([AU_labels, synth_labels]), epochs=140, validation_split=0.05, batch_size=16) | |
perceptual_label_predictor.save(f"./models/new_trained_models/perceptual_label_predictor.h5") | |