mirror of
https://github.com/optim-enterprises-bv/nDPId.git
synced 2025-10-29 17:32:23 +00:00
Improved Keras Autoencoder hyper parameter.
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
This commit is contained in:
@@ -29,12 +29,12 @@ import nDPIsrvd
|
||||
from nDPIsrvd import nDPIsrvdSocket, TermColor
|
||||
|
||||
INPUT_SIZE = nDPIsrvd.nDPId_PACKETS_PLEN_MAX
|
||||
LATENT_SIZE = 16
|
||||
TRAINING_SIZE = 8192
|
||||
EPOCH_COUNT = 50
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 0.0000001
|
||||
ES_PATIENCE = 10
|
||||
LATENT_SIZE = 8
|
||||
TRAINING_SIZE = 512
|
||||
EPOCH_COUNT = 3
|
||||
BATCH_SIZE = 16
|
||||
LEARNING_RATE = 0.000001
|
||||
ES_PATIENCE = 3
|
||||
PLOT = False
|
||||
PLOT_HISTORY = 100
|
||||
TENSORBOARD = False
|
||||
@@ -164,8 +164,12 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_
|
||||
sys.stderr.flush()
|
||||
encoder, _, autoencoder = get_autoencoder()
|
||||
autoencoder.summary()
|
||||
additional_callbacks = []
|
||||
if TENSORBOARD is True:
|
||||
tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1)
|
||||
additional_callbacks += [tensorboard]
|
||||
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=ES_PATIENCE, restore_best_weights=True, start_from_epoch=0, verbose=0, mode='auto')
|
||||
additional_callbacks += [early_stopping]
|
||||
shared_training_event.clear()
|
||||
|
||||
try:
|
||||
@@ -188,7 +192,7 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_
|
||||
tmp, tmp, epochs=EPOCH_COUNT, batch_size=BATCH_SIZE,
|
||||
validation_split=0.2,
|
||||
shuffle=True,
|
||||
callbacks=[tensorboard, early_stopping]
|
||||
callbacks=[additional_callbacks]
|
||||
)
|
||||
reconstructed_data = autoencoder.predict(tmp)
|
||||
mse = np.mean(np.square(tmp - reconstructed_data))
|
||||
@@ -295,15 +299,15 @@ if __name__ == '__main__':
|
||||
help='Load a pre-trained model file.')
|
||||
argparser.add_argument('--save-model', action='store',
|
||||
help='Save the trained model to a file.')
|
||||
argparser.add_argument('--training-size', action='store', default=TRAINING_SIZE,
|
||||
argparser.add_argument('--training-size', action='store', type=int, default=TRAINING_SIZE,
|
||||
help='Set the amount of captured packets required to start the training phase.')
|
||||
argparser.add_argument('--batch-size', action='store', default=BATCH_SIZE,
|
||||
argparser.add_argument('--batch-size', action='store', type=int, default=BATCH_SIZE,
|
||||
help='Set the batch size used for the training phase.')
|
||||
argparser.add_argument('--learning-rate', action='store', default=LEARNING_RATE,
|
||||
argparser.add_argument('--learning-rate', action='store', type=float, default=LEARNING_RATE,
|
||||
help='Set the (initial) learning rate for the optimizer.')
|
||||
argparser.add_argument('--plot', action='store_true', default=PLOT,
|
||||
help='Show some model metrics using pyplot.')
|
||||
argparser.add_argument('--plot-history', action='store', default=PLOT_HISTORY,
|
||||
argparser.add_argument('--plot-history', action='store', type=int, default=PLOT_HISTORY,
|
||||
help='Set the history size of Line plots. Requires --plot')
|
||||
argparser.add_argument('--tensorboard', action='store_true', default=TENSORBOARD,
|
||||
help='Enable TensorBoard compatible logging callback.')
|
||||
@@ -313,7 +317,7 @@ if __name__ == '__main__':
|
||||
help='Use SGD optimizer instead of Adam.')
|
||||
argparser.add_argument('--use-kldiv', action='store_true', default=VAE_USE_KLDIV,
|
||||
help='Use Kullback-Leibler loss function instead of Mean-Squared-Error.')
|
||||
argparser.add_argument('--patience', action='store', default=ES_PATIENCE,
|
||||
argparser.add_argument('--patience', action='store', type=int, default=ES_PATIENCE,
|
||||
help='Epoch value for EarlyStopping. This value forces VAE fitting to if no improvment achieved.')
|
||||
args = argparser.parse_args()
|
||||
address = nDPIsrvd.validateAddress(args)
|
||||
|
||||
Reference in New Issue
Block a user