Генератор псевдослучайных чисел "Вихрь Мерсенна" - SIMD реализация

Мне понадобился качественный ГСПЧ с хорошим распределением и длинным периодом. Самый популярный из существующих - Вихрь Мерсенна.

Я взял SIMD реализацию от его авторов SFMT 1.5.1 на C и переписал на C# (.NET 6).

Простите, разработчики игр, с Unity этот вариант Вихря несовместим и легко доработать не получится, надо ждать, когда сама Unity перейдет на CoreCLR и .NET 6. :(

Работал по такому заданию:

  • Процессор должен поддерживать SSE2, фоллбэк не предусмотрен
  • Поддерживается потокобезопасность на уровне статических членов
  • Совместим по поведению публичных членов с System.Random
  • Не имеет болезней нулевого сида
  • Сделать настолько просто, насколько возможно
  • Ориентирован для x64 приложения
/// <summary>
/// Генератор псевдослучайных чисел на основе алгоритма Вихрь Мерсенна
/// Конфигурация: SFMT-19937:122-18-1-11-1:dfffffef-ddfecb7f-bffaffff-bffffff6
/// </summary>
public class MersenneTwister
{
    const int MersenneExponent = 19937;
    const int Length128 = MersenneExponent / 128 + 1;
    const int Length32 = Length128 * 4;

    private readonly Vector128<uint>[] state = new Vector128<uint>[Length128];
    private int index;

    [ThreadStatic]
    private static MersenneTwister _shared;

    /// <summary>
    /// Общий экземпляр генератора, создается отдельно для каждого потока
    /// </summary>
    public static MersenneTwister Shared => _shared ??= new(GetRandomSeed());

    /// <summary>
    /// Создает экземпляр генератора со случайным сидом
    /// </summary>
    public MersenneTwister() 
    {
        InitGenerator(Shared.NextUInt32());
    }

    /// <summary>
    /// Создает экземпляр генератора с указанным сидом
    /// </summary>
    /// <param name="seed">Сид для создания генератора</param>
    public MersenneTwister(int seed) : this((uint)seed) { }

    private MersenneTwister(uint seed)
    {
        InitGenerator(seed);
    }

    private static uint GetRandomSeed()
    {
        ReadOnlySpan<byte> bytes = Guid.NewGuid().ToByteArray();
        ReadOnlySpan<uint> hash = MemoryMarshal.Cast<byte, uint>(SHA256.HashData(bytes));
        uint result = 0;
        foreach (uint value in hash)
            result ^= value;
        return result;
    }

    private void InitGenerator(uint seed)
    {
        if (!Sse2.IsSupported)
            throw new InvalidOperationException("SSE2 не поддерживается на данном устройстве.");

        Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(state);
        values[0] = seed;
        for (int i = 1; i < Length32; i++)
        {
            values[i] = (uint)(1812433253UL * (values[i - 1] ^ (values[i - 1] >> 30)) + (uint)i);
        }
        index = Length32;
        PeriodCertification(values);
    }

    private void PeriodCertification(Span<uint> values)
    {
        uint inner = 0;
        ReadOnlySpan<uint> parity = stackalloc uint[] { 0x00000001U, 0x00000000U, 0x00000000U, 0x13c9e684U };

        for (int i = 0; i < 4; i++)
            inner ^= values[i] & parity[i];
        for (int i = 16; i > 0; i >>= 1)
            inner ^= inner >> i;
        inner &= 1;

        if (inner == 1)
            return;

        for (int i = 0; i < 4; i++)
        {
            for (uint work = 1; work != 0; work <<= 1)
            {
                if ((work & parity[i]) != 0)
                {
                    values[i] ^= work;
                    return;
                }
            }
        }
    }

    private void UpdateState()
    {
        const int offset = 122;
        int i;
        Vector128<uint> r1 = state[^2];
        Vector128<uint> r2 = state[^1];
        for (i = 0; i < Length128 - offset; i++)
        {
            state[i] = DoRecursion(state[i], state[i + offset], r1, r2);
            r1 = r2;
            r2 = state[i];
        }
        for (; i < Length128; i++)
        {
            state[i] = DoRecursion(state[i], state[i + offset - Length128], r1, r2);
            r1 = r2;
            r2 = state[i];
        }
    }

    private Vector128<uint> DoRecursion(Vector128<uint> a, Vector128<uint> b, Vector128<uint> c, Vector128<uint> d)
    {
        Vector128<uint> z = Sse2.ShiftRightLogical128BitLane(c, 1);
        z = Sse2.Xor(z, a);
        Vector128<uint> v = Sse2.ShiftLeftLogical(d, 18);
        z = Sse2.Xor(z, v);
        Vector128<uint> x = Sse2.ShiftLeftLogical128BitLane(a, 1);
        z = Sse2.Xor(z, x);
        Vector128<uint> y = Sse2.ShiftRightLogical(b, 11);
        Vector128<uint> mask = Vector128.Create(0xdfffffefU, 0xddfecb7fU, 0xbffaffffU, 0xbffffff6U);
        y = Sse2.And(y, mask);
        return Sse2.Xor(z, y);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,ulong.MaxValue]
    /// </summary>
    public ulong NextUInt64()
    {
        if (index >= Length32)
        {
            UpdateState();
            index = 0;
        }
        Span<ulong> values = MemoryMarshal.Cast<Vector128<uint>, ulong>(state);
        ulong r = values[index / 2];
        index += 2;
        return r;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,uint.MaxValue]
    /// </summary>
    public uint NextUInt32()
    {
        if (index >= Length32)
        {
            UpdateState();
            index = 0;
        }
        Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(state);
        return values[index++];
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,int.MaxValue)
    /// </summary>
    public int Next()
    {
        return (int)(((ulong)int.MaxValue * NextUInt32()) >> 32);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,maxValue)
    /// </summary>
    public int Next(int maxValue)
    {
        if (maxValue <= 0)
            throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");

        return (int)(((ulong)maxValue * NextUInt32()) >> 32);
    }

    /// <summary>
    /// Возвращает число в диапазоне [minValue,maxValue)
    /// </summary>
    public int Next(int minValue, int maxValue)
    {
        if (minValue < 0)
            throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
        if (maxValue <= minValue)
            throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");

        return (int)(((ulong)(maxValue - minValue) * NextUInt32()) >> 32) + minValue;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,long.MaxValue)
    /// </summary>
    public long NextInt64()
    {
        return (long)(((BigInteger)long.MaxValue * NextUInt64()) >> 64);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,maxValue)
    /// </summary>
    public long NextInt64(long maxValue)
    {
        if (maxValue <= 0)
            throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");

        return (long)(((BigInteger)maxValue * NextUInt64()) >> 64);
    }

    /// <summary>
    /// Возвращает число в диапазоне [minValue,maxValue)
    /// </summary>
    public long NextInt64(long minValue, long maxValue)
    {
        if (minValue < 0)
            throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
        if (maxValue <= minValue)
            throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");

        return (long)(((BigInteger)(maxValue - minValue) * NextUInt64()) >> 64) + minValue;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,1)
    /// </summary>
    public double NextDouble()
    {
        return (NextUInt64() >> 11) / 9007199254740992.0; // 2^53
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,1)
    /// </summary>
    public float NextSingle()
    {
        return (NextUInt32() >> 8) / 16777216.0f; // 2^24
    }
}

Использовать просто, например вот так:

static void Main(string[] args)
{
    MersenneTwister mt = new MersenneTwister();
    int randomNumber = mt.Next(10);
    Console.WriteLine(randomNumber);

    Console.ReadKey();
}

По производительности он хорош, не буду приводить числа, для каждого процессора они свои, а просто скажу, что он всего в 2-4 раза медленнее, чем System.Random, хотя для меня допустима была просадка в 10 и более раз.

Выдает на одном и том же сиде идентичную последовательность чисел, как оригинальная разработка SFMT 1.5.1 на С, то есть с вычислениями здесь тоже всё ок.

От System.Random не стал наследоваться, так как не хотел дергать его внутренню логику. А так как сам System.Random ничего не реализует (могли бы за столько лет и интерфейс к нему прикрутить), смысла связываться с ним я не увидел.

Посмотрите пожалуйста, все ли ок у меня в преобразованиях в публичных методах, и не перемудрил ли я с сидами. А может стоит что-нибудь подкрутить, чтобы стало быстрее?


Ответы (1 шт):

Автор решения: aepot
  • Исправлено именование полей.
  • Оптимизирован, а затем уничтожен метод PeriodCertification. Вмержил его остатки в InitGenerator.
  • Исправлен баг нечетного индекса в NextUInt64(), ранее при генерации ulong если до этого была генерация uint и индекс остался нечетным, то старшая часть нового ulong повторяла ранее отданное uint число.
/// <summary>
/// Генератор псевдослучайных чисел на основе алгоритма Вихрь Мерсенна
/// Конфигурация: SFMT-19937:122-18-1-11-1:dfffffef-ddfecb7f-bffaffff-bffffff6
/// </summary>
public class MersenneTwister
{
    private const int MersenneExponent = 19937;
    private const int Length128 = MersenneExponent / 128 + 1;
    private const int Length32 = Length128 * 4;

    private readonly Vector128<uint>[] _state = new Vector128<uint>[Length128];
    private int _index;

    [ThreadStatic]
    private static MersenneTwister _shared;

    /// <summary>
    /// Общий экземпляр генератора, создается отдельно для каждого потока
    /// </summary>
    public static MersenneTwister Shared => _shared ??= new(GetRandomSeed());

    /// <summary>
    /// Создает экземпляр генератора со случайным сидом
    /// </summary>
    public MersenneTwister() : this(Shared.NextUInt32()) { }

    /// <summary>
    /// Создает экземпляр генератора с указанным сидом
    /// </summary>
    /// <param name="seed">Сид для создания генератора</param>
    public MersenneTwister(int seed) : this((uint)seed) { }

    private MersenneTwister(uint seed)
    {
        InitGenerator(seed);
    }

    private static uint GetRandomSeed()
    {
        ReadOnlySpan<byte> bytes = Guid.NewGuid().ToByteArray();
        ReadOnlySpan<uint> hash = MemoryMarshal.Cast<byte, uint>(SHA256.HashData(bytes));
        uint result = 0;
        foreach (uint value in hash)
            result ^= value;
        return result;
    }

    private void InitGenerator(uint seed)
    {
        if (!Sse2.IsSupported)
            throw new InvalidOperationException("SSE2 не поддерживается на данном устройстве.");

        Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(_state);
        values[0] = seed;
        for (int i = 1; i < Length32; i++)
        {
            values[i] = (uint)(1812433253ul * (values[i - 1] ^ (values[i - 1] >> 30)) + (uint)i);
        }
        _index = Length32;

        Vector128<uint> parity = Vector128.Create(0x00000001u, 0x00000000u, 0x00000000u, 0x13c9e684u);
        Vector128<uint> v = Sse2.And(_state[0], parity);

        uint inner = 0;
        for (int i = 0; i < 4; i++)
            inner ^= v.GetElement(i);
        for (int i = 16; i > 0; i >>= 1)
            inner ^= inner >> i;

        if ((inner & 1) == 0)
            values[0] ^= 1;
    }

    private void UpdateState()
    {
        const int offset = 122;
        Vector128<uint> r1 = _state[^2];
        Vector128<uint> r2 = _state[^1];
        int i = 0;
        for (; i < Length128 - offset; i++)
        {
            _state[i] = DoRecursion(_state[i], _state[i + offset], r1, r2);
            r1 = r2;
            r2 = _state[i];
        }
        for (; i < Length128; i++)
        {
            _state[i] = DoRecursion(_state[i], _state[i + offset - Length128], r1, r2);
            r1 = r2;
            r2 = _state[i];
        }
    }

    private Vector128<uint> DoRecursion(Vector128<uint> a, Vector128<uint> b, Vector128<uint> c, Vector128<uint> d)
    {
        Vector128<uint> z = Sse2.ShiftRightLogical128BitLane(c, 1);
        z = Sse2.Xor(z, a);
        Vector128<uint> v = Sse2.ShiftLeftLogical(d, 18);
        z = Sse2.Xor(z, v);
        Vector128<uint> x = Sse2.ShiftLeftLogical128BitLane(a, 1);
        z = Sse2.Xor(z, x);
        Vector128<uint> y = Sse2.ShiftRightLogical(b, 11);
        Vector128<uint> mask = Vector128.Create(0xdfffffefu, 0xddfecb7fu, 0xbffaffffu, 0xbffffff6u);
        y = Sse2.And(y, mask);
        return Sse2.Xor(z, y);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,ulong.MaxValue]
    /// </summary>
    public ulong NextUInt64()
    {
        if ((_index & 1) == 1)
            _index++;
        if (_index >= Length32)
        {
            UpdateState();
            _index = 0;
        }
        Span<ulong> values = MemoryMarshal.Cast<Vector128<uint>, ulong>(_state);
        ulong r = values[_index >> 1];
        _index += 2;
        return r;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,uint.MaxValue]
    /// </summary>
    public uint NextUInt32()
    {
        if (_index >= Length32)
        {
            UpdateState();
            _index = 0;
        }
        Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(_state);
        return values[_index++];
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,int.MaxValue)
    /// </summary>
    public int Next()
    {
        return (int)(((ulong)int.MaxValue * NextUInt32()) >> 32);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,maxValue)
    /// </summary>
    public int Next(int maxValue)
    {
        if (maxValue <= 0)
            throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");

        return (int)(((ulong)maxValue * NextUInt32()) >> 32);
    }

    /// <summary>
    /// Возвращает число в диапазоне [minValue,maxValue)
    /// </summary>
    public int Next(int minValue, int maxValue)
    {
        if (minValue < 0)
            throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
        if (maxValue <= minValue)
            throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");

        return (int)(((ulong)(maxValue - minValue) * NextUInt32()) >> 32) + minValue;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,long.MaxValue)
    /// </summary>
    public long NextInt64()
    {
        return (long)(((BigInteger)long.MaxValue * NextUInt64()) >> 64);
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,maxValue)
    /// </summary>
    public long NextInt64(long maxValue)
    {
        if (maxValue <= 0)
            throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");

        return (long)(((BigInteger)maxValue * NextUInt64()) >> 64);
    }

    /// <summary>
    /// Возвращает число в диапазоне [minValue,maxValue)
    /// </summary>
    public long NextInt64(long minValue, long maxValue)
    {
        if (minValue < 0)
            throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
        if (maxValue <= minValue)
            throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");

        return (long)(((BigInteger)(maxValue - minValue) * NextUInt64()) >> 64) + minValue;
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,1)
    /// </summary>
    public double NextDouble()
    {
        return (NextUInt64() >> 11) / 9007199254740992.0; // 2^53
    }

    /// <summary>
    /// Возвращает число в диапазоне [0,1)
    /// </summary>
    public float NextSingle()
    {
        return (NextUInt32() >> 8) / 16777216.0f; // 2^24
    }
}

На данным этапе оставлю так, здесь более оптимизировать особо нечего.

→ Ссылка