从TimSort说起

459 查看

大家可能对timsort并不是很熟悉,不过说起Collections.sort(list) 应该并不陌生。


public static <T extends Comparable<? super T>> void sort(List<T> list) {
Object[] a = list.toArray();
Arrays.sort(a);
ListIterator<T> i = list.listIterator();
for (int j=0; j<a.length; j++) {
    i.next();
    i.set((T)a[j]);
}
}

通过jdk6中的Collections源码可以看到,在sort时是调用了Arrays的sort方法,完成排序再返回list

  

      public static void sort(Object[] a) {
            Object[] aux = (Object[])a.clone();
            mergeSort(aux, a, 0, a.length, 0);
        }
            private static void mergeSort(Object[] src,
                      Object[] dest,
                      int low,
                      int high,
                      int off) {
        int length = high - low;
    
        // Insertion sort on smallest arrays
            if (length < INSERTIONSORT_THRESHOLD) {
                for (int i=low; i<high; i++)
                    for (int j=i; j>low &&
                 ((Comparable) dest[j-1]).compareTo(dest[j])>0; j--)
                        swap(dest, j, j-1);
                return;
            }
    
            // Recursively sort halves of dest into src
            int destLow  = low;
            int destHigh = high;
            low  += off;
            high += off;
            int mid = (low + high) >>> 1;
            mergeSort(dest, src, low, mid, -off);
            mergeSort(dest, src, mid, high, -off);
    
            // If list is already sorted, just copy from src to dest.  This is an
            // optimization that results in faster sorts for nearly ordered lists.
            if (((Comparable)src[mid-1]).compareTo(src[mid]) <= 0) {
                System.arraycopy(src, low, dest, destLow, length);
                return;
            }
    
            // Merge sorted halves (now in src) into dest
            for(int i = destLow, p = low, q = mid; i < destHigh; i++) {
                if (q >= high || p < mid && ((Comparable)src[p]).compareTo(src[q])<=0)
                    dest[i] = src[p++];
                else
                    dest[i] = src[q++];
            }
        }

而Arrays.sort(),对原始类型(int[],double[],char[],byte[]),JDK6里用的是快速排序,对于对象类型(Object[]),JDK6则使用归并排序。
到了jdk7,快速排序升级为双基准快排(双基准快排 vs 三路快排);归并排序升级为归并排序的改进版TimSort。
再到了JDK8, 对大集合增加了Arrays.parallelSort()函数,使用fork-Join框架,充分利用多核,对大的集合进行切分然后再归并排序,而在小的连续片段里,依然使用TimSort与DualPivotQuickSort。
谈到优化过程,先来回忆一下归并排序,长度为1的数组是已经排序好的。对长度为n>1的数组,将其分为2段(partition)(最常见的做法是从中间分开)。对两段数组递归进行归并排序,完成后将其合并(merge):通过扫描个已排序的数组并总是挑出两者中较小数作为合并数组中的下一个元素,来将两个已排序数组合并形成一个更大的已排序数组。
此时如果我们遇到这样一个数组,{5, 6, 7, 8, 9, 10, 1, 2, 3},我们利用快排或者归并带来的开销,似乎都不划算,因为乍一看只需要一步就能完成排序。这也是timSort的一个主要优势,适应性排序,先把5至10截取,1之3截取成两段,之后再归并排序,再来看一个数组,{64, 32, 16, 8, 4, 2, 1},对于这样的案例,通过适应性排序的思路,直接反转,再来检查,已经不需要额外的工作了。TimSort在此基础上还进行了其他的一些优化,这也助推了该算法的成功。下面我们通过源码来进一步了解一下。


    static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c) {
        if (c == null) {
            Arrays.sort(a, lo, hi);
            return;
        }
    
        rangeCheck(a.length, lo, hi);
        int nRemaining  = hi - lo;
        if (nRemaining < 2)
            return;  // Arrays of size 0 and 1 are always sorted
    
        // If array is small, do a "mini-TimSort" with no merges
        if (nRemaining < MIN_MERGE) {
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }
    
        /**
         * March over the array once, left to right, finding natural runs,
         * extending short natural runs to minRun elements, and merging runs
         * to maintain stack invariant.
         */
        TimSort<T> ts = new TimSort<>(a, c);
        int minRun = minRunLength(nRemaining);
        do {
            // Identify next run
            int runLen = countRunAndMakeAscending(a, lo, hi, c);
    
            // If run is short, extend to min(minRun, nRemaining)
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                binarySort(a, lo, lo + force, lo + runLen, c);
                runLen = force;
            }
    
            // Push run onto pending-run stack, and maybe merge
            ts.pushRun(lo, runLen);
            ts.mergeCollapse();
    
            // Advance to find next run
            lo += runLen;
            nRemaining -= runLen;
        } while (nRemaining != 0);
    
        // Merge all remaining runs to complete sort
        assert lo == hi;
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    }

首先是非空的判断,如果没有提供comparator,会调用Arrays.sort,其实也是ComparableTimSort,后面是算法的主体:如果元素个数小于2,直接返回,因为这两个元素已经排序了
如果元素个数小于一个阈值(默认为),调用 binarySort,这是一个不包含合并操作的 mini-TimSort。
在关键的 do-while 循环中,不断地进行排序,合并,排序,合并,一直到所有数据都处理完。
然后会找出run的最小长度,少于这个最小长度就需要对其进行扩展,再来看下binarySort

private static <T> void binarySort(T[] a, int lo, int hi, int start,
                                   Comparator<? super T> c) {
    assert lo <= start && start <= hi;
    if (start == lo)
        start++;
    for ( ; start < hi; start++) {
        T pivot = a[start];

        // Set left (and right) to the index where a[start] (pivot) belongs
        int left = lo;
        int right = start;
        assert left <= right;
        /*
         * Invariants:
         *   pivot >= all in [lo, left).
         *   pivot <  all in [right, start).
         */
        while (left < right) {
            int mid = (left + right) >>> 1;
            if (c.compare(pivot, a[mid]) < 0)
                right = mid;
            else
                left = mid + 1;
        }
        assert left == right;

        /*
         * The invariants still hold: pivot >= all in [lo, left) and
         * pivot < all in [left, start), so pivot belongs at left.  Note
         * that if there are elements equal to pivot, left points to the
         * first slot after them -- that's why this sort is stable.
         * Slide elements over to make room for pivot.
         */
        int n = start - left;  // The number of elements to move
        // Switch is just an optimization for arraycopy in default case
        switch (n) {
            case 2:  a[left + 2] = a[left + 1];
            case 1:  a[left + 1] = a[left];
                     break;
            default: System.arraycopy(a, left, a, left + 1, n);
        }
        a[left] = pivot;
    }
}

binarySort 对数组 a[lo:hi] 进行排序,并且a[lo:start] 是已经排好序的。算法的思路是对 a[start:hi] 中的元素,每次使用 binarySearch 为它在 a[lo:start] 中找到相应位置,并插入。
另一个关键函数是mergeCollapse

/**
 * Examines the stack of runs waiting to be merged and merges adjacent runs
 * until the stack invariants are reestablished:
 *
 *     1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
 *     2. runLen[i - 2] > runLen[i - 1]
 *
 * This method is called each time a new run is pushed onto the stack,
 * so the invariants are guaranteed to hold for i < stackSize upon
 * entry to the method.
 */
private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
            if (runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        } else if (runLen[n] <= runLen[n + 1]) {
            mergeAt(n);
        } else {
            break; // Invariant is established
        }
    }
}

它会把已经排序的 run 合并成一个大 run,此大 run 也会排好序。并的过程会一直循环下去,一直到注释里提到的循环不变式得到满足。
合并的时候,有会特别的技巧。假设两个 run 是 run1,run2 ,先用 gallopRight在 run1 里使用 binarySearch 查找 run2 首元素 的位置 k, 那么 run1 中 k 前面的元素就是合并后最小的那些元素。然后,在 run2 中查找 run1 尾元素 的位置 len2 ,那么 run2 中 len2 后面的那些元素就是合并后最大的那些元素。最后,根据len1 与 len2 大小,调用 mergeLo 或者 mergeHi 将剩余元素合并。
篇幅限制,不能全都说完了,感兴趣的读者可以移步TimeSort in java 7