AVL tree rotation in Java

2019-03-16 23:22发布

问题:

I want to implement the Java AVL tree and to rotate the tree left and right. I am not getting this.

Can anybody by looking at the code below tell me how can I possibly rotate the tree left and right and then use fix up with those two functions to balance the AVL tree?

I hope someone here can guide me through this.

import java.util.Random;
import java.util.SortedSet;
import java.util.TreeSet;

public class AVLTree<T> extends 
BinarySearchTree<AVLTree.Node<T>, T> implements SSet<T> {

    Random rand;

    public static class Node<T> extends BSTNode<Node<T>,T> {
        int h;  // the height of the node
    }

    public AVLTree() {
        sampleNode = new Node<T>();
        rand = new Random();
        c = new DefaultComparator<T>();
    }

    public int height(Node<T> u) {
        return (u == null) ? 0 : u.h;
    }

    public boolean add(T x) {
        Node<T> u = new Node<T>();
        u.x = x;
        if (super.add(u)) {
            for (Node<T> w = u; w != nil; w = w.parent) {
                // walk back up to the root adjusting heights
                w.h = Math.max(height(w.left), height(w.right)) + 1;
            }
        fixup(u);
        return true;
    }
    return false;
}

public void splice(Node<T> u) {
    Node<T> w = u.parent;
    super.splice(u);
    for (Node<T> z = u; z != nil; z = z.parent)
        z.h = Math.max(height(z.left), height(z.right)) + 1;            
    fixup(w);
}

public void checkHeights(Node<T> u) {
    if (u == nil) return;
    checkHeights(u.left);
    checkHeights(u.right);
    if (height(u) != 1 + Math.max(height(u.left), height(u.right))) 
        throw new RuntimeException("Check heights shows incorrect heights");
    int dif = height(u.left) - height(u.right);
    if (dif < -1 || dif > 1)
        throw new RuntimeException("Check heights found height difference of " + dif);
}

/**
 * TODO: finish writing this method
 * @param u
 */
public void fixup(Node<T> u) {
    while (u != nil) {
        int dif = height(u.left) - height(u.right);
        if (dif > 1) {
            // TODO: add code here to fix AVL condition 
            // on the path from u to the root, if  necessary
        } else if (dif < -1) {
            // TODO: add code here to fix AVL condition
            // on the path from u to the root, if necessary                 
        }
        u = u.parent;
    }
}

public Node rotateLeft() {
    return rotateLeft(u.parent);
}

public void rotateLeft(Node<T> u) {
    // TODO: Recompute height values at u and u.parent
}

public void rotateRight(Node<T> u) {
    // TODO: Recompute height values at u and u.parent
}

public static <T> T find(SortedSet<T> ss, T x) {
    SortedSet<T> ts = ss.tailSet(x);
    if (!ts.isEmpty()) {
        return ts.first(); 
    }
    return null;
}

/**
 * This just does some very basic correctness testing
 * @param args
 */
public static void main(String[] args) {
    AVLTree<Integer> t = new AVLTree<Integer>();
    Random r = new Random(0);
    System.out.print("Running AVL tests...");
    int n = 1000;
    for (int i = 0; i < n; i++) {
        t.add(r.nextInt(2*n));
        t.checkHeights(t.r);
    }
    for (int i = 0; i < n; i++) {
        t.remove(r.nextInt(2*n));
        t.checkHeights(t.r);
    }
    System.out.println("done");

    t.clear();
    System.out.print("Running correctness tests...");
    n = 100000;
    SortedSet<Integer> ss = new TreeSet<Integer>();
    Random rand = new Random();
    for (int i = 0; i < n; i++) {
        Integer x = rand.nextInt(2*n);
        boolean b1 = t.add(x);
        boolean b2 = ss.add(x);
        if (b1 != b2) {
            throw new RuntimeException("Adding " + x + " gives " + b2 
                    + " in SortedSet and " + b1 + " in AVL Tree");
        }
    }
    for (int i = 0; i < n; i++) {
        Integer x = rand.nextInt(2*n);
        Integer x1 = t.find(x);
        Integer x2 = find(ss, x);
        if (x1 != x2) {
            throw new RuntimeException("Searching " + x + " gives " + x2 
                    + " in SortedSet and " + x1 + " in AVL Tree");
        }
        ss.headSet(x);
    }
    for (int i = 0; i < n; i++) {
        Integer x = rand.nextInt(2*n);
        boolean b1 = t.remove(x);
        boolean b2 = ss.remove(x);
        if (b1 != b2) {
            throw new RuntimeException("Error (2): Removing " + x + " gives " + b2 
                    + " in SortedSet and " + b1 + " in AVL Tree");
        }
    }
    for (int i = 0; i < n; i++) {
        Integer x = rand.nextInt(2*n);
        Integer x1 = t.find(x);
        Integer x2 = find(ss, x);
        if (x1 != x2) {
            throw new RuntimeException("Error (3): Searching " + x + " gives " + x2 
                    + " in SortedSet and " + x1 + " in AVL Tree");
        }
        ss.headSet(x);
    }
    System.out.println("done");
 }
}

回答1:

Full AVL tree implementation:

public class AVLTree<T> {

    private AVLNode<T> root;

    private static class AVLNode<T> {

        private T t;
        private int height;
        private AVLNode<T> left;
        private AVLNode<T> right;

        private AVLNode(T t) {
            this.t = t;
            height = 1;
        }
    }

    public void insert(T value) {
        root = insert(root, value);
    }

    private AVLNode<T> insert(AVLNode<T> n, T v) {
        if (n == null) {
            n = new AVLNode<T>(v);
            return n;
        } else {
            int k = ((Comparable) n.t).compareTo(v);
            if (k > 0) {
                n.left = insert(n.left, v);
            } else {
                n.right = insert(n.right, v);
            }
            n.height = Math.max(height(n.left), height(n.right)) + 1;
            int heightDiff = heightDiff(n);
            if (heightDiff < -1) {
                if (heightDiff(n.right) > 0) {
                    n.right = rightRotate(n.right);
                    return leftRotate(n);
                } else {
                    return leftRotate(n);
                }
            } else if (heightDiff > 1) {
                if (heightDiff(n.left) < 0) {
                    n.left = leftRotate(n.left);
                    return rightRotate(n);
                } else {
                    return rightRotate(n);
                }
            } else;

        }
        return n;
    }

    private AVLNode<T> leftRotate(AVLNode<T> n) {
        AVLNode<T> r = n.right;
        n.right = r.left;
        r.left = n;
        n.height = Math.max(height(n.left), height(n.right)) + 1;
        r.height = Math.max(height(r.left), height(r.right)) + 1;
        return r;
    }

    private AVLNode<T> rightRotate(AVLNode<T> n) {
        AVLNode<T> r = n.left;
        n.left = r.right;
        r.right = n;
        n.height = Math.max(height(n.left), height(n.right)) + 1;
        r.height = Math.max(height(r.left), height(r.right)) + 1;
        return r;
    }

    private int heightDiff(AVLNode<T> a) {
        if (a == null) {
            return 0;
        }
        return height(a.left) - height(a.right);
    }

    private int height(AVLNode<T> a) {
        if (a == null) {
            return 0;
        }
        return a.height;
    }


}


回答2:

Here's a full implementation of AVL tree in Java

class Node {
    int key;
    Node left;
    Node right;
    int height;

    Node(int value) {
        key = value;
        left = null;
        right = null;
        height = 1;
    }
}

class AVLTree {
    Node root;

    int height(Node root) {
        if (root == null)
            return 0;

        return root.height;
    }

    int findHeight() {
        return height(root);
    }

    int findHeightFrom(int value) {
        Node node = search(root, value);
        if (node == null)
            return -1;

        return node.height;
    }

    Node search(Node root, int value) {
        if (root == null)
            return null;
        else {
            if (value == root.key)
                return root;
            else if (value < root.key)
                return search(root.left, value);
            else
                return search(root.right, value);
        }
    }

    boolean find(int value) {
        Node node = search(root,value);

        if (node == null)
            return false;

        return true;
    }

    int max(int one, int two) {
        return (one > two) ? one : two;
    }

    Node rightRotate(Node root) {
        Node rootLeftChild = root.left;
        root.left = rootLeftChild.right;
        rootLeftChild.right = root;

        root.height = max(height(root.left), height(root.right)) + 1;
        rootLeftChild.height = max(height(rootLeftChild.left), height(rootLeftChild.right)) + 1;

        return rootLeftChild;
    }

    Node leftRotate(Node root) {
        Node rootRightChild = root.right;
        root.right = rootRightChild.left;
        rootRightChild.left = root;

        root.height = max(height(root.left), height(root.right)) + 1;
        rootRightChild.height = max(height(rootRightChild.left), height(rootRightChild.right)) + 1;

        return rootRightChild;
    }

    Node insertNode(Node root, int value) {
        if (root == null)
            root = new Node(value);
        else {
            if (value < root.key)
                root.left = insertNode(root.left, value);
            else
                root.right = insertNode(root.right, value);
        }

        root.height = max(height(root.left), height(root.right)) + 1;

        int balanceFactor = height(root.left) - height(root.right);

        if (balanceFactor > 1) {
            // either left-left case or left-right case
            if (value < root.left.key) {
                // left-left case
                root = rightRotate(root);
            } else {
                // left-right case
                root.left = leftRotate(root.left);
                root = rightRotate(root);
            }
        } else if (balanceFactor < -1) {
            // either right-right case or right-left case
            if (value > root.right.key) {
                // right-right case
                root = leftRotate(root);
            } else {
                // right-left case
                root.right = rightRotate(root.right);
                root = leftRotate(root);
            }
        }

        return root;
    }

    void insert(int value) {
        root = insertNode(root, value);
    }

    void inorder(Node root) {
        if (root != null) {
            inorder(root.left);
            System.out.print(root.key + " ");
            inorder(root.right);
        }
    }

    void inorderTraversal() {
        inorder(root);
        System.out.println();
    }

    void preorder(Node root) {
        if (root != null) {
            System.out.print(root.key + " ");
            preorder(root.left);
            preorder(root.right);
        }
    }

    void preorderTraversal() {
        preorder(root);
        System.out.println();
    }
}

public class AVLTreeExample {
    public static void main(String[] args) {
        AVLTree avl = new AVLTree();
        avl.insert(10);
        avl.insert(20);
        avl.insert(30);
        avl.insert(40);
        avl.insert(50);
        avl.insert(25);

        System.out.print("Inorder Traversal : "); avl.inorderTraversal();
        System.out.print("Preorder Traversal : "); avl.preorderTraversal();
        System.out.println("Searching for 10 : " + avl.find(10));
        System.out.println("Searching for 11 : " + avl.find(11));
        System.out.println("Searching for 20 : " + avl.find(20));
        System.out.println("Height of the tree : " + avl.findHeight());
        System.out.println("Finding height from 10 : " + avl.findHeightFrom(10));
        System.out.println("Finding height from 20 : " + avl.findHeightFrom(20));
        System.out.println("Finding height from 25 : " + avl.findHeightFrom(25));
    }
}


回答3:

in order to rotate it right you need to first check if the parent is not root then if the parent is the right of the grand parent if so, set the right of the grand parent to the child else, set the left of the gran parent to the child

otherwise, root is child



标签: java avl-tree