From 4d378c17e074592bada2d64225229940625b9b04 Mon Sep 17 00:00:00 2001 From: Samira <48542514+SquaredS44@users.noreply.github.com> Date: Fri, 17 Jul 2020 17:37:48 +0200 Subject: [PATCH] Add support for multi-class random forest. --- conifer/converters/sklearn.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/conifer/converters/sklearn.py b/conifer/converters/sklearn.py index 65496612..42782fbe 100644 --- a/conifer/converters/sklearn.py +++ b/conifer/converters/sklearn.py @@ -26,17 +26,29 @@ def convert_random_forest(bdt): 'n_classes' : bdt.n_classes_, 'trees' : [], 'init_predict' : [0] * bdt.n_classes_, 'norm' : 1} - for tree in bdt.estimators_: - treesl = [] - tree = treeToDict(bdt, tree.tree_) - tree = padTree(ensembleDict, tree) - # Random forest takes the mean prediction, do that here - # Also need to scale the values by their sum - v = np.array(tree['value']) - tree['value'] = (v / v.sum(axis=2)[:, np.newaxis] / bdt.n_estimators)[:,0,0].tolist() - treesl.append(tree) - ensembleDict['trees'].append(treesl) - + if bdt.n_classes_ == 2: + for tree in bdt.estimators_: + treesl = [] + tree = treeToDict(bdt, tree.tree_) + tree = padTree(ensembleDict, tree) + # Random forest takes the mean prediction, do that here + # Also need to scale the values by their sum + v = np.array(tree['value']) + tree['value'] = (v / v.sum(axis=2)[:, np.newaxis] / bdt.n_estimators)[:,0,0].tolist() + treesl.append(tree) + ensembleDict['trees'].append(treesl) + else: + for tree in bdt.estimators_: + trees_list = [] + for c in range(bdt.n_classes_): + tree_dict = treeToDict(bdt, tree.tree_) + padded_tree = padTree(ensembleDict, tree_dict) + # Random forest takes the mean prediction, do that here + # Also need to scale the values by their sum + v = np.array(padded_tree['value']) + padded_tree['value'] = (v / v.sum(axis=2)[:, np.newaxis] / bdt.n_estimators)[:, 0, c].tolist() + trees_list.append(padded_tree) + ensembleDict['trees'].append(trees_list) return ensembleDict def convert(bdt):