Test2 / net.py
AlterM's picture
Duplicate from RisticksAI/ProfNet4
2f500b5
raw
history blame contribute delete
No virus
1.71 kB
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.preprocessing.sequence import pad_sequences
class SetLine:
def __init__(self, name, inp):
self.name = name
self.inp = embedding.getvec(name)
with open("set.pckl", "rb") as f:
dset = pickle.load(f)
sequences = [[x.inp for x in dset],]
vec_size = sequences[0][0].__len__()
window_size = 3
# Generate sliding windows and corresponding target vectors
sliding_windows = []
target_vectors = []
for seq in sequences:
for i in range(-window_size, len(seq) - window_size-1):
window = seq[i:i + window_size]
target = seq[i + window_size]
sliding_windows.append(np.array(window))
target_vectors.append(target)
# Pad sequences to a fixed length
max_seq_length = max(len(window) for window in sliding_windows)
padded_windows = pad_sequences(sliding_windows, maxlen=max_seq_length, padding='pre')
model = Sequential()
model.add(Input(shape=(max_seq_length, vec_size)))
model.add(Flatten())
model.add(Dense(512, activation='tanh'))
model.add(Dense(256, activation='tanh'))
model.add(Dense(512, activation='relu'))
model.add(Dense(300, activation='tanh'))
model.add(Dense(vec_size, activation='linear'))
# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss=MeanSquaredError(), metrics=['accuracy'])
# Train the model
X = np.array(padded_windows)
y = np.array(target_vectors)
model.fit(X, y, epochs=128, batch_size=4)
model.save("net.h5")