1 /** 2 A comprehensive sorting library for statistical functions. Each function 3 takes N arguments, which are arrays or array-like objects, sorts the first 4 and sorts the rest in lockstep. For merge and insertion sort, if the last 5 argument is a ulong*, increments the dereference of this ulong* by the bubble 6 sort distance between the first argument and the sorted version of the first 7 argument. This is useful for some statistical calculations. 8 9 All sorting functions have the precondition that all parallel input arrays 10 must have the same length. 11 12 Notes: 13 14 Comparison functions must be written such that compFun(x, x) == false. 15 For example, "a < b" is good, "a <= b" is not. 16 17 These functions are heavily optimized for sorting arrays of 18 ints and floats (by far the most common case when doing statistical 19 calculations). In these cases, they can be several times faster than the 20 equivalent functions in std.algorithm. Since sorting is extremely important 21 for non-parametric statistics, this results in important real-world 22 performance gains. However, it comes at a price in terms of generality: 23 24 1. They assume that what they are sorting is cheap to copy via normal 25 assignment. 26 27 2. They don't work at all with general ranges, only arrays and maybe 28 ranges very similar to arrays. 29 30 3. All tuning and micro-optimization is done with ints and floats, not 31 classes, large structs, strings, etc. 32 33 Examples: 34 --- 35 auto foo = [3, 1, 2, 4, 5].dup; 36 auto bar = [8, 6, 7, 5, 3].dup; 37 qsort(foo, bar); 38 assert(foo == [1, 2, 3, 4, 5]); 39 assert(bar == [6, 7, 8, 5, 3]); 40 auto baz = [1.0, 0, -1, -2, -3].dup; 41 mergeSort!("a > b")(bar, foo, baz); 42 assert(bar == [8, 7, 6, 5, 3]); 43 assert(foo == [3, 2, 1, 4, 5]); 44 assert(baz == [-1.0, 0, 1, -2, -3]); 45 --- 46 47 Author: David Simcha 48 */ 49 /* 50 * License: 51 * Boost Software License - Version 1.0 - August 17th, 2003 52 * 53 * Permission is hereby granted, free of charge, to any person or organization 54 * obtaining a copy of the software and accompanying documentation covered by 55 * this license (the "Software") to use, reproduce, display, distribute, 56 * execute, and transmit the Software, and to prepare derivative works of the 57 * Software, and to permit third-parties to whom the Software is furnished to 58 * do so, all subject to the following: 59 * 60 * The copyright notices in the Software and this entire statement, including 61 * the above license grant, this restriction and the following disclaimer, 62 * must be included in all copies of the Software, in whole or in part, and 63 * all derivative works of the Software, unless such copies or derivative 64 * works are solely in the form of machine-executable object code generated by 65 * a source language processor. 66 * 67 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 68 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 69 * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 70 * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 71 * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 72 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 73 * DEALINGS IN THE SOFTWARE. 74 */ 75 76 module dstats.sort; 77 78 import std.traits, std.algorithm, std.math, std.functional, std.math, std.typecons, 79 std.typetuple, std.range, std.array, std.traits, std.ascii : whitespace; 80 81 import dstats.alloc; 82 83 version(unittest) { 84 import std.stdio, std.random; 85 } 86 87 class SortException : Exception { 88 this(string msg) { 89 super(msg); 90 } 91 } 92 93 /* CTFE function. Used in isSimpleComparison.*/ 94 /*private*/ string removeWhitespace(string input) pure nothrow { 95 string ret; 96 foreach(elem; input) { 97 bool shouldAppend = true; 98 foreach(whiteChar; whitespace) { 99 if(elem == whiteChar) { 100 shouldAppend = false; 101 break; 102 } 103 } 104 105 if(shouldAppend) { 106 ret ~= elem; 107 } 108 } 109 return ret; 110 } 111 112 /* Conservatively tests whether the comparison function is simple enough that 113 * we can get away with comparing floats as if they were ints. 114 */ 115 /*private*/ template isSimpleComparison(alias comp) { 116 static if(!isSomeString!(typeof(comp))) { 117 enum bool isSimpleComparison = false; 118 } else { 119 enum bool isSimpleComparison = 120 removeWhitespace(comp) == "a<b" || 121 removeWhitespace(comp) == "a>b"; 122 } 123 } 124 125 /*private*/ bool intIsNaN(I)(I i) { 126 static if(is(I == int) || is(I == uint)) { 127 // IEEE 754 single precision float has a 23-bit significand stored in the 128 // lowest order bits, followed by an 8-bit exponent. A NaN is when the 129 // exponent bits are all ones and the significand is nonzero. 130 enum uint significandMask = 0b111_1111_1111_1111_1111_1111UL; 131 enum uint exponentMask = 0b1111_1111UL << 23; 132 } else static if(is(I == long) || is(I == ulong)) { 133 // IEEE 754 double precision float has a 52-bit significand stored in the 134 // lowest order bits, followed by an 11-bit exponent. A NaN is when the 135 // exponent bits are all ones and the significand is nonzero. 136 enum ulong significandMask = 137 0b1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111UL; 138 enum ulong exponentMask = 0b111_1111_1111UL << 52; 139 } else { 140 static assert(0); 141 } 142 143 return ((i & exponentMask) == exponentMask) && ((i & significandMask) != 0); 144 } 145 146 unittest { 147 // Test on randomly generated integers punned to floats. We expect that 148 // about 1 in 256 will be NaNs. 149 foreach(i; 0..10_000) { 150 uint randInt = uniform(0U, uint.max); 151 assert(std.math.isNaN(*(cast(float*) &randInt)) == intIsNaN(randInt)); 152 } 153 154 // Test on randomly generated integers punned to doubles. We expect that 155 // about 1 in 2048 will be NaNs. 156 foreach(i; 0..1_000_000) { 157 ulong randInt = (cast(ulong) uniform(0U, uint.max) << 32) + uniform(0U, uint.max); 158 assert(std.math.isNaN(*(cast(double*) &randInt)) == intIsNaN(randInt)); 159 } 160 } 161 162 /* Check for NaN and do some bit twiddling so that a float or double can be 163 * compared as an integer. This results in approximately a 40% speedup 164 * compared to just sorting as floats. 165 */ 166 auto prepareForSorting(alias comp, T)(T arr) { 167 static if(isSimpleComparison!comp) { 168 static if(is(T == real[])) { 169 foreach(elem; arr) { 170 if(isNaN(elem)) { 171 throw new SortException("Can't sort NaNs."); 172 } 173 } 174 175 return arr; 176 } else static if(is(T == double[]) || is(T == float[])) { 177 static if(is(T == double[])) { 178 alias long Int; 179 enum signMask = 1UL << 63; 180 } else { 181 alias int Int; 182 enum signMask = 1U << 31; 183 } 184 185 Int[] intArr = cast(Int[]) arr; 186 foreach(i, ref elem; intArr) { 187 if(intIsNaN(elem)) { 188 // Roll back the bit twiddling in case someone catches the 189 // exception, so that they don't see corrupted values. 190 postProcess!comp(intArr[0..i]); 191 192 throw new SortException("Can't sort NaNs."); 193 } 194 195 if(elem & signMask) { 196 // Negative. 197 elem ^= signMask; 198 elem = ~elem; 199 } 200 } 201 202 return intArr; 203 } else { 204 return arr; 205 } 206 207 } else { 208 return arr; 209 } 210 } 211 212 /*private*/ void postProcess(alias comp, T)(T arr) 213 if(!isSimpleComparison!comp || (!is(T == double[]) && !is(T == float[]))) {} 214 215 /* Undo bit twiddling from prepareForSorting() to get back original 216 * floating point numbers. 217 */ 218 /*private*/ void postProcess(alias comp, F)(F arr) 219 if((is(F == double[]) || is(F == float[])) && isSimpleComparison!comp) { 220 static if(is(F == double[])) { 221 alias long Int; 222 enum mask = 1UL << 63; 223 } else { 224 alias int Int; 225 enum mask = 1U << 31; 226 } 227 228 Int[] useMe = cast(Int[]) arr; 229 foreach(ref elem; useMe) { 230 if(elem & mask) { 231 elem = ~elem; 232 elem ^= mask; 233 } 234 } 235 } 236 237 version(unittest) { 238 static void testFloating(alias fun, F)() { 239 F[] testL = new F[1_000]; 240 foreach(ref e; testL) { 241 e = uniform(-1_000_000, 1_000_000); 242 } 243 auto testL2 = testL.dup; 244 245 static if(__traits(isSame, fun, mergeSortTemp)) { 246 auto temp1 = testL.dup; 247 auto temp2 = testL.dup; 248 } 249 250 foreach(i; 0..200) { 251 randomShuffle(zip(testL, testL2)); 252 uint len = uniform(0, 1_000); 253 254 static if(__traits(isSame, fun, mergeSortTemp)) { 255 fun!"a > b"(testL[0..len], testL2[0..len], temp1[0..len], temp2[0..len]); 256 } else { 257 fun!("a > b")(testL[0..len], testL2[0..len]); 258 } 259 260 assert(isSorted!("a > b")(testL[0..len])); 261 assert(testL == testL2, fun.stringof ~ '\t' ~ F.stringof); 262 } 263 } 264 } 265 266 void rotateLeft(T)(T input) 267 if(isRandomAccessRange!(T)) { 268 if(input.length < 2) return; 269 ElementType!(T) temp = input[0]; 270 foreach(i; 1..input.length) { 271 input[i-1] = input[i]; 272 } 273 input[$-1] = temp; 274 } 275 276 void rotateRight(T)(T input) 277 if(isRandomAccessRange!(T)) { 278 if(input.length < 2) return; 279 ElementType!(T) temp = input[$-1]; 280 for(size_t i = input.length - 1; i > 0; i--) { 281 input[i] = input[i-1]; 282 } 283 input[0] = temp; 284 } 285 286 /* Returns the index, NOT the value, of the median of the first, middle, last 287 * elements of data.*/ 288 size_t medianOf3(alias compFun, T)(T[] data) { 289 alias binaryFun!(compFun) comp; 290 immutable size_t mid = data.length / 2; 291 immutable uint result = ((cast(uint) (comp(data[0], data[mid]))) << 2) | 292 ((cast(uint) (comp(data[0], data[$ - 1]))) << 1) | 293 (cast(uint) (comp(data[mid], data[$ - 1]))); 294 295 assert(result != 2 && result != 5 && result < 8); // Cases 2, 5 can't happen. 296 switch(result) { 297 case 1: // 001 298 case 6: // 110 299 return data.length - 1; 300 case 3: // 011 301 case 4: // 100 302 return 0; 303 case 0: // 000 304 case 7: // 111 305 return mid; 306 default: 307 assert(0); 308 } 309 assert(0); 310 } 311 312 unittest { 313 assert(medianOf3!("a < b")([1,2,3,4,5]) == 2); 314 assert(medianOf3!("a < b")([1,2,5,4,3]) == 4); 315 assert(medianOf3!("a < b")([3,2,1,4,5]) == 0); 316 assert(medianOf3!("a < b")([5,2,3,4,1]) == 2); 317 assert(medianOf3!("a < b")([5,2,1,4,3]) == 4); 318 assert(medianOf3!("a < b")([3,2,5,4,1]) == 0); 319 } 320 321 322 /**Quick sort. Unstable, O(N log N) time average, worst 323 * case, O(log N) space, small constant term in time complexity. 324 * 325 * In this implementation, the following steps are taken to avoid the 326 * O(N<sup>2</sup>) worst case of naive quick sorts: 327 * 328 * 1. At each recursion, the median of the first, middle and last elements of 329 * the array is used as the pivot. 330 * 331 * 2. To handle the case of few unique elements, the "Fit Pivot" technique 332 * previously decribed by Andrei Alexandrescu is used. This allows 333 * reasonable performance with few unique elements, with zero overhead 334 * in other cases. 335 * 336 * 3. After a much larger than expected amount of recursion has occured, 337 * this function transitions to a heap sort. This guarantees an O(N log N) 338 * worst case.*/ 339 T[0] qsort(alias compFun = "a < b", T...)(T data) 340 if(T.length != 0) 341 in { 342 assert(data.length > 0); 343 size_t len = data[0].length; 344 foreach(array; data[1..$]) { 345 assert(array.length == len); 346 } 347 } body { 348 if(data[0].length < 25) { 349 // Skip computing logarithm rather than waiting until qsortImpl to 350 // do this. 351 return insertionSort!compFun(data); 352 } 353 354 // Determines the transition point to a heap sort. 355 uint TTL = cast(uint) (log2(cast(real) data[0].length) * 2); 356 357 auto toSort = prepareForSorting!compFun(data[0]); 358 359 /* qsort() throws if an invalid comparison function is passed. Even in 360 * this case, the data should be post-processed so the bit twiddling 361 * hacks for floats can be undone. 362 */ 363 try { 364 qsortImpl!(compFun)(toSort, data[1..$], TTL); 365 } finally { 366 postProcess!compFun(data[0]); 367 } 368 369 return data[0]; 370 } 371 372 //TTL = time to live, before transitioning to heap sort. 373 void qsortImpl(alias compFun, T...)(T data, uint TTL) { 374 alias binaryFun!(compFun) comp; 375 if(data[0].length < 25) { 376 insertionSortImpl!(compFun)(data); 377 return; 378 } 379 if(TTL == 0) { 380 heapSortImpl!(compFun)(data); 381 return; 382 } 383 TTL--; 384 385 { 386 immutable size_t med3 = medianOf3!(comp)(data[0]); 387 foreach(array; data) { 388 auto temp = array[med3]; 389 array[med3] = array[$ - 1]; 390 array[$ - 1] = temp; 391 } 392 } 393 394 T less, greater; 395 size_t lessI = size_t.max, greaterI = data[0].length - 1; 396 397 auto pivot = data[0][$ - 1]; 398 if(comp(pivot, pivot)) { 399 throw new SortException 400 ("Comparison function must be such that compFun(x, x) == false."); 401 } 402 403 while(true) { 404 while(comp(data[0][++lessI], pivot)) {} 405 while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {} 406 407 if(lessI < greaterI) { 408 foreach(array; data) { 409 auto temp = array[lessI]; 410 array[lessI] = array[greaterI]; 411 array[greaterI] = temp; 412 } 413 } else break; 414 } 415 416 foreach(ti, array; data) { 417 auto temp = array[$ - 1]; 418 array[$ - 1] = array[lessI]; 419 array[lessI] = temp; 420 less[ti] = array[0..min(lessI, greaterI + 1)]; 421 greater[ti] = array[lessI + 1..$]; 422 } 423 // Allow tail recursion optimization for larger block. This guarantees 424 // that, given a reasonable amount of stack space, no stack overflow will 425 // occur even in pathological cases. 426 if(greater[0].length > less[0].length) { 427 qsortImpl!(compFun)(less, TTL); 428 qsortImpl!(compFun)(greater, TTL); 429 return; 430 } else { 431 qsortImpl!(compFun)(greater, TTL); 432 qsortImpl!(compFun)(less, TTL); 433 } 434 } 435 436 unittest { 437 { // Test integer. 438 uint[] test = new uint[1_000]; 439 foreach(ref e; test) { 440 e = uniform(0, 100); 441 } 442 auto test2 = test.dup; 443 foreach(i; 0..1_000) { 444 randomShuffle(zip(test, test2)); 445 uint len = uniform(0, 1_000); 446 qsort(test[0..len], test2[0..len]); 447 assert(isSorted(test[0..len])); 448 assert(test == test2); 449 } 450 } 451 452 testFloating!(qsort, float)(); 453 testFloating!(qsort, double)(); 454 testFloating!(qsort, real)(); 455 456 auto nanArr = [double.nan, 1.0]; 457 try { 458 qsort(nanArr); 459 assert(0); 460 } catch(SortException) {} 461 } 462 463 /* Keeps track of what array merge sort data is in. This is a speed hack to 464 * copy back and forth less.*/ 465 /*private*/ enum { 466 DATA, 467 TEMP 468 } 469 470 /**Merge sort. O(N log N) time, O(N) space, small constant. Stable sort. 471 * If last argument is a ulong* instead of an array-like type, 472 * the dereference of the ulong* will be incremented by the bubble sort 473 * distance between the input array and the sorted version. This is useful 474 * in some statistics functions such as Kendall's tau.*/ 475 T[0] mergeSort(alias compFun = "a < b", T...)(T data) 476 if(T.length != 0) 477 in { 478 assert(data.length > 0); 479 size_t len = data[0].length; 480 foreach(array; data[1..$]) { 481 static if(!is(typeof(array) == ulong*)) 482 assert(array.length == len); 483 } 484 } body { 485 if(data[0].length < 65) { //Avoid mem allocation. 486 return insertionSortImpl!(compFun)(data); 487 } 488 static if(is(T[$ - 1] == ulong*)) { 489 enum dl = data.length - 1; 490 alias data[$ - 1] swapCount; 491 } else { 492 enum dl = data.length; 493 alias TypeTuple!() swapCount; // Place holder. 494 } 495 496 auto keyArr = prepareForSorting!compFun(data[0]); 497 auto toSort = TypeTuple!(keyArr, data[1..dl]); 498 499 typeof(toSort) temp; 500 auto alloc = newRegionAllocator(); 501 foreach(i, array; temp) { 502 temp[i] = alloc.uninitializedArray!(typeof(temp[i][0])[])(data[i].length); 503 } 504 505 uint res = mergeSortImpl!(compFun)(toSort, temp, swapCount); 506 if(res == TEMP) { 507 foreach(ti, array; temp) { 508 toSort[ti][0..$] = temp[ti][0..$]; 509 } 510 } 511 512 postProcess!compFun(data[0]); 513 return data[0]; 514 } 515 516 unittest { 517 uint[] test = new uint[1_000], stability = new uint[1_000]; 518 uint[] temp1 = new uint[1_000], temp2 = new uint[1_000]; 519 foreach(ref e; test) { 520 e = uniform(0, 100); //Lots of ties. 521 } 522 foreach(i; 0..100) { 523 ulong mergeCount = 0, bubbleCount = 0; 524 foreach(j, ref e; stability) { 525 e = cast(uint) j; 526 } 527 randomShuffle(test); 528 uint len = uniform(0, 1_000); 529 // Testing bubble sort distance against bubble sort, 530 // since bubble sort distance computed by bubble sort 531 // is straightforward, unlikely to contain any subtle bugs. 532 bubbleSort(test[0..len].dup, &bubbleCount); 533 if(i & 1) // Test both temp and non-temp branches. 534 mergeSort(test[0..len], stability[0..len], &mergeCount); 535 else 536 mergeSortTemp(test[0..len], stability[0..len], temp1[0..len], 537 temp2[0..len], &mergeCount); 538 assert(bubbleCount == mergeCount); 539 assert(isSorted(test[0..len])); 540 foreach(j; 1..len) { 541 if(test[j - 1] == test[j]) { 542 assert(stability[j - 1] < stability[j]); 543 } 544 } 545 } 546 // Test without swapCounts. 547 foreach(i; 0..1000) { 548 foreach(j, ref e; stability) { 549 e = cast(uint) j; 550 } 551 randomShuffle(test); 552 uint len = uniform(0, 1_000); 553 if(i & 1) // Test both temp and non-temp branches. 554 mergeSort(test[0..len], stability[0..len]); 555 else 556 mergeSortTemp(test[0..len], stability[0..len], temp1[0..len], 557 temp2[0..len]); 558 assert(isSorted(test[0..len])); 559 foreach(j; 1..len) { 560 if(test[j - 1] == test[j]) { 561 assert(stability[j - 1] < stability[j]); 562 } 563 } 564 } 565 566 testFloating!(mergeSort, float)(); 567 testFloating!(mergeSort, double)(); 568 testFloating!(mergeSort, real)(); 569 570 testFloating!(mergeSortTemp, float)(); 571 testFloating!(mergeSortTemp, double)(); 572 testFloating!(mergeSortTemp, real)(); 573 } 574 575 /**Merge sort, allowing caller to provide a temp variable. This allows 576 * recycling instead of repeated allocations. If D is data, T is temp, 577 * and U is a ulong* for calculating bubble sort distance, this can be called 578 * as mergeSortTemp(D, D, D, T, T, T, U) or mergeSortTemp(D, D, D, T, T, T) 579 * where each D has a T of corresponding type. 580 * 581 * Examples: 582 * --- 583 * int[] foo = [3, 1, 2, 4, 5].dup; 584 * int[] temp = new uint[5]; 585 * mergeSortTemp!("a < b")(foo, temp); 586 * assert(foo == [1, 2, 3, 4, 5]); // The contents of temp will be undefined. 587 * foo = [3, 1, 2, 4, 5].dup; 588 * real bar = [3.14L, 15.9, 26.5, 35.8, 97.9]; 589 * real temp2 = new real[5]; 590 * mergeSortTemp(foo, bar, temp, temp2); 591 * assert(foo == [1, 2, 3, 4, 5]); 592 * assert(bar == [15.9L, 26.5, 3.14, 35.8, 97.9]); 593 * // The contents of both temp and temp2 will be undefined. 594 * --- 595 */ 596 T[0] mergeSortTemp(alias compFun = "a < b", T...)(T data) 597 if(T.length != 0) 598 in { 599 assert(data.length > 0); 600 size_t len = data[0].length; 601 foreach(array; data[1..$]) { 602 static if(!is(typeof(array) == ulong*)) 603 assert(array.length == len); 604 } 605 } body { 606 static if(is(T[$ - 1] == ulong*)) { 607 enum dl = data.length - 1; 608 } else { 609 enum dl = data.length; 610 } 611 612 auto keyArr = prepareForSorting!compFun(data[0]); 613 auto keyTemp = cast(typeof(keyArr)) data[dl / 2]; 614 auto toSort = TypeTuple!( 615 keyArr, 616 data[1..dl / 2], 617 keyTemp, 618 data[dl / 2 + 1..$] 619 ); 620 621 uint res = mergeSortImpl!(compFun)(toSort); 622 623 if(res == TEMP) { 624 foreach(ti, array; toSort[0..$ / 2]) { 625 toSort[ti][0..$] = toSort[ti + dl / 2][0..$]; 626 } 627 } 628 629 postProcess!compFun(data[0]); 630 return data[0]; 631 } 632 633 /*private*/ uint mergeSortImpl(alias compFun = "a < b", T...)(T dataIn) { 634 static if(is(T[$ - 1] == ulong*)) { 635 alias dataIn[$ - 1] swapCount; 636 alias dataIn[0..dataIn.length / 2] data; 637 alias dataIn[dataIn.length / 2..$ - 1] temp; 638 } else { // Make empty dummy tuple. 639 alias TypeTuple!() swapCount; 640 alias dataIn[0..dataIn.length / 2] data; 641 alias dataIn[dataIn.length / 2..$] temp; 642 } 643 644 if(data[0].length < 50) { 645 insertionSortImpl!(compFun)(data, swapCount); 646 return DATA; 647 } 648 size_t half = data[0].length / 2; 649 typeof(data) left, right, tempLeft, tempRight; 650 foreach(ti, array; data) { 651 left[ti] = array[0..half]; 652 right[ti] = array[half..$]; 653 tempLeft[ti] = temp[ti][0..half]; 654 tempRight[ti] = temp[ti][half..$]; 655 } 656 657 /* Implementation note: The lloc, rloc stuff is a hack to avoid constantly 658 * copying data back and forth between the data and temp arrays. 659 * Instad of copying every time, I keep track of which array the last merge 660 * went into, and only copy at the end or if the two sides ended up in 661 * different arrays.*/ 662 uint lloc = mergeSortImpl!(compFun)(left, tempLeft, swapCount); 663 uint rloc = mergeSortImpl!(compFun)(right, tempRight, swapCount); 664 if(lloc == DATA && rloc == TEMP) { 665 foreach(ti, array; tempLeft) { 666 array[] = left[ti][]; 667 } 668 lloc = TEMP; 669 } else if(lloc == TEMP && rloc == DATA) { 670 foreach(ti, array; tempRight) { 671 array[] = right[ti][]; 672 } 673 } 674 if(lloc == DATA) { 675 merge!(compFun)(left, right, temp, swapCount); 676 return TEMP; 677 } else { 678 merge!(compFun)(tempLeft, tempRight, data, swapCount); 679 return DATA; 680 } 681 } 682 683 /*private*/ void merge(alias compFun, T...)(T data) { 684 alias binaryFun!(compFun) comp; 685 686 static if(is(T[$ - 1] == ulong*)) { 687 enum dl = data.length - 1; //Length after removing swapCount; 688 alias data[$ - 1] swapCount; 689 } else { 690 enum dl = data.length; 691 } 692 693 static assert(dl % 3 == 0); 694 alias data[0..dl / 3] left; 695 alias data[dl / 3..dl * 2 / 3] right; 696 alias data[dl * 2 / 3..dl] result; 697 static assert(left.length == right.length && right.length == result.length); 698 size_t i = 0, l = 0, r = 0; 699 while(l < left[0].length && r < right[0].length) { 700 if(comp(right[0][r], left[0][l])) { 701 702 static if(is(T[$ - 1] == ulong*)) { 703 *swapCount += left[0].length - l; 704 } 705 706 foreach(ti, array; result) { 707 result[ti][i] = right[ti][r]; 708 } 709 r++; 710 } else { 711 foreach(ti, array; result) { 712 result[ti][i] = left[ti][l]; 713 } 714 l++; 715 } 716 i++; 717 } 718 if(right[0].length > r) { 719 foreach(ti, array; result) { 720 result[ti][i..$] = right[ti][r..$]; 721 } 722 } else { 723 foreach(ti, array; result) { 724 result[ti][i..$] = left[ti][l..$]; 725 } 726 } 727 } 728 729 /**In-place merge sort, based on C++ STL's stable_sort(). O(N log<sup>2</sup> N) 730 * time complexity, O(1) space complexity, stable. Much slower than plain 731 * old mergeSort(), so only use it if you really need the O(1) space.*/ 732 T[0] mergeSortInPlace(alias compFun = "a < b", T...)(T data) 733 if(T.length != 0) 734 in { 735 assert(data.length > 0); 736 size_t len = data[0].length; 737 foreach(array; data[1..$]) { 738 assert(array.length == len); 739 } 740 } body { 741 auto toSort = prepareForSorting!compFun(data[0]); 742 mergeSortInPlaceImpl!compFun(toSort, data[1..$]); 743 postProcess!compFun(data[0]); 744 return data[0]; 745 } 746 747 /*private*/ T[0] mergeSortInPlaceImpl(alias compFun, T...)(T data) { 748 if (data[0].length <= 100) 749 return insertionSortImpl!(compFun)(data); 750 751 T left, right; 752 foreach(ti, array; data) { 753 left[ti] = array[0..$ / 2]; 754 right[ti] = array[$ / 2..$]; 755 } 756 757 mergeSortInPlace!(compFun, T)(right); 758 mergeSortInPlace!(compFun, T)(left); 759 mergeInPlace!(compFun)(data, data[0].length / 2); 760 return data[0]; 761 } 762 763 unittest { 764 uint[] test = new uint[1_000], stability = new uint[1_000]; 765 foreach(ref e; test) { 766 e = uniform(0, 100); //Lots of ties. 767 } 768 uint[] test2 = test.dup; 769 foreach(i; 0..1000) { 770 foreach(j, ref e; stability) { 771 e = cast(uint) j; 772 } 773 randomShuffle(zip(test, test2)); 774 uint len = uniform(0, 1_000); 775 mergeSortInPlace(test[0..len], test2[0..len], stability[0..len]); 776 assert(isSorted(test[0..len])); 777 assert(test == test2); 778 foreach(j; 1..len) { 779 if(test[j - 1] == test[j]) { 780 assert(stability[j - 1] < stability[j]); 781 } 782 } 783 } 784 785 testFloating!(mergeSortInPlace, float)(); 786 testFloating!(mergeSortInPlace, double)(); 787 testFloating!(mergeSortInPlace, real)(); 788 } 789 790 // Loosely based on C++ STL's __merge_without_buffer(). 791 /*private*/ void mergeInPlace(alias compFun = "a < b", T...)(T data, size_t middle) { 792 alias binaryFun!(compFun) comp; 793 794 static size_t largestLess(T)(T[] data, T value) { 795 return assumeSorted!(comp)(data).lowerBound(value).length; 796 } 797 798 static size_t smallestGr(T)(T[] data, T value) { 799 return data.length - 800 assumeSorted!(comp)(data).upperBound(value).length; 801 } 802 803 804 if (data[0].length < 2 || middle == 0 || middle == data[0].length) { 805 return; 806 } 807 808 if (data[0].length == 2) { 809 if(comp(data[0][1], data[0][0])) { 810 foreach(array; data) { 811 auto temp = array[0]; 812 array[0] = array[1]; 813 array[1] = temp; 814 } 815 } 816 return; 817 } 818 819 size_t half1, half2, firstCut, secondCut; 820 821 if (middle > data[0].length - middle) { 822 half1 = middle / 2; 823 auto pivot = data[0][half1]; 824 half2 = largestLess(data[0][middle..$], pivot); 825 } else { 826 half2 = (data[0].length - middle) / 2; 827 auto pivot = data[0][half2 + middle]; 828 half1 = smallestGr(data[0][0..middle], pivot); 829 } 830 831 foreach(array; data) { 832 bringToFront(array[half1..middle], array[middle..middle + half2]); 833 } 834 size_t newMiddle = half1 + half2; 835 836 T left, right; 837 foreach(ti, array; data) { 838 left[ti] = array[0..newMiddle]; 839 right[ti] = array[newMiddle..$]; 840 } 841 842 mergeInPlace!(compFun, T)(left, half1); 843 mergeInPlace!(compFun, T)(right, half2 + middle - newMiddle); 844 } 845 846 847 /**Heap sort. Unstable, O(N log N) time average and worst case, O(1) space, 848 * large constant term in time complexity.*/ 849 T[0] heapSort(alias compFun = "a < b", T...)(T data) 850 if(T.length != 0) 851 in { 852 assert(data.length > 0); 853 size_t len = data[0].length; 854 foreach(array; data[1..$]) { 855 assert(array.length == len); 856 } 857 } body { 858 auto toSort = prepareForSorting!compFun(data[0]); 859 heapSortImpl!compFun(toSort, data[1..$]); 860 postProcess!compFun(data[0]); 861 return data[0]; 862 } 863 864 /*private*/ T[0] heapSortImpl(alias compFun, T...)(T input) { 865 // Heap sort has such a huge constant that insertion sort's faster for N < 866 // 100 (for reals; even larger for smaller types). 867 if(input[0].length <= 100) { 868 return insertionSortImpl!(compFun)(input); 869 } 870 871 alias binaryFun!(compFun) comp; 872 if(input[0].length < 2) return input[0]; 873 makeMultiHeap!(compFun)(input); 874 for(size_t end = input[0].length - 1; end > 0; end--) { 875 foreach(ti, ia; input) { 876 auto temp = ia[end]; 877 ia[end] = ia[0]; 878 ia[0] = temp; 879 } 880 multiSiftDown!(compFun)(input, 0, end); 881 } 882 return input[0]; 883 } 884 885 unittest { 886 uint[] test = new uint[1_000]; 887 foreach(ref e; test) { 888 e = uniform(0, 100_000); 889 } 890 auto test2 = test.dup; 891 foreach(i; 0..1_000) { 892 randomShuffle(zip(test, test2)); 893 uint len = uniform(0, 1_000); 894 heapSort(test[0..len], test2[0..len]); 895 assert(isSorted(test[0..len])); 896 assert(test == test2); 897 } 898 899 testFloating!(heapSort, float)(); 900 testFloating!(heapSort, double)(); 901 testFloating!(heapSort, real)(); 902 } 903 904 void makeMultiHeap(alias compFun = "a < b", T...)(T input) { 905 if(input[0].length < 2) 906 return; 907 alias binaryFun!(compFun) comp; 908 for(sizediff_t start = (input[0].length - 1) / 2; start >= 0; start--) { 909 multiSiftDown!(compFun)(input, start, input[0].length); 910 } 911 } 912 913 void multiSiftDown(alias compFun = "a < b", T...) 914 (T input, size_t root, size_t end) { 915 alias binaryFun!(compFun) comp; 916 alias input[0] a; 917 while(root * 2 + 1 < end) { 918 size_t child = root * 2 + 1; 919 if(child + 1 < end && comp(a[child], a[child + 1])) { 920 child++; 921 } 922 if(comp(a[root], a[child])) { 923 foreach(ia; input) { 924 auto temp = ia[root]; 925 ia[root] = ia[child]; 926 ia[child] = temp; 927 } 928 root = child; 929 } 930 else return; 931 } 932 } 933 934 /**Insertion sort. O(N<sup>2</sup>) time worst, average case, O(1) space, VERY 935 * small constant, which is why it's useful for sorting small subarrays in 936 * divide and conquer algorithms. If last argument is a ulong*, increments 937 * the dereference of this argument by the bubble sort distance between the 938 * input array and the sorted version of the input.*/ 939 T[0] insertionSort(alias compFun = "a < b", T...)(T data) 940 in { 941 assert(data.length > 0); 942 size_t len = data[0].length; 943 foreach(array; data[1..$]) { 944 static if(!is(typeof(array) == ulong*)) 945 assert(array.length == len); 946 } 947 } body { 948 auto toSort = prepareForSorting!compFun(data[0]); 949 insertionSortImpl!compFun(toSort, data[1..$]); 950 postProcess!compFun(data[0]); 951 return data[0]; 952 } 953 954 private template IndexType(T) { 955 alias typeof(T.init[0]) IndexType; 956 } 957 958 /*private*/ T[0] insertionSortImpl(alias compFun, T...)(T data) { 959 alias binaryFun!(compFun) comp; 960 static if(is(T[$ - 1] == ulong*)) { 961 enum dl = data.length - 1; 962 alias data[$ - 1] swapCount; 963 } else { 964 enum dl = data.length; 965 } 966 967 alias data[0] keyArray; 968 if(keyArray.length < 2) { 969 return keyArray; 970 } 971 972 // Yes, I measured this, caching this value is actually faster on DMD. 973 immutable maxJ = keyArray.length - 1; 974 for(size_t i = keyArray.length - 2; i != size_t.max; --i) { 975 size_t j = i; 976 977 Tuple!(staticMap!(IndexType, typeof(data[0..dl]))) temp = void; 978 foreach(ti, Type; typeof(data[0..dl])) { 979 static if(hasElaborateAssign!Type) { 980 emplace(&(temp.field[ti]), data[ti][i]); 981 } else { 982 temp.field[ti] = data[ti][i]; 983 } 984 } 985 986 for(; j < maxJ && comp(keyArray[j + 1], temp.field[0]); ++j) { 987 // It's faster to do all copying here than to call rotateLeft() 988 // later, probably due to better ILP. 989 foreach(array; data[0..dl]) { 990 array[j] = array[j + 1]; 991 } 992 } 993 994 foreach(ti, Unused; typeof(temp.field)) { 995 data[ti][j] = temp.field[ti]; 996 } 997 998 static if(is(typeof(swapCount))) { 999 *swapCount += (j - i); //Increment swapCount variable. 1000 } 1001 } 1002 1003 return keyArray; 1004 } 1005 1006 unittest { 1007 uint[] test = new uint[100], stability = new uint[100]; 1008 foreach(ref e; test) { 1009 e = uniform(0, 100); //Lots of ties. 1010 } 1011 foreach(i; 0..1_000) { 1012 ulong insertCount = 0, bubbleCount = 0; 1013 foreach(j, ref e; stability) { 1014 e = cast(uint) j; 1015 } 1016 randomShuffle(test); 1017 uint len = uniform(0, 100); 1018 // Testing bubble sort distance against bubble sort, 1019 // since bubble sort distance computed by bubble sort 1020 // is straightforward, unlikely to contain any subtle bugs. 1021 bubbleSort(test[0..len].dup, &bubbleCount); 1022 insertionSort(test[0..len], stability[0..len], &insertCount); 1023 assert(bubbleCount == insertCount); 1024 assert(isSorted(test[0..len])); 1025 foreach(j; 1..len) { 1026 if(test[j - 1] == test[j]) { 1027 assert(stability[j - 1] < stability[j]); 1028 } 1029 } 1030 } 1031 } 1032 1033 // Kept around only because it's easy to implement, and therefore good for 1034 // testing more complex sort functions against. Especially useful for bubble 1035 // sort distance, since it's straightforward with a bubble sort, and not with 1036 // a merge sort or insertion sort. 1037 version(unittest) { 1038 T[0] bubbleSort(alias compFun = "a < b", T...)(T data) { 1039 alias binaryFun!(compFun) comp; 1040 static if(is(T[$ - 1] == ulong*)) 1041 enum dl = data.length - 1; 1042 else enum dl = data.length; 1043 if(data[0].length < 2) 1044 return data[0]; 1045 bool swapExecuted; 1046 foreach(i; 0..data[0].length) { 1047 swapExecuted = false; 1048 foreach(j; 1..data[0].length) { 1049 if(comp(data[0][j], data[0][j - 1])) { 1050 swapExecuted = true; 1051 static if(is(T[$ - 1] == ulong*)) 1052 (*(data[$-1]))++; 1053 foreach(array; data[0..dl]) 1054 swap(array[j-1], array[j]); 1055 } 1056 } 1057 if(!swapExecuted) return data[0]; 1058 } 1059 return data[0]; 1060 } 1061 } 1062 1063 unittest { 1064 //Sanity check for bubble sort distance. 1065 uint[] test = [4, 5, 3, 2, 1]; 1066 ulong dist = 0; 1067 bubbleSort(test, &dist); 1068 assert(dist == 9); 1069 dist = 0; 1070 test = [6, 1, 2, 4, 5, 3]; 1071 bubbleSort(test, &dist); 1072 assert(dist == 7); 1073 } 1074 1075 /**Returns the kth largest/smallest element (depending on compFun, 0-indexed) 1076 * in the input array in O(N) time. Allocates memory, does not modify input 1077 * array.*/ 1078 T quickSelect(alias compFun = "a < b", T)(T[] data, sizediff_t k) { 1079 auto alloc = newRegionAllocator(); 1080 auto dataDup = alloc.array(data); 1081 return partitionK!(compFun)(dataDup, k); 1082 } 1083 1084 /**Partitions the input data according to compFun, such that position k contains 1085 * the kth largest/smallest element according to compFun. For all elements e 1086 * with indices < k, !compFun(data[k], e) is guaranteed to be true. For all 1087 * elements e with indices > k, !compFun(e, data[k]) is guaranteed to be true. 1088 * For example, if compFun is "a < b", all elements with indices < k will be 1089 * <= data[k], and all elements with indices larger than k will be >= k. 1090 * Reorders any additional input arrays in lockstep. 1091 * 1092 * Examples: 1093 * --- 1094 * auto foo = [3, 1, 5, 4, 2].dup; 1095 * auto secondSmallest = partitionK(foo, 1); 1096 * assert(secondSmallest == 2); 1097 * foreach(elem; foo[0..1]) { 1098 * assert(elem <= foo[1]); 1099 * } 1100 * foreach(elem; foo[2..$]) { 1101 * assert(elem >= foo[1]); 1102 * } 1103 * --- 1104 * 1105 * Returns: The kth element of the array. 1106 */ 1107 ElementType!(T[0]) partitionK(alias compFun = "a < b", T...)(T data, ptrdiff_t k) 1108 in { 1109 assert(data.length > 0); 1110 size_t len = data[0].length; 1111 foreach(array; data[1..$]) { 1112 assert(array.length == len); 1113 } 1114 } body { 1115 // Don't use the float-to-int trick because it's actually slower here 1116 // because the main part of the algorithm is O(N), not O(N log N). 1117 return partitionKImpl!compFun(data, k); 1118 } 1119 1120 /*private*/ ElementType!(T[0]) partitionKImpl(alias compFun, T...)(T data, ptrdiff_t k) { 1121 alias binaryFun!(compFun) comp; 1122 1123 { 1124 immutable size_t med3 = medianOf3!(comp)(data[0]); 1125 foreach(array; data) { 1126 auto temp = array[med3]; 1127 array[med3] = array[$ - 1]; 1128 array[$ - 1] = temp; 1129 } 1130 } 1131 1132 ptrdiff_t lessI = -1, greaterI = data[0].length - 1; 1133 auto pivot = data[0][$ - 1]; 1134 while(true) { 1135 while(comp(data[0][++lessI], pivot)) {} 1136 while(greaterI > 0 && comp(pivot, data[0][--greaterI])) {} 1137 1138 if(lessI < greaterI) { 1139 foreach(array; data) { 1140 auto temp = array[lessI]; 1141 array[lessI] = array[greaterI]; 1142 array[greaterI] = temp; 1143 } 1144 } else break; 1145 } 1146 foreach(array; data) { 1147 auto temp = array[lessI]; 1148 array[lessI] = array[$ - 1]; 1149 array[$ - 1] = temp; 1150 } 1151 1152 if((greaterI < k && lessI >= k) || lessI == k) { 1153 return data[0][k]; 1154 } else if(lessI < k) { 1155 foreach(ti, array; data) { 1156 data[ti] = array[lessI + 1..$]; 1157 } 1158 return partitionK!(compFun, T)(data, k - lessI - 1); 1159 } else { 1160 foreach(ti, array; data) { 1161 data[ti] = array[0..min(greaterI + 1, lessI)]; 1162 } 1163 return partitionK!(compFun, T)(data, k); 1164 } 1165 } 1166 1167 template ArrayElemType(T : T[]) { 1168 alias T ArrayElemType; 1169 } 1170 1171 unittest { 1172 enum n = 1000; 1173 uint[] test = new uint[n]; 1174 uint[] test2 = new uint[n]; 1175 uint[] lockstep = new uint[n]; 1176 foreach(ref e; test) { 1177 e = uniform(0, 1000); 1178 } 1179 foreach(i; 0..1_000) { 1180 test2[] = test[]; 1181 lockstep[] = test[]; 1182 uint len = uniform(0, n - 1) + 1; 1183 qsort!("a > b")(test2[0..len]); 1184 int k = uniform(0, len); 1185 auto qsRes = partitionK!("a > b")(test[0..len], lockstep[0..len], k); 1186 assert(qsRes == test2[k]); 1187 foreach(elem; test[0..k]) { 1188 assert(elem >= test[k]); 1189 } 1190 foreach(elem; test[k + 1..len]) { 1191 assert(elem <= test[k]); 1192 } 1193 assert(test == lockstep); 1194 } 1195 } 1196 1197 /**Given a set of data points entered through the put function, this output range 1198 * maintains the invariant that the top N according to compFun will be 1199 * contained in the data structure. Uses a heap internally, O(log N) insertion 1200 * time. Good for finding the largest/smallest N elements of a very large 1201 * dataset that cannot be sorted quickly in its entirety, and may not even fit 1202 * in memory. If less than N datapoints have been entered, all are contained in 1203 * the structure. 1204 * 1205 * Examples: 1206 * --- 1207 * Random gen; 1208 * gen.seed(unpredictableSeed); 1209 * uint[] nums = seq(0U, 100U); 1210 * auto less = TopN!(uint, "a < b")(10); 1211 * auto more = TopN!(uint, "a > b")(10); 1212 * randomShuffle(nums, gen); 1213 * foreach(n; nums) { 1214 * less.put(n); 1215 * more.put(n); 1216 * } 1217 * assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]); 1218 * assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]); 1219 * --- 1220 */ 1221 struct TopN(T, alias compFun = "a > b") { 1222 private: 1223 alias binaryFun!(compFun) comp; 1224 uint n; 1225 uint nAdded; 1226 1227 T[] nodes; 1228 public: 1229 /** The variable ntop controls how many elements are retained.*/ 1230 this(uint ntop) { 1231 n = ntop; 1232 nodes = new T[n]; 1233 } 1234 1235 /** Insert an element into the topN struct.*/ 1236 void put(T elem) { 1237 if(nAdded < n) { 1238 nodes[nAdded] = elem; 1239 if(nAdded == n - 1) { 1240 makeMultiHeap!(comp)(nodes); 1241 } 1242 nAdded++; 1243 } else if(nAdded >= n) { 1244 if(comp(elem, nodes[0])) { 1245 nodes[0] = elem; 1246 multiSiftDown!(comp)(nodes, 0, nodes.length); 1247 } 1248 } 1249 } 1250 1251 /**Get the elements currently in the struct. Returns a reference to 1252 * internal state, elements will be in an arbitrary order. Cheap.*/ 1253 T[] getElements() { 1254 return nodes[0..min(n, nAdded)]; 1255 } 1256 1257 /**Returns the elements sorted by compFun. The array returned is a 1258 * duplicate of the input array. Not cheap.*/ 1259 T[] getSorted() { 1260 return qsort!(comp)(nodes[0..min(n, nAdded)].dup); 1261 } 1262 } 1263 1264 unittest { 1265 alias TopN!(uint, "a < b") TopNLess; 1266 alias TopN!(uint, "a > b") TopNGreater; 1267 Random gen; 1268 gen.seed(unpredictableSeed); 1269 uint[] nums = new uint[100]; 1270 foreach(i, ref n; nums) { 1271 n = cast(uint) i; 1272 } 1273 foreach(i; 0..100) { 1274 auto less = TopNLess(10); 1275 auto more = TopNGreater(10); 1276 randomShuffle(nums, gen); 1277 foreach(n; nums) { 1278 less.put(n); 1279 more.put(n); 1280 } 1281 assert(less.getSorted == [0U, 1,2,3,4,5,6,7,8,9]); 1282 assert(more.getSorted == [99U, 98, 97, 96, 95, 94, 93, 92, 91, 90]); 1283 } 1284 foreach(i; 0..100) { 1285 auto less = TopNLess(10); 1286 auto more = TopNGreater(10); 1287 randomShuffle(nums, gen); 1288 foreach(n; nums[0..5]) { 1289 less.put(n); 1290 more.put(n); 1291 } 1292 assert(less.getSorted == qsort!("a < b")(nums[0..5])); 1293 assert(more.getSorted == qsort!("a > b")(nums[0..5])); 1294 } 1295 } 1296