Implementing an AVL tree in JAVA

2020-06-07 02:04发布

问题:

I want to implement an AVL Tree in Java, here is what I have so far:

public class AVLNode {

  private int size; /** The size of the tree. */

  private int height; /** The height of the tree. */

  private Object key;/** The key of the current node. */

  private Object data;/** The data of the current node. */

  private Comparator comp;/** The {@link Comparator} used by the node. */

  /* All the nodes pointed by the current node.*/
  private AVLNode left,right,parent,succ,pred;

  /* Instantiates a new AVL node.
  *
  *  @param key the key of the node
  *  @param data the data that the node should keep
  *  @param comp the comparator to be used in the tree
  */
  public AVLNode(Object key, Object data, Comparator comp) {
    this(key,data,comp,null);
  }

  /* Instantiates a new AVL node.
  *
  * @param key the key of the node
  * @param data the data that the node should keep
  * @param comp the comparator to be used in the tree
  * @param parent the parent of the created node
  */
  public AVLNode(Object key, Object data, Comparator comp, AVLNode parent) {
    this.data = data;
    this.key = key;
    this.comp = comp;
    this.parent = parent;

    this.left = null;
    this.right = null;
    this.succ = null;
    this.pred = null;

    this.size = 1;
    this.height = 0;
 }

 /* Adds the given data to the tree.
 *
 * @param key the key
 * @param data the data
 * @return the root of the tree after insertion and rotations
 * @author <b>students</b>
 */
  public AVLNode add(Object key,Object data) {
    return null;
  }

  /* Removes a Node which key is equal 
  * (by {@link Comparator}) to the given argument.
  *
  * @param key the key
  * @return the root after deletion and rotations
  * @author <b>students</b>
  */
  public AVLNode remove(Object key) {
    return null;    
  }

I need to implement the add and remove functions. Here is what I have so far, both should run in O(log(n)) time. Both should return the root of the whole tree:

/*  Adds a new Node into the tree.
* @param key the key of the new node
* @param data the data of the new node
*/
public void add(Object key,Object data){
    if (isEmpty()){
        this.root = new AVLNode(key,data,comp);
    }
    else{
        root = this.root.add(key,data);         
    }
}

/**
 * Removes a node n from the tree where 
 * n.key is equal (by {@link Comparator}) to the given key.
 *
 * @param key the key
 */
public void remove(Object key){
    if (isEmpty()){
        return; 
    }
    else
        root = this.root.remove(key);
}

I need help on making the add and remove functions.

Is there any guide to describe how the add and remove operations work? A library to copy or something where I can figure out how/why AVL Trees work?

回答1:

You can try my AVL Tree which is linked here. Let me know if you have any additional questions.

Source in case the link goes down

package com.jwetherell.algorithms.data_structures;

import java.util.ArrayList;
import java.util.List;

/**
* An AVL tree is a self-balancing binary search tree, and it was the first such
* data structure to be invented. In an AVL tree, the heights of the two child
* subtrees of any node differ by at most one. AVL trees are often compared with
* red-black trees because they support the same set of operations and because
* red-black trees also take O(log n) time for the basic operations. Because AVL
* trees are more rigidly balanced, they are faster than red-black trees for
* lookup intensive applications. However, red-black trees are faster for
* insertion and removal.
*
* http://en.wikipedia.org/wiki/AVL_tree
*
* @author Justin Wetherell <phishman3579@gmail.com>
*/
public class AVLTree<T extends Comparable<T>> extends BinarySearchTree<T> implements BinarySearchTree.INodeCreator<T> {

    private enum Balance {
        LEFT_LEFT, LEFT_RIGHT, RIGHT_LEFT, RIGHT_RIGHT
    };

    /**
    * Default constructor.
    */
    public AVLTree() {
        this.creator = this;
    }

    /**
    * Constructor with external Node creator.
    */
    public AVLTree(INodeCreator<T> creator) {
        super(creator);
    }

    /**
    * {@inheritDoc}
    */
    @Override
    protected Node<T> addValue(T id) {
        Node<T> nodeToReturn = super.addValue(id);
        AVLNode<T> nodeAdded = (AVLNode<T>) nodeToReturn;

        while (nodeAdded != null) {
            nodeAdded.updateHeight();
            balanceAfterInsert(nodeAdded);
            nodeAdded = (AVLNode<T>) nodeAdded.parent;
        }

        return nodeToReturn;
    }

    /**
    * Balance the tree according to the AVL post-insert algorithm.
    *
    * @param node
    *            Root of tree to balance.
    */
    private void balanceAfterInsert(AVLNode<T> node) {
        int balanceFactor = node.getBalanceFactor();
        if (balanceFactor > 1 || balanceFactor < -1) {
            AVLNode<T> parent = null;
            AVLNode<T> child = null;
            Balance balance = null;
            if (balanceFactor < 0) {
                parent = (AVLNode<T>) node.lesser;
                balanceFactor = parent.getBalanceFactor();
                if (balanceFactor < 0) {
                    child = (AVLNode<T>) parent.lesser;
                    balance = Balance.LEFT_LEFT;
                } else {
                    child = (AVLNode<T>) parent.greater;
                    balance = Balance.LEFT_RIGHT;
                }
            } else {
                parent = (AVLNode<T>) node.greater;
                balanceFactor = parent.getBalanceFactor();
                if (balanceFactor < 0) {
                    child = (AVLNode<T>) parent.lesser;
                    balance = Balance.RIGHT_LEFT;
                } else {
                    child = (AVLNode<T>) parent.greater;
                    balance = Balance.RIGHT_RIGHT;
                }
            }

            if (balance == Balance.LEFT_RIGHT) {
                // Left-Right (Left rotation, right rotation)
                rotateLeft(parent);
                rotateRight(node);
            } else if (balance == Balance.RIGHT_LEFT) {
                // Right-Left (Right rotation, left rotation)
                rotateRight(parent);
                rotateLeft(node);
            } else if (balance == Balance.LEFT_LEFT) {
                // Left-Left (Right rotation)
                rotateRight(node);
            } else {
                // Right-Right (Left rotation)
                rotateLeft(node);
            }

            node.updateHeight(); // New child node
            child.updateHeight(); // New child node
            parent.updateHeight(); // New Parent node
        }
    }

    /**
    * {@inheritDoc}
    */
    @Override
    protected Node<T> removeValue(T value) {
        // Find node to remove
        Node<T> nodeToRemoved = this.getNode(value);
        if (nodeToRemoved != null) {
            // Find the replacement node
            Node<T> replacementNode = this.getReplacementNode(nodeToRemoved);

            // Find the parent of the replacement node to re-factor the
            // height/balance of the tree
            AVLNode<T> nodeToRefactor = null;
            if (replacementNode != null)
                nodeToRefactor = (AVLNode<T>) replacementNode.parent;
            if (nodeToRefactor == null)
                nodeToRefactor = (AVLNode<T>) nodeToRemoved.parent;
            if (nodeToRefactor != null && nodeToRefactor.equals(nodeToRemoved))
                nodeToRefactor = (AVLNode<T>) replacementNode;

            // Replace the node
            replaceNodeWithNode(nodeToRemoved, replacementNode);

            // Re-balance the tree all the way up the tree
            if (nodeToRefactor != null) {
                while (nodeToRefactor != null) {
                    nodeToRefactor.updateHeight();
                    balanceAfterDelete(nodeToRefactor);
                    nodeToRefactor = (AVLNode<T>) nodeToRefactor.parent;
                }
            }
        }
        return nodeToRemoved;
    }

    /**
    * Balance the tree according to the AVL post-delete algorithm.
    *
    * @param node
    *            Root of tree to balance.
    */
    private void balanceAfterDelete(AVLNode<T> node) {
        int balanceFactor = node.getBalanceFactor();
        if (balanceFactor == -2 || balanceFactor == 2) {
            if (balanceFactor == -2) {
                AVLNode<T> ll = (AVLNode<T>) node.lesser.lesser;
                int lesser = (ll != null) ? ll.height : 0;
                AVLNode<T> lr = (AVLNode<T>) node.lesser.greater;
                int greater = (lr != null) ? lr.height : 0;
                if (lesser >= greater) {
                    rotateRight(node);
                    node.updateHeight();
                    if (node.parent != null)
                        ((AVLNode<T>) node.parent).updateHeight();
                } else {
                    rotateLeft(node.lesser);
                    rotateRight(node);

                    AVLNode<T> p = (AVLNode<T>) node.parent;
                    if (p.lesser != null)
                        ((AVLNode<T>) p.lesser).updateHeight();
                    if (p.greater != null)
                        ((AVLNode<T>) p.greater).updateHeight();
                    p.updateHeight();
                }
            } else if (balanceFactor == 2) {
                AVLNode<T> rr = (AVLNode<T>) node.greater.greater;
                int greater = (rr != null) ? rr.height : 0;
                AVLNode<T> rl = (AVLNode<T>) node.greater.lesser;
                int lesser = (rl != null) ? rl.height : 0;
                if (greater >= lesser) {
                    rotateLeft(node);
                    node.updateHeight();
                    if (node.parent != null)
                        ((AVLNode<T>) node.parent).updateHeight();
                } else {
                    rotateRight(node.greater);
                    rotateLeft(node);

                    AVLNode<T> p = (AVLNode<T>) node.parent;
                    if (p.lesser != null)
                        ((AVLNode<T>) p.lesser).updateHeight();
                    if (p.greater != null)
                        ((AVLNode<T>) p.greater).updateHeight();
                    p.updateHeight();
                }
            }
        }
    }

    /**
    * {@inheritDoc}
    */
    @Override
    protected boolean validateNode(Node<T> node) {
        boolean bst = super.validateNode(node);
        if (!bst)
            return false;

        AVLNode<T> avlNode = (AVLNode<T>) node;
        int balanceFactor = avlNode.getBalanceFactor();
        if (balanceFactor > 1 || balanceFactor < -1) {
            return false;
        }
        if (avlNode.isLeaf()) {
            if (avlNode.height != 1)
                return false;
        } else {
            AVLNode<T> avlNodeLesser = (AVLNode<T>) avlNode.lesser;
            int lesserHeight = 1;
            if (avlNodeLesser != null)
                lesserHeight = avlNodeLesser.height;

            AVLNode<T> avlNodeGreater = (AVLNode<T>) avlNode.greater;
            int greaterHeight = 1;
            if (avlNodeGreater != null)
                greaterHeight = avlNodeGreater.height;

            if (avlNode.height == (lesserHeight + 1) || avlNode.height == (greaterHeight + 1)) {
                return true;
            } else {
                return false;
            }
        }

        return true;
    }

    /**
    * {@inheritDoc}
    */
    @Override
    public String toString() {
        return AVLTreePrinter.getString(this);
    }

    /**
    * {@inheritDoc}
    */
    @Override
    public Node<T> createNewNode(Node<T> parent, T id) {
        return (new AVLNode<T>(parent, id));
    }

    protected static class AVLNode<T extends Comparable<T>> extends Node<T> {

        protected int height = 1;

        /**
        * Constructor for an AVL node
        *
        * @param parent
        *            Parent of the node in the tree, can be NULL.
        * @param value
        *            Value of the node in the tree.
        */
        protected AVLNode(Node<T> parent, T value) {
            super(parent, value);
        }

        /**
        * Determines is this node is a leaf (has no children).
        *
        * @return True if this node is a leaf.
        */
        protected boolean isLeaf() {
            return ((lesser == null) && (greater == null));
        }

        /**
        * Updates the height of this node based on it's children.
        */
        protected void updateHeight() {
            int lesserHeight = 0;
            int greaterHeight = 0;
            if (lesser != null) {
                AVLNode<T> lesserAVLNode = (AVLNode<T>) lesser;
                lesserHeight = lesserAVLNode.height;
            }
            if (greater != null) {
                AVLNode<T> greaterAVLNode = (AVLNode<T>) greater;
                greaterHeight = greaterAVLNode.height;
            }

            if (lesserHeight > greaterHeight) {
                height = lesserHeight + 1;
            } else {
                height = greaterHeight + 1;
            }
        }

        /**
        * Get the balance factor for this node.
        *
        * @return An integer representing the balance factor for this node. It
        *         will be negative if the lesser branch is longer than the
        *         greater branch.
        */
        protected int getBalanceFactor() {
            int lesserHeight = 0;
            int greaterHeight = 0;
            if (lesser != null) {
                AVLNode<T> lesserAVLNode = (AVLNode<T>) lesser;
                lesserHeight = lesserAVLNode.height;
            }
            if (greater != null) {
                AVLNode<T> greaterAVLNode = (AVLNode<T>) greater;
                greaterHeight = greaterAVLNode.height;
            }
            return greaterHeight - lesserHeight;
        }

        /**
        * {@inheritDoc}
        */
        @Override
        public String toString() {
            return "value=" + id + " height=" + height + " parent=" + ((parent != null) ? parent.id : "NULL")
                    + " lesser=" + ((lesser != null) ? lesser.id : "NULL") + " greater="
                    + ((greater != null) ? greater.id : "NULL");
        }
    }

    protected static class AVLTreePrinter {

        public static <T extends Comparable<T>> String getString(AVLTree<T> tree) {
            if (tree.root == null)
                return "Tree has no nodes.";
            return getString((AVLNode<T>) tree.root, "", true);
        }

        public static <T extends Comparable<T>> String getString(AVLNode<T> node) {
            if (node == null)
                return "Sub-tree has no nodes.";
            return getString(node, "", true);
        }

        private static <T extends Comparable<T>> String getString(AVLNode<T> node, String prefix, boolean isTail) {
            StringBuilder builder = new StringBuilder();

            builder.append(prefix + (isTail ? "└── " : "├── ") + "(" + node.height + ") " + node.id + "\n");
            List<Node<T>> children = null;
            if (node.lesser != null || node.greater != null) {
                children = new ArrayList<Node<T>>(2);
                if (node.lesser != null)
                    children.add(node.lesser);
                if (node.greater != null)
                    children.add(node.greater);
            }
            if (children != null) {
                for (int i = 0; i < children.size() - 1; i++) {
                    builder.append(getString((AVLNode<T>) children.get(i), prefix + (isTail ? "    " : "│   "), false));
                }
                if (children.size() >= 1) {
                    builder.append(getString((AVLNode<T>) children.get(children.size() - 1), prefix
                            + (isTail ? "    " : "│   "), true));
                }
            }

            return builder.toString();
        }
    }
}


回答2:

Yet another Java implementation of AVL, with insert, search and delete.

It also prints out the parent name and height of each node when you do an in-order traversal, which makes it easy to see the effect of operations.

Out-of-the-box runnable code, should be especially helpful for CS students struggling with homework :-)

public class AVLTree {

    private static class Node {
        Node left, right;
        Node parent;
        int value ;
        int height = 0;

        public Node(int data, Node parent) {
            this.value = data;
            this.parent = parent;
        }

        @Override
        public String toString() {
            return value + " height " + height + " parent " + (parent == null ?
                    "NULL" : parent.value) + " | ";
        }

        void setLeftChild(Node child) {
            if (child != null) {
                child.parent = this;
            }

            this.left = child;
        }

        void setRightChild(Node child) {
            if (child != null) {
                child.parent = this;
            }

            this.right = child;
        }
    }

    private Node root = null;

    public void insert(int data) {
        insert(root, data);
    }

    private int height(Node node) {
        return node == null ? -1 : node.height;
    }

    private void insert(Node node, int value) {
        if (root == null) {
            root = new Node(value, null);
            return;
        }

        if (value < node.value) {
            if (node.left != null) {
                insert(node.left, value);
            } else {
                node.left = new Node(value, node);
            }

            if (height(node.left) - height(node.right) == 2) { //left heavier
                if (value < node.left.value) {
                    rotateRight(node);
                } else {
                    rotateLeftThenRight(node);
                }
            }
        } else if (value > node.value) {
            if (node.right != null) {
                insert(node.right, value);
            } else {
                node.right = new Node(value, node);
            }

            if (height(node.right) - height(node.left) == 2) { //right heavier
                if (value > node.right.value)
                    rotateLeft(node);
                else {
                    rotateRightThenLeft(node);
                }
            }
        }

        reHeight(node);
    }

    private void rotateRight(Node pivot) {
        Node parent = pivot.parent;
        Node leftChild = pivot.left;
        Node rightChildOfLeftChild = leftChild.right;
        pivot.setLeftChild(rightChildOfLeftChild);
        leftChild.setRightChild(pivot);
        if (parent == null) {
            this.root = leftChild;
            leftChild.parent = null;
            return;
        }

        if (parent.left == pivot) {
            parent.setLeftChild(leftChild);
        } else {
            parent.setRightChild(leftChild);
        }

        reHeight(pivot);
        reHeight(leftChild);
    }

    private void rotateLeft(Node pivot) {
        Node parent = pivot.parent;
        Node rightChild = pivot.right;
        Node leftChildOfRightChild = rightChild.left;
        pivot.setRightChild(leftChildOfRightChild);
        rightChild.setLeftChild(pivot);
        if (parent == null) {
            this.root = rightChild;
            rightChild.parent = null;
            return;
        }

        if (parent.left == pivot) {
            parent.setLeftChild(rightChild);
        } else {
            parent.setRightChild(rightChild);
        }

        reHeight(pivot);
        reHeight(rightChild);
    }

    private void reHeight(Node node) {
        node.height = Math.max(height(node.left), height(node.right)) + 1;
    }

    private void rotateLeftThenRight(Node node) {
        rotateLeft(node.left);
        rotateRight(node);
    }

    private void rotateRightThenLeft(Node node) {
        rotateRight(node.right);
        rotateLeft(node);
    }

    public boolean delete(int key) {
        Node target = search(key);
        if (target == null) return false;
        target = deleteNode(target);
        balanceTree(target.parent);
        return true;
    }

    private Node deleteNode(Node target) {
        if (isLeaf(target)) { //leaf
            if (isLeftChild(target)) {
                target.parent.left = null;
            } else {
                target.parent.right = null;
            }
        } else if (target.left == null ^ target.right == null) { //exact 1 child
            Node nonNullChild = target.left == null ? target.right : target.left; 
            if (isLeftChild(target)) {
                target.parent.setLeftChild(nonNullChild); 
            } else {
                target.parent.setRightChild(nonNullChild);
            }
        } else {//2 children
            Node immediatePredInOrder = immediatePredInOrder(target);
            target.value = immediatePredInOrder.value;
            target = deleteNode(immediatePredInOrder);
        }

        reHeight(target.parent);
        return target;
    }

    private Node immediatePredInOrder(Node node) {
        Node current = node.left;
        while (current.right != null) {
            current = current.right;
        }

        return current;
    }

    private boolean isLeftChild(Node child) {
        return (child.parent.left == child);
    }

    private boolean isLeaf(Node node) {
        return node.left == null && node.right == null;
    }

    private int calDifference(Node node) {
        int rightHeight = height(node.right);
        int leftHeight = height(node.left);
        return rightHeight - leftHeight;
    }

    private void balanceTree(Node node) {
        int difference = calDifference(node);
        Node parent = node.parent;
        if (difference == -2) {
            if (height(node.left.left) >= height(node.left.right)) {
                rotateRight(node);
            } else {
                rotateLeftThenRight(node);
            }
        } else if (difference == 2) {
            if (height(node.right.right) >= height(node.right.left)) {
                rotateLeft(node);
            } else {
                rotateRightThenLeft(node);
            }
        }

        if (parent != null) {
            balanceTree(parent);
        }

        reHeight(node);
    }

    public Node search(int key) {
        return binarySearch(root, key);
    }

    private Node binarySearch(Node node, int key) {
        if (node == null) return null;

        if (key == node.value) {
            return node;
        }

        if (key < node.value && node.left != null) {
            return binarySearch(node.left, key);
        }

        if (key > node.value && node.right != null) {
            return binarySearch(node.right, key);
        }

        return null;
    }

    public void traverseInOrder() {
        System.out.println("ROOT " + root.toString());
        inorder(root);
        System.out.println();
    }

    private void inorder(Node node) {
        if (node != null) {
            inorder(node.left);
            System.out.print(node.toString());
            inorder(node.right);
        }
    }

    public static void main(String[] args) {
        AVLTree avl = new AVLTree();
        avl.insert(1);
        avl.traverseInOrder();
        avl.insert(2);
        avl.traverseInOrder();
        avl.insert(3);
        avl.traverseInOrder();
        avl.insert(4);
        avl.traverseInOrder();
        avl.delete(1);
        avl.traverseInOrder();
        avl.insert(5);
        avl.traverseInOrder();
        avl.insert(6);
        avl.traverseInOrder();
        avl.delete(3);
        avl.traverseInOrder();
        avl.delete(5);
        avl.traverseInOrder();
    }

}


回答3:

Well, This java code may help you, it extends a BST by Michael Goodrich:

To see complete data structure go here AVLTree.java (Link is no longer available)

import java.util.Comparator;    
 /** 
 * AVLTree class - implements an AVL Tree by extending a binary
 * search tree.
 *
 * @author Michael Goodrich, Roberto Tamassia, Eric Zamore
 */

//begin#fragment AVLTree
public class AVLTree extends BinarySearchTree implements Dictionary {
  public AVLTree(Comparator c)  { super(c); }
  public AVLTree() { super(); }
  /** Nested class for the nodes of an AVL tree. */ 
  protected static class AVLNode extends BTNode {
    protected int height;  // we add a height field to a BTNode
    AVLNode() {/* default constructor */}
    /** Preferred constructor */
    AVLNode(Object element, BTPosition parent,
        BTPosition left, BTPosition right) {
      super(element, parent, left, right);
      height = 0;
      if (left != null) 
        height = Math.max(height, 1 + ((AVLNode) left).getHeight());
      if (right != null) 
        height = Math.max(height, 1 + ((AVLNode) right).getHeight());
    } // we assume that the parent will revise its height if needed
    public void setHeight(int h) { height = h; }
    public int getHeight() { return height; }
  }
  /** Creates a new binary search tree node (overrides super's version). */
  protected BTPosition createNode(Object element, BTPosition parent,
              BTPosition left, BTPosition right) {
    return new AVLNode(element,parent,left,right);  // now use AVL nodes
  }
  /** Returns the height of a node (call back to an AVLNode). */
  protected int height(Position p)  {
    return ((AVLNode) p).getHeight();
  }
  /** Sets the height of an internal node (call back to an AVLNode). */
  protected void setHeight(Position p)  { // called only if p is internal
    ((AVLNode) p).setHeight(1+Math.max(height(left(p)), height(right(p))));
  }
  /** Returns whether a node has balance factor between -1 and 1. */
  protected boolean isBalanced(Position p)  {
    int bf = height(left(p)) - height(right(p));
    return ((-1 <= bf) &&  (bf <= 1));
  }
//end#fragment AVLTree
//begin#fragment AVLTree2
  /** Returns a child of p with height no smaller than that of the other child */
//end#fragment AVLTree2
  /** 
    * Return a child of p with height no smaller than that of the
    * other child.
    */
//begin#fragment AVLTree2
  protected Position tallerChild(Position p)  {
    if (height(left(p)) > height(right(p))) return left(p);
    else if (height(left(p)) < height(right(p))) return right(p);
    // equal height children - break tie using parent's type
    if (isRoot(p)) return left(p);
    if (p == left(parent(p))) return left(p);
    else return right(p);
  }
  /**  
    * Rebalance method called by insert and remove.  Traverses the path from 
    * zPos to the root. For each node encountered, we recompute its height 
    * and perform a trinode restructuring if it's unbalanced.
    */
  protected void rebalance(Position zPos) {
    if(isInternal(zPos))
       setHeight(zPos);
    while (!isRoot(zPos)) {  // traverse up the tree towards the root
      zPos = parent(zPos);
      setHeight(zPos);
      if (!isBalanced(zPos)) { 
    // perform a trinode restructuring at zPos's tallest grandchild
        Position xPos =  tallerChild(tallerChild(zPos));
        zPos = restructure(xPos); // tri-node restructure (from parent class)
        setHeight(left(zPos));  // recompute heights
        setHeight(right(zPos));
        setHeight(zPos);
      }
    }
  } 
  // overridden methods of the dictionary ADT
//end#fragment AVLTree2
  /** 
    * Inserts an item into the dictionary and returns the newly created
    * entry. 
    */
//begin#fragment AVLTree2
  public Entry insert(Object k, Object v) throws InvalidKeyException  {
    Entry toReturn = super.insert(k, v); // calls our new createNode method
    rebalance(actionPos); // rebalance up from the insertion position
    return toReturn;
  }
//end#fragment AVLTree2
  /** Removes and returns an entry from the dictionary. */
//begin#fragment AVLTree2
  public Entry remove(Entry ent) throws InvalidEntryException {
    Entry toReturn = super.remove(ent);
    if (toReturn != null)   // we actually removed something
      rebalance(actionPos);  // rebalance up the tree
    return toReturn;
  }
} // end of AVLTree class
//end#fragment AVLTree2

BTNode.java

public class BTNode implements BTPosition {
  private Object element;   // element stored at this node
  private BTPosition left, right, parent;  // adjacent nodes
//end#fragment BTNode
  /** Default constructor */
  public BTNode() { }
//begin#fragment BTNode
  /** Main constructor */
  public BTNode(Object element, BTPosition parent, 
           BTPosition left, BTPosition right) { 
    setElement(element);
    setParent(parent);
    setLeft(left);
    setRight(right);
  }
  public Object element() { return element; }
  public void setElement(Object o) { 
    element=o; 
  }
  public BTPosition getLeft() { return left; }
  public void setLeft(BTPosition v) { 
    left=v; 
  }
  public BTPosition getRight() { return right; }
  public void setRight(BTPosition v) { 
    right=v; 
  }
  public BTPosition getParent() { return parent; }
  public void setParent(BTPosition v) { 
    parent=v; 
  }
}

BTPosition.java

public interface BTPosition extends Position {  // inherits element()
  public void setElement(Object o);
  public BTPosition getLeft(); 
  public void setLeft(BTPosition v); 
  public BTPosition getRight(); 
  public void setRight(BTPosition v); 
  public BTPosition getParent(); 
  public void setParent(BTPosition v);
}


回答4:

I have a video playlist explaining how AVL trees work that I recommend.

Here's a working implementation of an AVL tree which is well documented. The add/remove operations work just like a regular binary search tree expect that you need to update the balance factor value as you go.

Note that this is a recursive implementation which is much simpler to understand but likely slower than its iterative counterpart.

This data structure was taken from my github repo

/**
 * This file contains an implementation of an AVL tree. An AVL tree
 * is a special type of binary tree which self balances itself to keep
 * operations logarithmic.
 *
 * @author William Fiset, william.alexandre.fiset@gmail.com
 **/

public class AVLTreeRecursive <T extends Comparable<T>> implements Iterable<T> {

  class Node {

    // 'bf' is short for Balance Factor
    int bf;

    // The value/data contained within the node.
    T value;

    // The height of this node in the tree.
    int height;

    // The left and the right children of this node.    
    Node left, right;

    public Node(T value) {
      this.value = value;
    }

  }

  // The root node of the AVL tree.
  Node root;

  // Tracks the number of nodes inside the tree.
  private int nodeCount = 0;

  // The height of a rooted tree is the number of edges between the tree's
  // root and its furthest leaf. This means that a tree containing a single 
  // node has a height of 0.
  public int height() {
    if (root == null) return 0;
    return root.height;
  }

  // Returns the number of nodes in the tree.
  public int size() {
    return nodeCount;
  }

  // Returns whether or not the tree is empty.
  public boolean isEmpty() {
    return size() == 0;
  }

  // Return true/false depending on whether a value exists in the tree.
  public boolean contains(T value) {
    return contains(root, value);
  }

  // Recursive contains helper method.
  private boolean contains(Node node, T value) {

    if (node == null) return false;

    // Compare current value to the value in the node.
    int cmp = value.compareTo(node.value);

    // Dig into left subtree.
    if (cmp < 0) return contains(node.left, value);

    // Dig into right subtree.
    if (cmp > 0) return contains(node.right, value);

    // Found value in tree.
    return true;

  }

  // Insert/add a value to the AVL tree. The value must not be null, O(log(n))
  public boolean insert(T value) {
    if (value == null) return false;
    if (!contains(root, value)) {
      root = insert(root, value);
      nodeCount++;
      return true;
    }
    return false;
  }

  // Inserts a value inside the AVL tree.
  private Node insert(Node node, T value) {

    // Base case.
    if (node == null) return new Node(value);

    // Compare current value to the value in the node.
    int cmp = value.compareTo(node.value);

    // Insert node in left subtree.
    if (cmp < 0) {
      node.left = insert(node.left, value);;

    // Insert node in right subtree.
    } else {
      node.right = insert(node.right, value);
    }

    // Update balance factor and height values.
    update(node);

    // Re-balance tree.
    return balance(node);

  }

  // Update a node's height and balance factor.
  private void update(Node node) {

    int leftNodeHeight  = (node.left  == null) ? -1 : node.left.height;
    int rightNodeHeight = (node.right == null) ? -1 : node.right.height;

    // Update this node's height.
    node.height = 1 + Math.max(leftNodeHeight, rightNodeHeight);

    // Update balance factor.
    node.bf = rightNodeHeight - leftNodeHeight;

  }

  // Re-balance a node if its balance factor is +2 or -2.
  private Node balance(Node node) {

    // Left heavy subtree.
    if (node.bf == -2) {

      // Left-Left case.
      if (node.left.bf <= 0) {
        return leftLeftCase(node);

      // Left-Right case.
      } else {
        return leftRightCase(node);
      }

    // Right heavy subtree needs balancing.
    } else if (node.bf == +2) {

      // Right-Right case.
      if (node.right.bf >= 0) {
        return rightRightCase(node);

      // Right-Left case.
      } else {
        return rightLeftCase(node);
      }

    }

    // Node either has a balance factor of 0, +1 or -1 which is fine.
    return node;

  }

  private Node leftLeftCase(Node node) {
    return rightRotation(node);
  }

  private Node leftRightCase(Node node) {
    node.left = leftRotation(node.left);
    return leftLeftCase(node);
  }

  private Node rightRightCase(Node node) {
    return leftRotation(node);
  }

  private Node rightLeftCase(Node node) {
    node.right = rightRotation(node.right);
    return rightRightCase(node);
  }

  private Node leftRotation(Node node) {
    Node newParent = node.right;
    node.right = newParent.left;
    newParent.left = node;
    update(node);
    update(newParent);
    return newParent;
  }

  private Node rightRotation(Node node) {
    Node newParent = node.left;
    node.left = newParent.right;
    newParent.right = node;
    update(node);
    update(newParent);
    return newParent;
  }

  // Remove a value from this binary tree if it exists, O(log(n))
  public boolean remove(T elem) {

    if (elem == null) return false;

    if (contains(root, elem)) {
      root = remove(root, elem);
      nodeCount--;
      return true;
    }

    return false;
  }

  // Removes a value from the AVL tree.
  private Node remove(Node node, T elem) {

    if (node == null) return null;

    int cmp = elem.compareTo(node.value);

    // Dig into left subtree, the value we're looking
    // for is smaller than the current value.
    if (cmp < 0) {
      node.left = remove(node.left, elem);

    // Dig into right subtree, the value we're looking
    // for is greater than the current value.
    } else if (cmp > 0) {
      node.right = remove(node.right, elem);

    // Found the node we wish to remove.
    } else {

      // This is the case with only a right subtree or no subtree at all. 
      // In this situation just swap the node we wish to remove
      // with its right child.
      if (node.left == null) {
        return node.right;

      // This is the case with only a left subtree or 
      // no subtree at all. In this situation just
      // swap the node we wish to remove with its left child.
      } else if (node.right == null) {
        return node.left;

      // When removing a node from a binary tree with two links the
      // successor of the node being removed can either be the largest
      // value in the left subtree or the smallest value in the right 
      // subtree. As a heuristic, I will remove from the subtree with
      // the most nodes in hopes that this may help with balancing.
      } else {

        // Choose to remove from left subtree
        if (node.left.height > node.right.height) {

          // Swap the value of the successor into the node.
          T successorValue = findMax(node.left);
          node.value = successorValue;

          // Find the largest node in the left subtree.
          node.left = remove(node.left, successorValue);

        } else {

          // Swap the value of the successor into the node.
          T successorValue = findMin(node.right);
          node.value = successorValue;

          // Go into the right subtree and remove the leftmost node we
          // found and swapped data with. This prevents us from having
          // two nodes in our tree with the same value.
          node.right = remove(node.right, successorValue);
        }
      }
    }

    // Update balance factor and height values.
    update(node);

    // Re-balance tree.
    return balance(node);

  }

  // Helper method to find the leftmost node (which has the smallest value)
  private T findMin(Node node) {
    while(node.left != null) 
      node = node.left;
    return node.value;
  }

  // Helper method to find the rightmost node (which has the largest value)
  private T findMax(Node node) {
    while(node.right != null) 
      node = node.right;
    return node.value;
  }

  // Returns as iterator to traverse the tree in order.
  public java.util.Iterator<T> iterator () {

    final int expectedNodeCount = nodeCount;
    final java.util.Stack<Node> stack = new java.util.Stack<>();
    stack.push(root);

    return new java.util.Iterator<T> () {
      Node trav = root;
      @Override 
      public boolean hasNext() {
        if (expectedNodeCount != nodeCount) throw new java.util.ConcurrentModificationException();        
        return root != null && !stack.isEmpty();
      }
      @Override 
      public T next () {

        if (expectedNodeCount != nodeCount) throw new java.util.ConcurrentModificationException();

        while(trav != null && trav.left != null) {
          stack.push(trav.left);
          trav = trav.left;
        }

        Node node = stack.pop();

        if (node.right != null) {
          stack.push(node.right);
          trav = node.right;
        }

        return node.value;
      }
      @Override 
      public void remove() {
        throw new UnsupportedOperationException();
      }      
    };
  }

  // Make sure all left child nodes are smaller in value than their parent and
  // make sure all right child nodes are greater in value than their parent.
  // (Used only for testing)
  boolean validateBstInvarient(Node node) {
    if (node == null) return true;
    T val = node.value;
    boolean isValid = true;
    if (node.left  != null) isValid = isValid && node.left.value.compareTo(val)  < 0;
    if (node.right != null) isValid = isValid && node.right.value.compareTo(val) > 0;
    return isValid && validateBstInvarient(node.left) && validateBstInvarient(node.right);
  }

  // Example usage of AVL tree.
  public static void main(String[] args) {
    AVLTreeRecursive<Integer> tree = new AVLTreeRecursive<>();

    tree.insert(7);
    tree.insert(16);
    tree.insert(-2);
    tree.insert(10);
    tree.insert(12);

    // Prints: -2 7 10 12 16 
    for(Integer value : tree) System.out.print(value + " ");
    System.out.println();

    tree.remove(12);
    tree.remove(-5);
    tree.remove(10);

    // Prints: -2 7 16 
    for(Integer value : tree) System.out.print(value + " ");
    System.out.println();


    System.out.println(tree.contains(10)); // false
    System.out.println(tree.contains(16)); // true

  }

}