public class ExtendsBST<E extends Comparable<E>> { private class Node { public E e; public Node left, right; public int size, depth, count; public Node(E e) { this.e = e; size = 1; count = 1; depth = 1; } @Override public String toString() { return e.toString(); } } private Node root; private int size; public ExtendsBST() { root = null; size = 0; } public void add(E e) { if (root == null) { root = new Node(e); return; } Node cur = root; for (;;) { int cmp = e.compareTo(cur.e); if (cmp == 0) { cur.count++; break; } if (cmp < 0 && cur.left == null) { cur.left = new Node(e); size ++; return; } else if(cmp > 0 && cur.right == null) {// e.compareTo(cur.e) > 0 cur.right = new Node(e); size ++; return; } if (cmp < 0) cur = cur.left; else // cmp > 0, 相等已经被for语句过滤了 cur = cur.right; cur.size ++; cur.depth ++; } } public E floor(E e) { Node minNode = floor(root, e); return minNode == null ? null : minNode.e; } private Node floor(Node node, E e) { // 中序遍历得到最接近且比e小的节点 if (node == null) return null; int cmp = e.compareTo(node.e); if (cmp == 0) return node; if (cmp < 0) return floor(node.left, e); Node rightNode = floor(node.right, e); if (rightNode != null) return rightNode; else return node; } public E ceil(E e) { Node maxNode = ceil(root, e); return maxNode == null ? null : maxNode.e; } private Node ceil(Node node, E e) { if (node == null) return null; int cmp = e.compareTo(node.e); if (cmp == 0) return node; if (cmp > 0) return ceil(node.right, e); Node leftNode = ceil(node.left, e); if (leftNode != null) return leftNode; else return node; } public int rank(E e) { return rank(root, e); } private int rank(Node node, E e) { if (node == null) return 0; int cmp = e.compareTo(node.e); int leftSize = node.left == null ? 0 : node.left.size; if (cmp == 0) return leftSize + 1; else if (cmp > 0) return leftSize + 1 + rank(node.right, e); else// cmp < 0 return rank(node.left, e); } public E select(int rank) { if (rank >= size) throw new IllegalArgumentException( String.format("select failed; rank out of bound; SIZE=%d, RANK=%d", size, rank)); return select(root, rank).e; } private Node select(Node node, int rank) { if (node == null) return null; int t = node.left == null ? 0 : node.left.size; if (t > rank) return select(node.left, rank); else if (t < rank) return select(node.right, rank - t - 1); else return node; } @Override public String toString() { StringBuilder sb = new StringBuilder(); generateBSTString(root, 0, sb); return sb.toString(); } // 生成以node为节点,深度为depth的描述二叉树字符串 private void generateBSTString(Node node, int depth, StringBuilder sb) { if (node == null) { sb.append(generateDepthString(depth) + "null\n"); return; } sb.append(generateDepthString(depth) + node.e + "\n"); generateBSTString(node.left, depth+1 , sb); generateBSTString(node.right, depth + 1, sb); } private String generateDepthString(int depth) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < depth; i ++) sb.append("--"); return sb.toString(); } public static void main(String[] args) { ExtendsBST<Integer> bst = new ExtendsBST<>(); int[] nums = {5, 10, 20, 13, 56, 31, 88, 62, 19}; for (int i = 0; i < nums.length; i ++) bst.add(nums[i]); //System.out.println(bst.floor(21)); //System.out.println(bst.ceil(21)); // System.out.println(bst.rank(56)); // System.out.println(bst.rank(88)); // System.out.println(bst.rank(2000)); System.out.println(bst.select(3)); } }