mirror of
https://github.com/optim-enterprises-bv/nDPId.git
synced 2025-11-01 02:37:48 +00:00
py-machine-learning: load and save trained models
* added link to a pre-trained model Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import csv
|
||||
import joblib
|
||||
import matplotlib.pyplot
|
||||
import numpy
|
||||
import os
|
||||
@@ -177,7 +178,11 @@ def isProtoClass(proto_class, line):
|
||||
|
||||
if __name__ == '__main__':
|
||||
argparser = nDPIsrvd.defaultArgumentParser()
|
||||
argparser.add_argument('--csv', action='store', required=True,
|
||||
argparser.add_argument('--load-model', action='store',
|
||||
help='Load a pre-trained model file.')
|
||||
argparser.add_argument('--save-model', action='store',
|
||||
help='Save the trained model to a file.')
|
||||
argparser.add_argument('--csv', action='store',
|
||||
help='Input CSV file generated with nDPIsrvd-analysed.')
|
||||
argparser.add_argument('--proto-class', action='append', required=True,
|
||||
help='nDPId protocol class of interest used for training and prediction. ' +
|
||||
@@ -211,6 +216,14 @@ if __name__ == '__main__':
|
||||
args = argparser.parse_args()
|
||||
address = nDPIsrvd.validateAddress(args)
|
||||
|
||||
if args.csv is None and args.load_model is None:
|
||||
sys.stderr.write('{}: Either `--csv` or `--load-model` required!\n'.format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
if args.csv is None and args.generate_feature_importance is True:
|
||||
sys.stderr.write('{}: `--generate-feature-importance` requires `--csv`.\n'.format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
ENABLE_FEATURE_IAT = args.enable_iat
|
||||
ENABLE_FEATURE_PKTLEN = args.enable_pktlen
|
||||
ENABLE_FEATURE_DIRS = args.disable_dirs is False
|
||||
@@ -222,40 +235,50 @@ if __name__ == '__main__':
|
||||
for i in range(len(args.proto_class)):
|
||||
args.proto_class[i] = args.proto_class[i].lower()
|
||||
|
||||
sys.stderr.write('Learning via CSV..\n')
|
||||
with open(args.csv, newline='\n') as csvfile:
|
||||
reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
|
||||
X = list()
|
||||
y = list()
|
||||
if args.load_model is not None:
|
||||
sys.stderr.write('Loading model from {}\n'.format(args.load_model))
|
||||
model = joblib.load(args.load_model)
|
||||
|
||||
for line in reader:
|
||||
N_DIRS = len(getFeaturesFromArray(line['directions']))
|
||||
N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
|
||||
break
|
||||
if args.csv is not None:
|
||||
sys.stderr.write('Learning via CSV..\n')
|
||||
with open(args.csv, newline='\n') as csvfile:
|
||||
reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
|
||||
X = list()
|
||||
y = list()
|
||||
|
||||
for line in reader:
|
||||
try:
|
||||
X += getRelevantFeaturesCSV(line)
|
||||
y += [isProtoClass(args.proto_class, line['proto'])]
|
||||
except RuntimeError as err:
|
||||
print('Error: `{}\'\non line: {}'.format(err, line))
|
||||
for line in reader:
|
||||
N_DIRS = len(getFeaturesFromArray(line['directions']))
|
||||
N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
|
||||
break
|
||||
|
||||
sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
|
||||
for line in reader:
|
||||
try:
|
||||
X += getRelevantFeaturesCSV(line)
|
||||
y += [isProtoClass(args.proto_class, line['proto'])]
|
||||
except RuntimeError as err:
|
||||
print('Error: `{}\'\non line: {}'.format(err, line))
|
||||
|
||||
model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
|
||||
class_weight = args.sklearn_class_weight,
|
||||
n_jobs = args.sklearn_jobs,
|
||||
n_estimators = args.sklearn_estimators,
|
||||
verbose = args.sklearn_verbosity,
|
||||
min_samples_leaf = args.sklearn_min_samples_leaf,
|
||||
max_features = args.sklearn_max_features
|
||||
)
|
||||
sys.stderr.write('Training model..\n')
|
||||
model.fit(X, y)
|
||||
sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
|
||||
|
||||
if args.generate_feature_importance is True:
|
||||
sys.stderr.write('Generating feature importance .. this may take some time')
|
||||
plotPermutatedImportance(model, X, y)
|
||||
if args.load_model is None:
|
||||
model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
|
||||
class_weight = args.sklearn_class_weight,
|
||||
n_jobs = args.sklearn_jobs,
|
||||
n_estimators = args.sklearn_estimators,
|
||||
verbose = args.sklearn_verbosity,
|
||||
min_samples_leaf = args.sklearn_min_samples_leaf,
|
||||
max_features = args.sklearn_max_features
|
||||
)
|
||||
sys.stderr.write('Training model..\n')
|
||||
model.fit(X, y)
|
||||
|
||||
if args.generate_feature_importance is True:
|
||||
sys.stderr.write('Generating feature importance .. this may take some time\n')
|
||||
plotPermutatedImportance(model, X, y)
|
||||
|
||||
if args.save_model is not None:
|
||||
sys.stderr.write('Saving model to {}\n'.format(args.save_model))
|
||||
joblib.dump(model, args.save_model)
|
||||
|
||||
print('Map[*] -> [0]')
|
||||
for x in range(len(args.proto_class)):
|
||||
|
||||
Reference in New Issue
Block a user