Source code for pml.supervised.decision_trees.interface

# Copyright (C) 2012 David Rusk
#
# Permission is hereby granted, free of charge, to any person obtaining a copy 
# of this software and associated documentation files (the "Software"), to 
# deal in the Software without restriction, including without limitation the 
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
# sell copies of the Software, and to permit persons to whom the Software is 
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in 
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
# IN THE SOFTWARE.
"""
Decision trees classification algorithm.

@author: drusk
"""

from pml.supervised.classifiers import AbstractClassifier
from pml.supervised.decision_trees import id3
from pml.supervised.decision_trees.tree_plotting import MatplotlibAnnotationTreePlotter

[docs]class DecisionTree(AbstractClassifier): """ Decision tree classifier. Builds a tree which is like a flow chart. It allows a decision to be reached by checking the values for various features and following the appropriate branches until a destination is reached. In addition to being useful as a classifier, the structure of the decision tree can lend insight into the data. """
[docs] def __init__(self, training_set): """ Constructs a new decision tree. Args: training_set: model.DataSet The training data to use when building the decision tree. """ self.training_set = training_set self._tree = id3.build_tree(training_set) self._plotter = MatplotlibAnnotationTreePlotter(self._tree)
def _classify(self, sample): """ Predicts a sample's classification based on the decision tree that was built from the training data. Args: sample: The sample or observation to be classified. Returns: The sample's classification. """ node = self._tree.get_root_node() while not node.is_leaf(): feature = node.get_value() branch = sample[feature] node = node.get_child(branch) return node.get_value()
[docs] def plot(self): """ Generates a plot of the decision tree to visualize its structure. Returns: void """ self._plotter.plot()

Project Versions

This Page