Source code for pml.supervised.decision_trees.id3

# Copyright (C) 2012 David Rusk
#
# Permission is hereby granted, free of charge, to any person obtaining a copy 
# of this software and associated documentation files (the "Software"), to 
# deal in the Software without restriction, including without limitation the 
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
# sell copies of the Software, and to permit persons to whom the Software is 
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in 
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
# IN THE SOFTWARE.
"""
Implements the ID3 decision tree algorithm.

@author: drusk
"""

from pml.supervised.decision_trees.trees import Node, Tree
from pml.tools.info_theory import info_gain
from pml.utils.collection_utils import (get_key_with_highest_value, 
                                        get_most_common)

[docs]def build_tree(dataset): """ Builds the decision tree for a data set using the ID3 algorithm. Args: dataset: model.DataSet The data for which the decision tree will be built. Return: tree: Tree The decision tree that was built. """ return Tree(_build_tree_recursively(dataset))
def _build_tree_recursively(dataset): """ Private function used to build the decision tree in a recursive fashion. Args: dataset: model.DataSet The data at the current level of the tree. Lower levels of the tree have filtered subsets of the original data set. Returns: current_root: Node The node which is the root of the level being processed. For example, on the first/outermost call to this function the root node will be returned. Subsequent calls will return the various child nodes. """ label_set = set(dataset.get_labels()) if len(label_set) == 1: # All remaining samples have the same label, no need to split further return Node(label_set.pop()) if len(dataset.feature_list()) == 0: # No more features to split on return Node(get_most_common(dataset.get_labels())) # We can still split further split_feature = choose_feature_to_split(dataset) node = Node(split_feature) for value in dataset.get_feature_values(split_feature): subset = dataset.value_filter( split_feature, value).drop_column(split_feature) node.add_child(value, _build_tree_recursively(subset)) return node
[docs]def choose_feature_to_split(dataset): """ Choose the root to be the feature which has the highest information gain. Args: dataset: model.DataSet The data set being used to build the decision tree. Returns: feature: string The feature which should be the root. """ gains = {} for feature in dataset.feature_list(): gains[feature] = info_gain(feature, dataset) return get_key_with_highest_value(gains)

Project Versions

This Page