py-machine-learning: load and save trained models

* added link to a pre-trained model

Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
This commit is contained in:
Toni Uhlig
2022-10-14 08:49:25 +02:00
parent 80f8448834
commit 6292102f93
3 changed files with 58 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import csv
import joblib
import matplotlib.pyplot
import numpy
import os
@@ -177,7 +178,11 @@ def isProtoClass(proto_class, line):
if __name__ == '__main__':
argparser = nDPIsrvd.defaultArgumentParser()
argparser.add_argument('--csv', action='store', required=True,
argparser.add_argument('--load-model', action='store',
help='Load a pre-trained model file.')
argparser.add_argument('--save-model', action='store',
help='Save the trained model to a file.')
argparser.add_argument('--csv', action='store',
help='Input CSV file generated with nDPIsrvd-analysed.')
argparser.add_argument('--proto-class', action='append', required=True,
help='nDPId protocol class of interest used for training and prediction. ' +
@@ -211,6 +216,14 @@ if __name__ == '__main__':
args = argparser.parse_args()
address = nDPIsrvd.validateAddress(args)
if args.csv is None and args.load_model is None:
sys.stderr.write('{}: Either `--csv` or `--load-model` required!\n'.format(sys.argv[0]))
sys.exit(1)
if args.csv is None and args.generate_feature_importance is True:
sys.stderr.write('{}: `--generate-feature-importance` requires `--csv`.\n'.format(sys.argv[0]))
sys.exit(1)
ENABLE_FEATURE_IAT = args.enable_iat
ENABLE_FEATURE_PKTLEN = args.enable_pktlen
ENABLE_FEATURE_DIRS = args.disable_dirs is False
@@ -222,40 +235,50 @@ if __name__ == '__main__':
for i in range(len(args.proto_class)):
args.proto_class[i] = args.proto_class[i].lower()
sys.stderr.write('Learning via CSV..\n')
with open(args.csv, newline='\n') as csvfile:
reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
X = list()
y = list()
if args.load_model is not None:
sys.stderr.write('Loading model from {}\n'.format(args.load_model))
model = joblib.load(args.load_model)
for line in reader:
N_DIRS = len(getFeaturesFromArray(line['directions']))
N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
break
if args.csv is not None:
sys.stderr.write('Learning via CSV..\n')
with open(args.csv, newline='\n') as csvfile:
reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
X = list()
y = list()
for line in reader:
try:
X += getRelevantFeaturesCSV(line)
y += [isProtoClass(args.proto_class, line['proto'])]
except RuntimeError as err:
print('Error: `{}\'\non line: {}'.format(err, line))
for line in reader:
N_DIRS = len(getFeaturesFromArray(line['directions']))
N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
break
sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
for line in reader:
try:
X += getRelevantFeaturesCSV(line)
y += [isProtoClass(args.proto_class, line['proto'])]
except RuntimeError as err:
print('Error: `{}\'\non line: {}'.format(err, line))
model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
class_weight = args.sklearn_class_weight,
n_jobs = args.sklearn_jobs,
n_estimators = args.sklearn_estimators,
verbose = args.sklearn_verbosity,
min_samples_leaf = args.sklearn_min_samples_leaf,
max_features = args.sklearn_max_features
)
sys.stderr.write('Training model..\n')
model.fit(X, y)
sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
if args.generate_feature_importance is True:
sys.stderr.write('Generating feature importance .. this may take some time')
plotPermutatedImportance(model, X, y)
if args.load_model is None:
model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
class_weight = args.sklearn_class_weight,
n_jobs = args.sklearn_jobs,
n_estimators = args.sklearn_estimators,
verbose = args.sklearn_verbosity,
min_samples_leaf = args.sklearn_min_samples_leaf,
max_features = args.sklearn_max_features
)
sys.stderr.write('Training model..\n')
model.fit(X, y)
if args.generate_feature_importance is True:
sys.stderr.write('Generating feature importance .. this may take some time\n')
plotPermutatedImportance(model, X, y)
if args.save_model is not None:
sys.stderr.write('Saving model to {}\n'.format(args.save_model))
joblib.dump(model, args.save_model)
print('Map[*] -> [0]')
for x in range(len(args.proto_class)):