平衡二叉树

平衡二索树

平衡二叉树(AVL树)

空间和时间复杂度

空间时间复杂为O(N)

时间复杂度:

操作

插入

插入会破坏平衡,以下分四种情况说明,假设第一个破坏平衡的节点为k

avl_tree_3

删除

分为以下4中情况

  1. 当前节点为删除节点,且为叶子节点,直接删除
  2. 当前节点为删除节点有且只有一个子树,将当前子树替换为当前节点
  3. 当前节点为删除节点并且左右子树同时存在,则查找左子树最大值或者右子树最小值,进行替换,并递归删除,这样操作不会影响节点平衡
  4. 当前节点不是删除节点,当删除的值小于当前节点值,则在左子树递归删除,反之则在右子树递归删除,该操作会破坏平衡,需要进行平衡操作
class TreeNode:
    def __init__(self, val: int, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
        self.height = 0

    def __str__(self) -> str:
        return str(self.val)


class AVLTree:
    def __init__(self):
        self.root = None

    def _height(self, node: TreeNode):
        return node.height if node else -1

    def findMin(self) -> TreeNode:
        if self.root is None:
            return None
        return self.findMin(self.left)

    def _findMin(self, root: TreeNode) -> TreeNode:
        if root.left is None:
            return root
        return self._findMin(root.left)

    def findMax(self) -> TreeNode:
        if self.root is None:
            return None
        return self._findMax(self.right)

    def _findMax(self, root: TreeNode) -> TreeNode:
        if root.right is None:
            return root
        return self._findMax(root.right)

    def _unbalanced(self, root: TreeNode):
        return abs(self._height(root.left) - self._height(root.right)) == 2

    def insert(self, val: int):
        if self.root is None:
            self.root = TreeNode(val)
            return
        self.root = self._insert(self.root, val)

    def _insert(self, root: TreeNode, val: int):
        if root is None:
            root = TreeNode(val)
        elif val < root.val:
            root.left = self._insert(root.left, val)
            if self._unbalanced(root):
                if val < root.left.val:
                    root = self.LLSpinning(root)
                else:
                    root = self.RLSpinning(root)
        elif val > root.val:
            root.right = self._insert(root.right, val)
            if self._unbalanced(root):
                if val > root.right.val:
                    root = self.RRSpinning(root)
                else:
                    root = self.LRSpinning(root)

        root.height = max(self._height(root.left), self._height(
            root.right)) + 1
        return root

    def LLSpinning(self, root: TreeNode) -> TreeNode:
        ''' LL 右旋处理
                7                    6
               /                  /    \
              6                  5       7
            /
           5
        '''
        left = root.left
        root.left, left.right = left.right, root
        root.height = max(self._height(root.right), self._height(
            root.left)) + 1
        left.height = max(self._height(left.left), root.height) + 1
        return left

    def RRSpinning(self, root: TreeNode) -> TreeNode:
        ''' RR 左旋处理
                7                    8
                 \                /    \
                  8              7      9
                    9
        '''
        right = root.right
        root.right, right.left = right.left, root
        root.height = max(self._height(root.right), self._height(
            root.left)) + 1
        right.height = max(self._height(right.right), root.height) + 1
        return right

    def LRSpinning(self, root: TreeNode) -> TreeNode:
        ''' LR 双向旋转,先右旋在左旋
                7           7                8
                 \          \             /    \
                  9           8           7      9
                 /              \
                8                9
        '''

        root.right = self.LLSpinning(root.right)
        return self.RRSpinning(root)

    def RLSpinning(self, root: TreeNode) -> TreeNode:
        ''' RL 双向旋转,先左旋在右旋
                7         7               6
               /         /               /  \
             5          6              5     7
              \        /
                6      5
        '''
        root.left = self.RRSpinning(root.left)
        return self.LLSpinning(root)

    def remove(self, val: int):
        self.root = self._remove(self.root, val)

    def _remove(self, root: TreeNode, val: int) -> TreeNode:
        if root is None:
            return root
        if root.val > val:
            # 在左子树查找,会破坏平衡,需要自旋
            root.left = self._remove(root.left, val)
            if self._unbalanced(root):
                if self._height(root.right.right) >= self._height(
                        root.right.left):
                    '''
                        左子树删除后,右子树是RR型,需要进行左旋
                            5
                            \
                             6
                             \
                              7
                    '''

                    root = self.LLSpinning(root)
                else:
                    '''
                        左子树删除后
                           5
                            \
                              7
                             /
                             6
                    '''

                    root = self.LRSpinning(root)
            root.height = max(self._height(root.left), self._height(
                root.right)) + 1
        elif root.val < val:
            # 在右子树查找,会破坏平衡,需要自旋
            root.right = self._remove(root.right, val)
            if self._unbalanced(root):
                if self._height(root.left.left) >= self._height(
                        root.left.right):
                    root = self.RRSpinning(root)
                else:
                    root = self.RLSpinning(root)
            root.height = max(self._height(root.left), self._height(
                root.right)) + 1
        else:
            if root.left and root.right:
                if self._height(root.left) <= self._height(root.right):
                    # 左子树高度低于右子树,取右子树最小值,然后进行替换删除
                    minNode = self._findMin(root.right)
                    root.val = minNode.val
                    root.right = self._remove(root.right, minNode.val)
                else:
                    # 左子树高度高于右子树,取左子树最大值,然后进行替换删除
                    maxNode = self._findMax(root.left)
                    root.val = maxNode.val
                    root.left = self._remove(root.left, maxNode.val)
                root.height = max(self._height(root.left),
                                  self._height(root.right)) + 1
            elif root.left:
                root = root.left
            else:
                root = root.right

        return root

    def search(self, val: int) -> TreeNode:
        return self._search(self.root, val)

    def _search(self, root: TreeNode, val: int) -> TreeNode:
        if root is None:
            return root
        if root.val == val:
            return root
        elif root.val > val:
            return self._search(root.left, val)
        else:
            return self._search(root.right, val)

    def travel(self):
        self._travel(self.root)

    def _travel(self, root: TreeNode):
        if root is None:
            return
        self._travel(root.left)
        print("值: ", root.val, " 高度: ", root.height)
        self._travel(root.right)