2461c7257e9d — Michael Johnson 10 months ago
Improve speed of Uniform.Int128 distribution

This switches Int128 to use the same widening multiplication algorithm that UInt128 now uses.
M src/Benchmarks/UniformDists.cs +1 -1
@@ 158,7 158,7 @@ public class UniformDists
     }
 #if NET8_0_OR_GREATER
     [Benchmark]
-    public Int128 SampleUnt128()
+    public Int128 SampleInt128()
     {
         Int128 sum = 0;
         for (Int32 i = 0; i < Iterations; i++)

          
M src/RandN/Distributions/Uniform128.cs +21 -40
@@ 62,12 62,17 @@ public static partial class Uniform
             if (_range == 0) // 0 is a special case where we sample the entire range.
                 return unchecked((System.Int128)unsigned);
 
-            while (unsigned > _zone)
+            var zone = _zone;
+
+            while (true)
             {
-                unsigned = NextUInt128(rng);
+                var (hi, lo) = unsigned.WideningMultiply(_range);
+
+                if (lo <= zone)
+                    return unchecked((System.Int128)hi + _low);
+
+                unsigned = rng.NextUInt128();
             }
-
-            return unchecked((System.Int128)(unsigned % _range) + _low);
         }
 
         /// <inheritdoc />

          
@@ 80,9 85,12 @@ public static partial class Uniform
                 return true;
             }
 
-            if (unsigned <= _zone)
+            var zone = _zone;
+            var (hi, lo) = unsigned.WideningMultiply(_range);
+
+            if (lo <= zone)
             {
-                result = unchecked((System.Int128)(unsigned % _range) + _low);
+                result = unchecked((System.Int128)hi + _low);
                 return true;
             }
 

          
@@ 152,7 160,7 @@ public static partial class Uniform
         /// <inheritdoc />
         public System.UInt128 Sample<TRng>(TRng rng) where TRng : notnull, IRng
         {
-            var unsigned = NextUInt128(rng);
+            var unsigned = rng.NextUInt128();
             if (_range == 0) // 0 is a special case where we sample the entire range.
                 return unsigned;
 

          
@@ 160,19 168,19 @@ public static partial class Uniform
 
             while (true)
             {
-                var (hi, lo) = WideningMultiply(unsigned, _range);
+                var (hi, lo) = unsigned.WideningMultiply(_range);
 
                 if (lo <= zone)
-                    return unchecked(_low + hi);
+                    return unchecked(hi + _low);
 
-                unsigned = NextUInt128(rng);
+                unsigned = rng.NextUInt128();
             }
         }
 
         /// <inheritdoc />
         public Boolean TrySample<TRng>(TRng rng, out System.UInt128 result) where TRng : notnull, IRng
         {
-            var unsigned = NextUInt128(rng);
+            var unsigned = rng.NextUInt128();
             if (_range == 0) // 0 is a special case where we sample the entire range.
             {
                 result = unsigned;

          
@@ 180,44 188,17 @@ public static partial class Uniform
             }
 
             var zone = _zone;
-            var (hi, lo) = WideningMultiply(unsigned, _range);
+            var (hi, lo) = unsigned.WideningMultiply(_range);
 
             if (lo <= zone)
             {
-                result = unchecked(_low + hi);
+                result = unchecked(hi + _low);
                 return true;
             }
 
             result = default;
             return false;
         }
-
-        private static (System.UInt128, System.UInt128) WideningMultiply(System.UInt128 left, System.UInt128 right)
-        {
-            System.UInt128 LOWER_MASK = System.UInt128.MaxValue >> 64;
-            System.UInt128 low = unchecked((left & LOWER_MASK) * (right & LOWER_MASK));
-            System.UInt128 t = low >> 64;
-            low &= LOWER_MASK;
-            t += unchecked((left >> 64) * (right & LOWER_MASK));
-            low += (t & LOWER_MASK) << 64;
-            System.UInt128 high = t >> 64;
-            t = low >> 64;
-            low &= LOWER_MASK;
-            t += unchecked((right >> 64) * (left & LOWER_MASK));
-            low += (t & LOWER_MASK) << 64;
-            high += t >> 64;
-            high += unchecked((left >> 64) * (right >> 64));
-
-            return (high, low);
-        }
-
-        private static System.UInt128 NextUInt128<TRng>(TRng rng) where TRng : notnull, IRng
-        {
-            // Use Little Endian; we explicitly generate one value before the next.
-            var x = rng.NextUInt64();
-            var y = rng.NextUInt64();
-            return new System.UInt128(y, x);
-        }
     }
 }
 #endif

          
A => src/RandN/Distributions/Utils.cs +48 -0
@@ 0,0 1,48 @@ 
+#if NET8_0_OR_GREATER
+using System;
+
+namespace RandN.Distributions;
+
+/// <summary>
+/// Utilities for 128 bit distributions.
+/// </summary>
+internal static class Utils
+{
+    /// <summary>
+    /// Multiplies two <see cref="System.UInt128"/>s and returns the result split into upper and lower bits.
+    /// </summary>
+    /// <param name="left">The left multiplicand.</param>
+    /// <param name="right">The right multiplicand.</param>
+    /// <returns>A tuple of the high bits and the low bits of the multiplication result.</returns>
+    public static (UInt128 hi, UInt128 lo) WideningMultiply(this UInt128 left, UInt128 right)
+    {
+        UInt128 lowerMask = new UInt128(0, UInt64.MaxValue);
+
+        UInt128 low = unchecked((left & lowerMask) * (right & lowerMask));
+        UInt128 t = low >> 64;
+        low &= lowerMask;
+        t += unchecked((left >> 64) * (right & lowerMask));
+        low += (t & lowerMask) << 64;
+        UInt128 high = t >> 64;
+        t = low >> 64;
+        low &= lowerMask;
+        t += unchecked((right >> 64) * (left & lowerMask));
+        low += (t & lowerMask) << 64;
+        high += t >> 64;
+        high += unchecked((left >> 64) * (right >> 64));
+
+        return (high, low);
+    }
+
+    /// <summary>
+    /// Returns the next 128 bits in the RNG as a <see cref="System.UInt128"/>.
+    /// </summary>
+    public static UInt128 NextUInt128<TRng>(this TRng rng) where TRng : notnull, IRng
+    {
+        // Use Little Endian; we explicitly generate one value before the next.
+        var x = rng.NextUInt64();
+        var y = rng.NextUInt64();
+        return new UInt128(y, x);
+    }
+}
+#endif

          
M src/Tests/Distributions/UniformInteger128Tests.cs +15 -57
@@ 136,7 136,6 @@ public sealed class UniformInteger128Tes
         Assert.False(dist.TrySample(rng, out _));
         Assert.True(dist.TrySample(rng, out result));
         Assert.Equal(new UInt128(0x4000000000000000, 3), result);
-        Assert.False(dist.TrySample(rng, out _));
 
         // Now test a blocking sample
         Assert.Equal(new UInt128(0x4000000000000000, 4), dist.Sample(rng));

          
@@ 289,70 288,29 @@ public sealed class UniformInteger128Tes
     [Fact]
     public void RejectionsInt128()
     {
-        Int128 midpoint = 0;
+        Int128 midpoint = Int128.MaxValue / 2 + Int128.MinValue / 2;
         Int128 low = Int128.MinValue;
         Int128 high = midpoint + 1;
 
-        UInt128 maxRand = UInt128.MaxValue;
-        UInt128 rangeSize = unchecked((UInt128)high - (UInt128)low + 1);
-        UInt128 rejectCount = (maxRand - rangeSize + 1) % rangeSize;
-
-        UInt128 lastAccepted = maxRand - rejectCount;
-        UInt128 penultimateAccepted = lastAccepted - 1;
-        UInt128 firstRejected = lastAccepted + 1;
-        UInt128 secondRejected = firstRejected + 1;
-        UInt128 thirdRejected = secondRejected + 1;
+        var sequence = Enumerable.Range(0, 10).Select(x => UInt128.MaxValue / 2 + (UInt128)x);
 
         var dist = Uniform.NewInclusive(low, high);
-        var rng = new SequenceRng([
-            penultimateAccepted.IsolateLow(),
-            penultimateAccepted.IsolateHigh(),
-            lastAccepted.IsolateLow(),
-            lastAccepted.IsolateHigh(),
-            firstRejected.IsolateLow(),
-            firstRejected.IsolateHigh(),
-            secondRejected.IsolateLow(),
-            secondRejected.IsolateHigh(),
-            thirdRejected.IsolateLow(),
-            thirdRejected.IsolateHigh(),
+        var rng = new SequenceRng(sequence.SelectMany(x => new[] { x.IsolateLow(), x.IsolateHigh() }));
 
-            firstRejected.IsolateLow(),
-            firstRejected.IsolateHigh(),
-            secondRejected.IsolateLow(),
-            secondRejected.IsolateHigh(),
-            thirdRejected.IsolateLow(),
-            thirdRejected.IsolateHigh(),
-            0,
-            0,
-        ]);
-
+        Assert.False(dist.TrySample(rng, out _));
         Assert.True(dist.TrySample(rng, out Int128 result));
-        Assert.Equal(midpoint, result);
+        Assert.Equal(new Int128(0xC000000000000000, 0), result);
+        Assert.True(dist.TrySample(rng, out result));
+        Assert.Equal(new Int128(0xC000000000000000, 1), result);
+        Assert.False(dist.TrySample(rng, out _));
         Assert.True(dist.TrySample(rng, out result));
-        Assert.Equal(midpoint + 1, result);
+        Assert.Equal(new Int128(0xC000000000000000, 2), result);
         Assert.False(dist.TrySample(rng, out _));
-        Assert.False(dist.TrySample(rng, out _));
-        Assert.False(dist.TrySample(rng, out _));
+        Assert.True(dist.TrySample(rng, out result));
+        Assert.Equal(new Int128(0xC000000000000000, 3), result);
 
         // Now test a blocking sample
-        Assert.Equal(Int128.MinValue, dist.Sample(rng));
-    }
-
-    [Fact]
-    public void ZoneEqualToGeneratedInt128()
-    {
-        Int128 low = Int128.MinValue;
-        Int128 high = Int128.MaxValue - 1;
-        UInt128 rngState = UInt128.MaxValue - 1;
-
-        var rng = new SequenceRng([rngState.IsolateLow(), rngState.IsolateHigh()]);
-        var dist = Uniform.NewInclusive(low, high);
-
-        Assert.Equal(UInt32.MaxValue - 1, rng.NextUInt32());
-        Assert.Equal(UInt32.MaxValue, rng.NextUInt32());
-        Assert.Equal(UInt32.MaxValue, rng.NextUInt32());
-        Assert.Equal(UInt32.MaxValue, rng.NextUInt32());
-        Assert.Equal(high, dist.Sample(rng));
+        Assert.Equal(new Int128(0xC000000000000000, 4), dist.Sample(rng));
     }
 
     /// <summary>

          
@@ 383,9 341,9 @@ public sealed class UniformInteger128Tes
         var dist = Uniform.New(new Int128(0, 50), new Int128(2_000, 200_000_000_000ul));
         var expectedValues = new[]
         {
-            new Int128(0x0E9, 0x2C2DE5AC5A44EFE3),
-            new Int128(0x04C, 0xB23B6C6817A5B06C),
-            new Int128(0x1F6, 0xBD76F937FCFCFB8D),
+            new Int128(0x666, 0xD3A5622DD2B654F2),
+            new Int128(0x45D, 0x5667CA7B62179B9D),
+            new Int128(0x4D4, 0x81F33768224909BA),
         };
         foreach (var expected in expectedValues)
             Assert.Equal(expected, dist.Sample(rng));