Source code for pml.supervised.decision_trees.trees

# Copyright (C) 2012 David Rusk
#
# Permission is hereby granted, free of charge, to any person obtaining a copy 
# of this software and associated documentation files (the "Software"), to 
# deal in the Software without restriction, including without limitation the 
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 
# sell copies of the Software, and to permit persons to whom the Software is 
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in 
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 
# IN THE SOFTWARE.
"""
Data structures for representing trees.

@author: drusk
"""

[docs]class Tree(object): """ A tree containing nodes which are connected to each other by directed edges. """
[docs] def __init__(self, root_node): """ Constructs a new tree. Args: root_node: Node The node which will be at the root of the tree. """ self._root_node = root_node self._all_nodes = root_node.get_all_descendants() self._all_nodes.append(root_node)
[docs] def get_root_node(self): """ Retrieves the root node of the tree. This node will have only outgoing edges (children), no incoming edges (parents). Returns: root: Node The root node. """ return self._root_node
[docs] def get_leaves(self): """ Retrieves all leaf nodes from the tree. Returns: leaves: list(Node) """ return [node for node in self._all_nodes if node.is_leaf()]
[docs] def get_num_leaves(self): """ Counts the number of leaves in the tree. Returns: num_leaves: int """ return len(self.get_leaves())
[docs] def get_depth(self): """ Calculates the number of nodes on the longest path from root to leaf. Returns: depth: int """ return self._root_node.get_height() + 1
[docs]class Node(object): """ A node in a tree. Holds a value and may have branches connecting it to other nodes. """
[docs] def __init__(self, value): """ Constructs a new node. Args: value: The data value to be associated with this node. """ self._value = value self._children = {}
[docs] def get_value(self): """ Retrieves the data value associated with this node. Returns: value: The value associated with this node. """ return self._value
[docs] def add_child(self, branch, child): """ Creates a branch from this node to another node which will be the child. Args: branch: The identifier for the branch connecting this node to the child node. child: Node Another node which will be a child of the current node. Returns: void """ self._children[branch] = child
[docs] def get_child(self, branch): """ Retrieves the child node connected by the specified branch. Args: branch: The identifier that was used to associate a child node with the current node. Returns: child: Node The child node found by following the specified branch. Raises: KeyError if the specified branch does not exist. """ return self._children[branch]
[docs] def get_branches(self): """ Retrieves all the branches to children of the current node. Returns: branches: list A list of all the branches to child nodes. Note that this means branches TO this node are not included. """ return self._children.keys()
[docs] def is_leaf(self): """ Checks if this node is a leaf (has no children). Returns: is_leaf: boolean True if this node has no children. """ return len(self._children) == 0
[docs] def get_height(self): """ Determines the node's height, i.e. the maximum number of edges between it and a leaf node. Returns: height: int """ max_distance = 0 for branch in self.get_branches(): distance = self.get_child(branch).get_height() + 1 if distance > max_distance: max_distance = distance return max_distance
[docs] def get_all_descendants(self): """ Retrieves all descendants of the current node, i.e. nodes which can eventually be reached by following outgoing branches from the current node. Returns: descendants: list(Node) """ descendants = [] for branch in self.get_branches(): child_node = self.get_child(branch) descendants.append(child_node) descendants.extend(child_node.get_all_descendants()) return descendants

Project Versions

This Page