1 /**Basic information theory.  Joint entropy, mutual information, conditional
2  * mutual information.  This module uses the base 2 definition of these
3  * quantities, i.e, entropy, mutual info, etc. are output in bits.
4  *
5  * Author:  David Simcha*/
6  /*
7  * License:
8  * Boost Software License - Version 1.0 - August 17th, 2003
9  *
10  * Permission is hereby granted, free of charge, to any person or organization
11  * obtaining a copy of the software and accompanying documentation covered by
12  * this license (the "Software") to use, reproduce, display, distribute,
13  * execute, and transmit the Software, and to prepare derivative works of the
14  * Software, and to permit third-parties to whom the Software is furnished to
15  * do so, all subject to the following:
16  *
17  * The copyright notices in the Software and this entire statement, including
18  * the above license grant, this restriction and the following disclaimer,
19  * must be included in all copies of the Software, in whole or in part, and
20  * all derivative works of the Software, unless such copies or derivative
21  * works are solely in the form of machine-executable object code generated by
22  * a source language processor.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26  * FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
27  * SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
28  * FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
29  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30  * DEALINGS IN THE SOFTWARE.
31  */
32 
33 module dstats.infotheory;
34 
35 import std.traits, std.math, std.typetuple, std.functional, std.range,
36        std.array, std.typecons, std.algorithm;
37 
38 import dstats.base, dstats.alloc;
39 import dstats.summary : sum;
40 import dstats.distrib : chiSquareCDFR;
41 
42 import dstats.tests : toContingencyScore, gTestContingency;
43 
44 version(unittest) {
45     import std.stdio, std.bigint, dstats.tests : gTestObs;
46 }
47 
48 /**This function calculates the Shannon entropy of a forward range that is
49  * treated as frequency counts of a set of discrete observations.
50  *
51  * Examples:
52  * ---
53  * double uniform3 = entropyCounts([4, 4, 4]);
54  * assert(approxEqual(uniform3, log2(3)));
55  * double uniform4 = entropyCounts([5, 5, 5, 5]);
56  * assert(approxEqual(uniform4, 2));
57  * ---
58  */
59 double entropyCounts(T)(T data)
60 if(isForwardRange!(T) && doubleInput!(T)) {
61     auto save = data.save();
62     return entropyCounts(save, sum!(T, double)(data));
63 }
64 
65 double entropyCounts(T)(T data, double n)
66 if(isIterable!(T)) {
67     immutable double nNeg1 = 1.0 / n;
68     double entropy = 0;
69     foreach(value; data) {
70         if(value == 0)
71             continue;
72         double pxi = cast(double) value * nNeg1;
73         entropy -= pxi * log2(pxi);
74     }
75     return entropy;
76 }
77 
78 unittest {
79     double uniform3 = entropyCounts([4, 4, 4].dup);
80     assert(approxEqual(uniform3, log2(3)));
81     double uniform4 = entropyCounts([5, 5, 5, 5].dup);
82     assert(approxEqual(uniform4, 2));
83     assert(entropyCounts([2,2].dup)==1);
84     assert(entropyCounts([5.1,5.1,5.1,5.1].dup)==2);
85     assert(approxEqual(entropyCounts([1,2,3,4,5].dup), 2.1492553971685));
86 }
87 
88 template FlattenType(T...) {
89     alias FlattenTypeImpl!(T).ret FlattenType;
90 }
91 
92 template FlattenTypeImpl(T...) {
93     static if(T.length == 0) {
94         alias TypeTuple!() ret;
95     } else {
96         T[0] j;
97         static if(is(typeof(j._jointRanges))) {
98             alias TypeTuple!(typeof(j._jointRanges), FlattenType!(T[1..$])) ret;
99         } else {
100             alias TypeTuple!(T[0], FlattenType!(T[1..$])) ret;
101         }
102     }
103 }
104 
105 private Joint!(FlattenType!(T, U)) flattenImpl(T, U...)(T start, U rest) {
106     static if(rest.length == 0) {
107         return start;
108     } else static if(is(typeof(rest[0]._jointRanges))) {
109         return flattenImpl(jointImpl(start.tupleof, rest[0]._jointRanges), rest[1..$]);
110     } else {
111         return flattenImpl(jointImpl(start.tupleof, rest[0]), rest[1..$]);
112     }
113 }
114 
115 Joint!(FlattenType!(T)) flatten(T...)(T args) {
116     static assert(args.length > 0);
117     static if(is(typeof(args[0]._jointRanges))) {
118         auto myTuple = args[0];
119     } else {
120         auto myTuple = jointImpl(args[0]);
121     }
122     static if(args.length == 1) {
123         return myTuple;
124     } else {
125         return flattenImpl(myTuple, args[1..$]);
126     }
127 }
128 
129 /**Bind a set of ranges together to represent a joint probability distribution.
130  *
131  * Examples:
132  * ---
133  * auto foo = [1,2,3,1,1];
134  * auto bar = [2,4,6,2,2];
135  * auto e = entropy(joint(foo, bar));  // Calculate joint entropy of foo, bar.
136  * ---
137  */
138 Joint!(FlattenType!(T)) joint(T...)(T args) {
139     return jointImpl(flatten(args).tupleof);
140 }
141 
142 Joint!(T) jointImpl(T...)(T args) {
143     return Joint!(T)(args);
144 }
145 
146 /**Iterate over a set of ranges by value in lockstep and return an ObsEnt,
147  * which is used internally by entropy functions on each iteration.*/
148 struct Joint(T...) {
149     T _jointRanges;
150 
151     @property ObsEnt!(ElementsTuple!(T)) front() {
152         alias ElementsTuple!(T) E;
153         alias ObsEnt!(E) rt;
154         rt ret;
155         foreach(ti, elem; _jointRanges) {
156             ret.tupleof[ti] = elem.front;
157         }
158         return ret;
159     }
160 
161     void popFront() {
162         foreach(ti, elem; _jointRanges) {
163             _jointRanges[ti].popFront;
164         }
165     }
166 
167     @property bool empty() {
168         foreach(elem; _jointRanges) {
169             if(elem.empty) {
170                 return true;
171             }
172         }
173         return false;
174     }
175 
176     static if(T.length > 0 && allSatisfy!(hasLength, T)) {
177         @property size_t length() {
178             size_t ret = size_t.max;
179             foreach(range; _jointRanges) {
180                 auto len = range.length;
181                 if(len < ret) {
182                     ret = len;
183                 }
184             }
185             return ret;
186         }
187     }
188 }
189 
190 template ElementsTuple(T...) {
191     static if(T.length == 1) {
192         alias TypeTuple!(Unqual!(ElementType!(T[0]))) ElementsTuple;
193     } else {
194         alias TypeTuple!(Unqual!(ElementType!(T[0])), ElementsTuple!(T[1..$]))
195             ElementsTuple;
196     }
197 }
198 
199 private template Comparable(T) {
200     enum bool Comparable = is(typeof({
201         T a;
202         T b;
203         return a < b; }));
204 }
205 
206 static assert(Comparable!ubyte);
207 static assert(Comparable!ubyte);
208 
209 struct ObsEnt(T...) {
210     T compRep;
211     alias compRep this;
212 
213     static if(anySatisfy!(hasIndirections, T)) {
214 
215         // Then there's indirection involved.  We can't just do all our
216         // comparison and hashing operations bitwise.
217         hash_t toHash() {
218             hash_t sum = 0;
219             foreach(i, elem; this.tupleof) {
220                 sum *= 11;
221                 static if(is(elem : long) && elem.sizeof <= hash_t.sizeof) {
222                     sum += elem;
223                 } else static if(__traits(compiles, elem.toHash)) {
224                     sum += elem.toHash;
225                 } else {
226                     auto ti = typeid(typeof(elem));
227                     sum += ti.getHash(&elem);
228                 }
229             }
230             return sum;
231         }
232 
233         bool opEquals(const ref typeof(this) rhs) const {
234             foreach(ti, elem; this.tupleof) {
235                 if(elem != rhs.tupleof[ti])
236                     return false;
237             }
238             return true;
239         }
240     }
241     // Else just use the default runtime functions for hash and equality.
242 
243 
244     static if(allSatisfy!(Comparable, T)) {
245         int opCmp(const ref typeof(this) rhs) const {
246             foreach(ti, elem; this.tupleof) {
247                 if(rhs.tupleof[ti] < elem) {
248                     return -1;
249                 } else if(rhs.tupleof[ti] > elem) {
250                     return 1;
251                 }
252             }
253             return 0;
254         }
255     }
256 }
257 
258 // Whether we can use StackTreeAA, or whether we have to use a regular AA for
259 // entropy.
260 private template NeedsHeap(T) {
261     static if(!hasIndirections!(ForeachType!(T))) {
262         enum bool NeedsHeap = false;
263     } else static if(isArray!(T)) {
264         enum bool NeedsHeap = false;
265     } else static if(is(Joint!(typeof(T.init.tupleof)))
266            && is(T == Joint!(typeof(T.init.tupleof)))
267            && allSatisfy!(isArray, typeof(T.init.tupleof))) {
268         enum bool NeedsHeap = false;
269     } else {
270         enum bool NeedsHeap = true;
271     }
272 }
273 
274 unittest {
275     auto foo = filter!"a"(cast(uint[][]) [[1]]);
276     auto bar = filter!("a")([1,2,3][]);
277     static assert(NeedsHeap!(typeof(foo)));
278     static assert(!NeedsHeap!(typeof(bar)));
279     static assert(NeedsHeap!(Joint!(uint[], typeof(foo))));
280     static assert(!NeedsHeap!(Joint!(uint[], typeof(bar))));
281     static assert(!NeedsHeap!(Joint!(uint[], uint[])));
282 }
283 
284 /**Calculates the joint entropy of a set of observations.  Each input range
285  * represents a vector of observations. If only one range is given, this reduces
286  * to the plain old entropy.  Input range must have a length.
287  *
288  * Note:  This function specializes if ElementType!(T) is a byte, ubyte, or
289  * char, resulting in a much faster entropy calculation.  When possible, try
290  * to provide data in the form of a byte, ubyte, or char.
291  *
292  * Examples:
293  * ---
294  * int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3];
295  * double entropyFoo = entropy(foo);  // Plain old entropy of foo.
296  * assert(approxEqual(entropyFoo, log2(3)));
297  * int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3];
298  * double HFooBar = entropy(joint(foo, bar));  // Joint entropy of foo and bar.
299  * assert(approxEqual(HFooBar, log2(9)));
300  * ---
301  */
302 double entropy(T)(T data)
303 if(isIterable!(T)) {
304     static if(!hasLength!(T)) {
305         return entropyImpl!(uint, T)(data);
306     } else {
307         if(data.length <= ubyte.max) {
308             return entropyImpl!(ubyte, T)(data);
309         } else if(data.length <= ushort.max) {
310             return entropyImpl!(ushort, T)(data);
311         } else {
312             return entropyImpl!(uint, T)(data);
313         }
314     }
315 }
316 
317 private double entropyImpl(U, T)(T data)
318 if((ForeachType!(T).sizeof > 1 || is(ForeachType!T == struct)) && !NeedsHeap!(T)) {
319     // Generic version.
320     auto alloc = newRegionAllocator();
321     alias ForeachType!(T) E;
322 
323     static if(hasLength!T) {
324         auto counts = StackHash!(E, U)(max(20, data.length / 20), alloc);
325     } else {
326         auto counts = StackTreeAA!(E, U)(alloc);
327     }
328     uint N;
329 
330     foreach(elem; data)  {
331         counts[elem]++;
332         N++;
333     }
334 
335     double ans = entropyCounts(counts.values, N);
336     return ans;
337 }
338 
339 private double entropyImpl(U, T)(T data)
340 if(ForeachType!(T).sizeof > 1 && NeedsHeap!(T)) {  // Generic version.
341     alias ForeachType!(T) E;
342 
343     uint len = 0;
344     U[E] counts;
345     foreach(elem; data) {
346         len++;
347         counts[elem]++;
348     }
349     return entropyCounts(counts, len);
350 }
351 
352 private double entropyImpl(U, T)(T data)  // byte/char specialization
353 if(ForeachType!(T).sizeof == 1 && !is(ForeachType!T == struct)) {
354     alias ForeachType!(T) E;
355 
356     U[ubyte.max + 1] counts;
357 
358     uint min = ubyte.max, max = 0, len = 0;
359     foreach(elem; data)  {
360         len++;
361         static if(is(E == byte)) {
362             // Keep adjacent elements adjacent.  In real world use cases,
363             // probably will have ranges like [-1, 1].
364             ubyte e = cast(ubyte) (cast(ubyte) (elem) + byte.max);
365         } else {
366             ubyte e = cast(ubyte) elem;
367         }
368         counts[e]++;
369         if(e > max) {
370             max = e;
371         }
372         if(e < min) {
373             min = e;
374         }
375     }
376 
377     return entropyCounts(counts.ptr[min..max + 1], len);
378 }
379 
380 unittest {
381     { // Generic version.
382         int[] foo = [1, 1, 1, 2, 2, 2, 3, 3, 3];
383         double entropyFoo = entropy(foo);
384         assert(approxEqual(entropyFoo, log2(3)));
385         int[] bar = [1, 2, 3, 1, 2, 3, 1, 2, 3];
386         auto stuff = joint(foo, bar);
387         double jointEntropyFooBar = entropy(joint(foo, bar));
388         assert(approxEqual(jointEntropyFooBar, log2(9)));
389     }
390     { // byte specialization
391         byte[] foo = [-1, -1, -1, 2, 2, 2, 3, 3, 3];
392         double entropyFoo = entropy(foo);
393         assert(approxEqual(entropyFoo, log2(3)));
394         string bar = "ACTGGCTA";
395         assert(entropy(bar) == 2);
396     }
397     { // NeedsHeap version.
398         string[] arr = ["1", "1", "1", "2", "2", "2", "3", "3", "3"];
399         auto m = map!("a")(arr);
400         assert(approxEqual(entropy(m), log2(3)));
401     }
402 }
403 
404 /**Calculate the conditional entropy H(data | cond).*/
405 double condEntropy(T, U)(T data, U cond)
406 if(isInputRange!(T) && isInputRange!(U)) {
407     static if(isForwardRange!U) {
408         alias cond condForward;
409     } else {
410         auto alloc = newRegionAllocator();
411         auto condForward = alloc.array(cond);
412     }
413 
414     return entropy(joint(data, condForward.save)) - entropy(condForward.save);
415 }
416 
417 unittest {
418     // This shouldn't be easy to screw up.  Just really basic.
419     int[] foo = [1,2,2,1,1];
420     int[] bar = [1,2,3,1,2];
421     assert(approxEqual(entropy(foo) - condEntropy(foo, bar),
422            mutualInfo(foo, bar)));
423 }
424 
425 private double miContingency(double observed, double expected) {
426     return (observed == 0) ? 0 :
427            (observed * log2(observed / expected));
428 }
429 
430 
431 /**Calculates the mutual information of two vectors of discrete observations.
432  */
433 double mutualInfo(T, U)(T x, U y)
434 if(isInputRange!(T) && isInputRange!(U)) {
435     uint xFreedom, yFreedom, n;
436     typeof(return) ret;
437 
438     static if(!hasLength!T && !hasLength!U) {
439         ret = toContingencyScore!(T, U, uint)
440             (x, y, &miContingency, xFreedom, yFreedom, n);
441     } else {
442         immutable minLen = min(x.length, y.length);
443         if(minLen <= ubyte.max) {
444             ret = toContingencyScore!(T, U, ubyte)
445                 (x, y, &miContingency, xFreedom, yFreedom, n);
446         } else if(minLen <= ushort.max) {
447             ret = toContingencyScore!(T, U, ushort)
448                 (x, y, &miContingency, xFreedom, yFreedom, n);
449         } else {
450             ret = toContingencyScore!(T, U, uint)
451                 (x, y, &miContingency, xFreedom, yFreedom, n);
452         }
453     }
454 
455     return ret / n;
456 }
457 
458 unittest {
459     // Values from R, but converted from base e to base 2.
460     assert(approxEqual(mutualInfo(bin([1,2,3,3,8].dup, 10),
461            bin([8,6,7,5,3].dup, 10)), 1.921928));
462     assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 2),
463            bin([2,7,9,6,3,1,7,40].dup, 2)), .2935645));
464     assert(approxEqual(mutualInfo(bin([1,2,1,1,3,4,3,6].dup, 4),
465            bin([2,7,9,6,3,1,7,40].dup, 4)), .5435671));
466 
467 }
468 
469 /**
470 Calculates the mutual information of a contingency table representing a joint
471 discrete probability distribution.  Takes a set of finite forward ranges,
472 one for each column in the contingency table.  These can be expressed either as
473 a tuple of ranges or a range of ranges.
474 */
475 double mutualInfoTable(T...)(T table) {
476     // This function is really just included to give conceptual unity to
477     // the infotheory module.
478     return gTestContingency(table).mutualInfo;
479 }
480 
481 /**
482 Calculates the conditional mutual information I(x, y | z) from a set of
483 observations.
484 */
485 double condMutualInfo(T, U, V)(T x, U y, V z) {
486     auto ret = entropy(joint(x, z)) - entropy(joint(x, y, z)) - entropy(z)
487         + entropy(joint(y, z));
488     return max(ret, 0);
489 }
490 
491 unittest {
492     // Values from Matlab mi package by Hanchuan Peng.
493     auto res = condMutualInfo([1,2,1,2,1,2,1,2].dup, [3,1,2,3,4,2,1,2].dup,
494                               [1,2,3,1,2,3,1,2].dup);
495     assert(approxEqual(res, 0.4387));
496     res = condMutualInfo([1,2,3,1,2].dup, [2,1,3,2,1].dup,
497                          joint([1,1,1,2,2].dup, [2,2,2,1,1].dup));
498     assert(approxEqual(res, 1.3510));
499 }
500 
501 /**Calculates the entropy of any old input range of observations more quickly
502  * than entropy(), provided that all equal values are adjacent.  If the input
503  * is sorted by more than one key, i.e. structs, the result will be the joint
504  * entropy of all of the keys.  The compFun alias will be used to compare
505  * adjacent elements and determine how many instances of each value exist.*/
506 double entropySorted(alias compFun = "a == b", T)(T data)
507 if(isInputRange!(T)) {
508     alias ElementType!(T) E;
509     alias binaryFun!(compFun) comp;
510     immutable n = data.length;
511     immutable nrNeg1 = 1.0L / n;
512 
513     double sum = 0.0;
514     int nSame = 1;
515     auto last = data.front;
516     data.popFront;
517     foreach(elem; data) {
518         if(comp(elem, last)) {
519             nSame++;
520         } else {
521             immutable p = nSame * nrNeg1;
522             nSame = 1;
523             sum -= p * log2(p);
524         }
525         last = elem;
526     }
527     // Handle last run.
528     immutable p = nSame * nrNeg1;
529     sum -= p * log2(p);
530 
531     return sum;
532 }
533 
534 unittest {
535     uint[] foo = [1U,2,3,1,3,2,6,3,1,6,3,2,2,1,3,5,2,1].dup;
536     auto sorted = foo.dup;
537     sort(sorted);
538     assert(approxEqual(entropySorted(sorted), entropy(foo)));
539 }
540 
541 /**
542 Much faster implementations of information theory functions for the special
543 but common case where all observations are integers on the range [0, nBin$(RPAREN).
544 This is the case, for example, when the observations have been previously
545 binned using, for example, dstats.base.frqBin().
546 
547 Note that, due to the optimizations used, joint() cannot be used with
548 the member functions of this struct, except entropy().
549 
550 For those looking for hard numbers, this seems to be on the order of 10x
551 faster than the generic implementations according to my quick and dirty
552 benchmarks.
553 */
554 struct DenseInfoTheory {
555     private uint nBin;
556 
557     // Saves space and makes things cache efficient by using the smallest
558     // integer width necessary for binning.
559     double selectSize(alias fun, T...)(T args) {
560         static if(allSatisfy!(hasLength, T)) {
561             immutable len = args[0].length;
562 
563             if(len <= ubyte.max) {
564                 return fun!ubyte(args);
565             } else if(len <= ushort.max) {
566                 return fun!ushort(args);
567             } else {
568                 return fun!uint(args);
569             }
570 
571             // For now, assume that noone is going to have more than
572             // 4 billion observations.
573         } else {
574             return fun!uint(args);
575         }
576     }
577 
578     /**
579     Constructs a DenseInfoTheory object for nBin bins.  The values taken by
580     each observation must then be on the interval [0, nBin$(RPAREN).
581     */
582     this(uint nBin) {
583         this.nBin = nBin;
584     }
585 
586     /**
587     Computes the entropy of a set of observations.  Note that, for this
588     function, the joint() function can be used to compute joint entropies
589     as long as each individual range contains only integers on [0, nBin$(RPAREN).
590     */
591     double entropy(R)(R range) if(isIterable!R) {
592         return selectSize!entropyImpl(range);
593     }
594 
595     private double entropyImpl(Uint, R)(R range) {
596         auto alloc = newRegionAllocator();
597         uint n = 0;
598 
599         static if(is(typeof(range._jointRanges))) {
600             // Compute joint entropy.
601             immutable nRanges = range._jointRanges.length;
602             auto counts = alloc.uninitializedArray!(Uint[])(nBin ^^ nRanges);
603             counts[] = 0;
604 
605             Outer:
606             while(true) {
607                 uint multiplier = 1;
608                 uint index = 0;
609 
610                 foreach(ti, Unused; typeof(range._jointRanges)) {
611                     if(range._jointRanges[ti].empty) break Outer;
612                     immutable rFront = range._jointRanges[ti].front;
613                     assert(rFront < nBin);  // Enforce is too costly here.
614 
615                     index += multiplier * cast(uint) rFront;
616                     range._jointRanges[ti].popFront();
617                     multiplier *= nBin;
618                 }
619 
620                 counts[index]++;
621                 n++;
622             }
623 
624             return entropyCounts(counts, n);
625         } else {
626             auto counts = alloc.uninitializedArray!(Uint[])(nBin);
627 
628             counts[] = 0;
629             foreach(elem; range) {
630                 counts[elem]++;
631                 n++;
632             }
633 
634             return entropyCounts(counts, n);
635         }
636     }
637 
638     /// I(x; y)
639     double mutualInfo(R1, R2)(R1 x, R2 y)
640     if(isIterable!R1 && isIterable!R2) {
641         return selectSize!mutualInfoImpl(x, y);
642     }
643 
644     private double mutualInfoImpl(Uint, R1, R2)(R1 x, R2 y) {
645         auto alloc = newRegionAllocator();
646         auto joint = alloc.uninitializedArray!(Uint[])(nBin * nBin);
647         auto margx = alloc.uninitializedArray!(Uint[])(nBin);
648         auto margy = alloc.uninitializedArray!(Uint[])(nBin);
649         joint[] = 0;
650         margx[] = 0;
651         margy[] = 0;
652         uint n;
653 
654         while(!x.empty && !y.empty) {
655             immutable xFront = cast(uint) x.front;
656             immutable yFront = cast(uint) y.front;
657             assert(xFront < nBin);
658             assert(yFront < nBin);
659 
660             joint[xFront * nBin + yFront]++;
661             margx[xFront]++;
662             margy[yFront]++;
663             n++;
664             x.popFront();
665             y.popFront();
666         }
667 
668         auto ret = entropyCounts(margx, n) + entropyCounts(margy, n) -
669             entropyCounts(joint, n);
670         return max(0, ret);
671     }
672 
673     /**
674     Calculates the P-value for I(X; Y) assuming x and y both have supports
675     of [0, nBin$(RPAREN).  The P-value is calculated using a Chi-Square approximation.
676     It is asymptotically correct, but is approximate for finite sample size.
677 
678     Parameters:
679     mutualInfo:  I(x; y), in bits
680     n:  The number of samples used to calculate I(x; y)
681     */
682     double mutualInfoPval(double mutualInfo, double n) {
683         immutable df = (nBin - 1) ^^ 2;
684 
685         immutable testStat = mutualInfo * 2 * LN2 * n;
686         return chiSquareCDFR(testStat, df);
687     }
688 
689     /// H(X | Y)
690     double condEntropy(R1, R2)(R1 x, R2 y)
691     if(isIterable!R1 && isIterable!R2) {
692         return selectSize!condEntropyImpl(x, y);
693     }
694 
695     private double condEntropyImpl(Uint, R1, R2)(R1 x, R2 y) {
696         auto alloc = newRegionAllocator();
697         auto joint = alloc.uninitializedArray!(Uint[])(nBin * nBin);
698         auto margy = alloc.uninitializedArray!(Uint[])(nBin);
699         joint[] = 0;
700         margy[] = 0;
701         uint n;
702 
703         while(!x.empty && !y.empty) {
704             immutable xFront = cast(uint) x.front;
705             immutable yFront = cast(uint) y.front;
706             assert(xFront < nBin);
707             assert(yFront < nBin);
708 
709             joint[xFront * nBin + yFront]++;
710             margy[yFront]++;
711             n++;
712             x.popFront();
713             y.popFront();
714         }
715 
716         auto ret = entropyCounts(joint, n) - entropyCounts(margy, n);
717         return max(0, ret);
718     }
719 
720     /// I(X; Y | Z)
721     double condMutualInfo(R1, R2, R3)(R1 x, R2 y, R3 z)
722     if(allSatisfy!(isIterable, R1, R2, R3)) {
723         return selectSize!condMutualInfoImpl(x, y, z);
724     }
725 
726     private double condMutualInfoImpl(Uint, R1, R2, R3)(R1 x, R2 y, R3 z) {
727         auto alloc = newRegionAllocator();
728         immutable nBinSq = nBin * nBin;
729         auto jointxyz = alloc.uninitializedArray!(Uint[])(nBin * nBin * nBin);
730         auto jointxz = alloc.uninitializedArray!(Uint[])(nBinSq);
731         auto jointyz = alloc.uninitializedArray!(Uint[])(nBinSq);
732         auto margz = alloc.uninitializedArray!(Uint[])(nBin);
733         jointxyz[] = 0;
734         jointxz[] = 0;
735         jointyz[] = 0;
736         margz[] = 0;
737         uint n = 0;
738 
739         while(!x.empty && !y.empty && !z.empty) {
740             immutable xFront = cast(uint) x.front;
741             immutable yFront = cast(uint) y.front;
742             immutable zFront = cast(uint) z.front;
743             assert(xFront < nBin);
744             assert(yFront < nBin);
745             assert(zFront < nBin);
746 
747             jointxyz[xFront * nBinSq + yFront * nBin + zFront]++;
748             jointxz[xFront * nBin + zFront]++;
749             jointyz[yFront * nBin + zFront]++;
750             margz[zFront]++;
751             n++;
752 
753             x.popFront();
754             y.popFront();
755             z.popFront();
756         }
757 
758         auto ret = entropyCounts(jointxz, n) - entropyCounts(jointxyz, n) -
759             entropyCounts(margz, n) + entropyCounts(jointyz, n);
760         return max(0, ret);
761     }
762 }
763 
764 unittest {
765     auto dense = DenseInfoTheory(3);
766     auto a = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2];
767     auto b = [1, 2, 2, 2, 0, 0, 1, 1, 1, 1, 0, 0];
768     auto c = [1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0];
769 
770     assert(entropy(a) == dense.entropy(a));
771     assert(entropy(b) == dense.entropy(b));
772     assert(entropy(c) == dense.entropy(c));
773     assert(entropy(joint(a, c)) == dense.entropy(joint(c, a)));
774     assert(entropy(joint(a, b)) == dense.entropy(joint(a, b)));
775     assert(entropy(joint(c, b)) == dense.entropy(joint(c, b)));
776 
777     assert(condEntropy(a, c) == dense.condEntropy(a, c));
778     assert(condEntropy(a, b) == dense.condEntropy(a, b));
779     assert(condEntropy(c, b) == dense.condEntropy(c, b));
780 
781     alias approxEqual ae;
782     assert(ae(mutualInfo(a, c), dense.mutualInfo(c, a)));
783     assert(ae(mutualInfo(a, b), dense.mutualInfo(a, b)));
784     assert(ae(mutualInfo(c, b), dense.mutualInfo(c, b)));
785 
786     assert(ae(condMutualInfo(a, b, c), dense.condMutualInfo(a, b, c)));
787     assert(ae(condMutualInfo(a, c, b), dense.condMutualInfo(a, c, b)));
788     assert(ae(condMutualInfo(b, c, a), dense.condMutualInfo(b, c, a)));
789 
790     // Test P-value stuff.
791     immutable pDense = dense.mutualInfoPval(dense.mutualInfo(a, b), a.length);
792     immutable pNotDense = gTestObs(a, b).p;
793     assert(approxEqual(pDense, pNotDense));
794 }