Генератор псевдослучайных чисел "Вихрь Мерсенна" - SIMD реализация
Мне понадобился качественный ГСПЧ с хорошим распределением и длинным периодом. Самый популярный из существующих - Вихрь Мерсенна.
Я взял SIMD реализацию от его авторов SFMT 1.5.1 на C и переписал на C# (.NET 6).
Простите, разработчики игр, с Unity этот вариант Вихря несовместим и легко доработать не получится, надо ждать, когда сама Unity перейдет на CoreCLR и .NET 6. :(
Работал по такому заданию:
- Процессор должен поддерживать SSE2, фоллбэк не предусмотрен
- Поддерживается потокобезопасность на уровне статических членов
- Совместим по поведению публичных членов с
System.Random - Не имеет болезней нулевого сида
- Сделать настолько просто, насколько возможно
- Ориентирован для x64 приложения
/// <summary>
/// Генератор псевдослучайных чисел на основе алгоритма Вихрь Мерсенна
/// Конфигурация: SFMT-19937:122-18-1-11-1:dfffffef-ddfecb7f-bffaffff-bffffff6
/// </summary>
public class MersenneTwister
{
const int MersenneExponent = 19937;
const int Length128 = MersenneExponent / 128 + 1;
const int Length32 = Length128 * 4;
private readonly Vector128<uint>[] state = new Vector128<uint>[Length128];
private int index;
[ThreadStatic]
private static MersenneTwister _shared;
/// <summary>
/// Общий экземпляр генератора, создается отдельно для каждого потока
/// </summary>
public static MersenneTwister Shared => _shared ??= new(GetRandomSeed());
/// <summary>
/// Создает экземпляр генератора со случайным сидом
/// </summary>
public MersenneTwister()
{
InitGenerator(Shared.NextUInt32());
}
/// <summary>
/// Создает экземпляр генератора с указанным сидом
/// </summary>
/// <param name="seed">Сид для создания генератора</param>
public MersenneTwister(int seed) : this((uint)seed) { }
private MersenneTwister(uint seed)
{
InitGenerator(seed);
}
private static uint GetRandomSeed()
{
ReadOnlySpan<byte> bytes = Guid.NewGuid().ToByteArray();
ReadOnlySpan<uint> hash = MemoryMarshal.Cast<byte, uint>(SHA256.HashData(bytes));
uint result = 0;
foreach (uint value in hash)
result ^= value;
return result;
}
private void InitGenerator(uint seed)
{
if (!Sse2.IsSupported)
throw new InvalidOperationException("SSE2 не поддерживается на данном устройстве.");
Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(state);
values[0] = seed;
for (int i = 1; i < Length32; i++)
{
values[i] = (uint)(1812433253UL * (values[i - 1] ^ (values[i - 1] >> 30)) + (uint)i);
}
index = Length32;
PeriodCertification(values);
}
private void PeriodCertification(Span<uint> values)
{
uint inner = 0;
ReadOnlySpan<uint> parity = stackalloc uint[] { 0x00000001U, 0x00000000U, 0x00000000U, 0x13c9e684U };
for (int i = 0; i < 4; i++)
inner ^= values[i] & parity[i];
for (int i = 16; i > 0; i >>= 1)
inner ^= inner >> i;
inner &= 1;
if (inner == 1)
return;
for (int i = 0; i < 4; i++)
{
for (uint work = 1; work != 0; work <<= 1)
{
if ((work & parity[i]) != 0)
{
values[i] ^= work;
return;
}
}
}
}
private void UpdateState()
{
const int offset = 122;
int i;
Vector128<uint> r1 = state[^2];
Vector128<uint> r2 = state[^1];
for (i = 0; i < Length128 - offset; i++)
{
state[i] = DoRecursion(state[i], state[i + offset], r1, r2);
r1 = r2;
r2 = state[i];
}
for (; i < Length128; i++)
{
state[i] = DoRecursion(state[i], state[i + offset - Length128], r1, r2);
r1 = r2;
r2 = state[i];
}
}
private Vector128<uint> DoRecursion(Vector128<uint> a, Vector128<uint> b, Vector128<uint> c, Vector128<uint> d)
{
Vector128<uint> z = Sse2.ShiftRightLogical128BitLane(c, 1);
z = Sse2.Xor(z, a);
Vector128<uint> v = Sse2.ShiftLeftLogical(d, 18);
z = Sse2.Xor(z, v);
Vector128<uint> x = Sse2.ShiftLeftLogical128BitLane(a, 1);
z = Sse2.Xor(z, x);
Vector128<uint> y = Sse2.ShiftRightLogical(b, 11);
Vector128<uint> mask = Vector128.Create(0xdfffffefU, 0xddfecb7fU, 0xbffaffffU, 0xbffffff6U);
y = Sse2.And(y, mask);
return Sse2.Xor(z, y);
}
/// <summary>
/// Возвращает число в диапазоне [0,ulong.MaxValue]
/// </summary>
public ulong NextUInt64()
{
if (index >= Length32)
{
UpdateState();
index = 0;
}
Span<ulong> values = MemoryMarshal.Cast<Vector128<uint>, ulong>(state);
ulong r = values[index / 2];
index += 2;
return r;
}
/// <summary>
/// Возвращает число в диапазоне [0,uint.MaxValue]
/// </summary>
public uint NextUInt32()
{
if (index >= Length32)
{
UpdateState();
index = 0;
}
Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(state);
return values[index++];
}
/// <summary>
/// Возвращает число в диапазоне [0,int.MaxValue)
/// </summary>
public int Next()
{
return (int)(((ulong)int.MaxValue * NextUInt32()) >> 32);
}
/// <summary>
/// Возвращает число в диапазоне [0,maxValue)
/// </summary>
public int Next(int maxValue)
{
if (maxValue <= 0)
throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");
return (int)(((ulong)maxValue * NextUInt32()) >> 32);
}
/// <summary>
/// Возвращает число в диапазоне [minValue,maxValue)
/// </summary>
public int Next(int minValue, int maxValue)
{
if (minValue < 0)
throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
if (maxValue <= minValue)
throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");
return (int)(((ulong)(maxValue - minValue) * NextUInt32()) >> 32) + minValue;
}
/// <summary>
/// Возвращает число в диапазоне [0,long.MaxValue)
/// </summary>
public long NextInt64()
{
return (long)(((BigInteger)long.MaxValue * NextUInt64()) >> 64);
}
/// <summary>
/// Возвращает число в диапазоне [0,maxValue)
/// </summary>
public long NextInt64(long maxValue)
{
if (maxValue <= 0)
throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");
return (long)(((BigInteger)maxValue * NextUInt64()) >> 64);
}
/// <summary>
/// Возвращает число в диапазоне [minValue,maxValue)
/// </summary>
public long NextInt64(long minValue, long maxValue)
{
if (minValue < 0)
throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
if (maxValue <= minValue)
throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");
return (long)(((BigInteger)(maxValue - minValue) * NextUInt64()) >> 64) + minValue;
}
/// <summary>
/// Возвращает число в диапазоне [0,1)
/// </summary>
public double NextDouble()
{
return (NextUInt64() >> 11) / 9007199254740992.0; // 2^53
}
/// <summary>
/// Возвращает число в диапазоне [0,1)
/// </summary>
public float NextSingle()
{
return (NextUInt32() >> 8) / 16777216.0f; // 2^24
}
}
Использовать просто, например вот так:
static void Main(string[] args)
{
MersenneTwister mt = new MersenneTwister();
int randomNumber = mt.Next(10);
Console.WriteLine(randomNumber);
Console.ReadKey();
}
По производительности он хорош, не буду приводить числа, для каждого процессора они свои, а просто скажу, что он всего в 2-4 раза медленнее, чем System.Random, хотя для меня допустима была просадка в 10 и более раз.
Выдает на одном и том же сиде идентичную последовательность чисел, как оригинальная разработка SFMT 1.5.1 на С, то есть с вычислениями здесь тоже всё ок.
От System.Random не стал наследоваться, так как не хотел дергать его внутренню логику. А так как сам System.Random ничего не реализует (могли бы за столько лет и интерфейс к нему прикрутить), смысла связываться с ним я не увидел.
Посмотрите пожалуйста, все ли ок у меня в преобразованиях в публичных методах, и не перемудрил ли я с сидами. А может стоит что-нибудь подкрутить, чтобы стало быстрее?
Ответы (1 шт):
- Исправлено именование полей.
- Оптимизирован, а затем уничтожен метод
PeriodCertification. Вмержил его остатки вInitGenerator. - Исправлен баг нечетного индекса в
NextUInt64(), ранее при генерацииulongесли до этого была генерацияuintи индекс остался нечетным, то старшая часть новогоulongповторяла ранее отданноеuintчисло.
/// <summary>
/// Генератор псевдослучайных чисел на основе алгоритма Вихрь Мерсенна
/// Конфигурация: SFMT-19937:122-18-1-11-1:dfffffef-ddfecb7f-bffaffff-bffffff6
/// </summary>
public class MersenneTwister
{
private const int MersenneExponent = 19937;
private const int Length128 = MersenneExponent / 128 + 1;
private const int Length32 = Length128 * 4;
private readonly Vector128<uint>[] _state = new Vector128<uint>[Length128];
private int _index;
[ThreadStatic]
private static MersenneTwister _shared;
/// <summary>
/// Общий экземпляр генератора, создается отдельно для каждого потока
/// </summary>
public static MersenneTwister Shared => _shared ??= new(GetRandomSeed());
/// <summary>
/// Создает экземпляр генератора со случайным сидом
/// </summary>
public MersenneTwister() : this(Shared.NextUInt32()) { }
/// <summary>
/// Создает экземпляр генератора с указанным сидом
/// </summary>
/// <param name="seed">Сид для создания генератора</param>
public MersenneTwister(int seed) : this((uint)seed) { }
private MersenneTwister(uint seed)
{
InitGenerator(seed);
}
private static uint GetRandomSeed()
{
ReadOnlySpan<byte> bytes = Guid.NewGuid().ToByteArray();
ReadOnlySpan<uint> hash = MemoryMarshal.Cast<byte, uint>(SHA256.HashData(bytes));
uint result = 0;
foreach (uint value in hash)
result ^= value;
return result;
}
private void InitGenerator(uint seed)
{
if (!Sse2.IsSupported)
throw new InvalidOperationException("SSE2 не поддерживается на данном устройстве.");
Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(_state);
values[0] = seed;
for (int i = 1; i < Length32; i++)
{
values[i] = (uint)(1812433253ul * (values[i - 1] ^ (values[i - 1] >> 30)) + (uint)i);
}
_index = Length32;
Vector128<uint> parity = Vector128.Create(0x00000001u, 0x00000000u, 0x00000000u, 0x13c9e684u);
Vector128<uint> v = Sse2.And(_state[0], parity);
uint inner = 0;
for (int i = 0; i < 4; i++)
inner ^= v.GetElement(i);
for (int i = 16; i > 0; i >>= 1)
inner ^= inner >> i;
if ((inner & 1) == 0)
values[0] ^= 1;
}
private void UpdateState()
{
const int offset = 122;
Vector128<uint> r1 = _state[^2];
Vector128<uint> r2 = _state[^1];
int i = 0;
for (; i < Length128 - offset; i++)
{
_state[i] = DoRecursion(_state[i], _state[i + offset], r1, r2);
r1 = r2;
r2 = _state[i];
}
for (; i < Length128; i++)
{
_state[i] = DoRecursion(_state[i], _state[i + offset - Length128], r1, r2);
r1 = r2;
r2 = _state[i];
}
}
private Vector128<uint> DoRecursion(Vector128<uint> a, Vector128<uint> b, Vector128<uint> c, Vector128<uint> d)
{
Vector128<uint> z = Sse2.ShiftRightLogical128BitLane(c, 1);
z = Sse2.Xor(z, a);
Vector128<uint> v = Sse2.ShiftLeftLogical(d, 18);
z = Sse2.Xor(z, v);
Vector128<uint> x = Sse2.ShiftLeftLogical128BitLane(a, 1);
z = Sse2.Xor(z, x);
Vector128<uint> y = Sse2.ShiftRightLogical(b, 11);
Vector128<uint> mask = Vector128.Create(0xdfffffefu, 0xddfecb7fu, 0xbffaffffu, 0xbffffff6u);
y = Sse2.And(y, mask);
return Sse2.Xor(z, y);
}
/// <summary>
/// Возвращает число в диапазоне [0,ulong.MaxValue]
/// </summary>
public ulong NextUInt64()
{
if ((_index & 1) == 1)
_index++;
if (_index >= Length32)
{
UpdateState();
_index = 0;
}
Span<ulong> values = MemoryMarshal.Cast<Vector128<uint>, ulong>(_state);
ulong r = values[_index >> 1];
_index += 2;
return r;
}
/// <summary>
/// Возвращает число в диапазоне [0,uint.MaxValue]
/// </summary>
public uint NextUInt32()
{
if (_index >= Length32)
{
UpdateState();
_index = 0;
}
Span<uint> values = MemoryMarshal.Cast<Vector128<uint>, uint>(_state);
return values[_index++];
}
/// <summary>
/// Возвращает число в диапазоне [0,int.MaxValue)
/// </summary>
public int Next()
{
return (int)(((ulong)int.MaxValue * NextUInt32()) >> 32);
}
/// <summary>
/// Возвращает число в диапазоне [0,maxValue)
/// </summary>
public int Next(int maxValue)
{
if (maxValue <= 0)
throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");
return (int)(((ulong)maxValue * NextUInt32()) >> 32);
}
/// <summary>
/// Возвращает число в диапазоне [minValue,maxValue)
/// </summary>
public int Next(int minValue, int maxValue)
{
if (minValue < 0)
throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
if (maxValue <= minValue)
throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");
return (int)(((ulong)(maxValue - minValue) * NextUInt32()) >> 32) + minValue;
}
/// <summary>
/// Возвращает число в диапазоне [0,long.MaxValue)
/// </summary>
public long NextInt64()
{
return (long)(((BigInteger)long.MaxValue * NextUInt64()) >> 64);
}
/// <summary>
/// Возвращает число в диапазоне [0,maxValue)
/// </summary>
public long NextInt64(long maxValue)
{
if (maxValue <= 0)
throw new ArgumentOutOfRangeException(nameof(maxValue), "Значение должно быть больше 0.");
return (long)(((BigInteger)maxValue * NextUInt64()) >> 64);
}
/// <summary>
/// Возвращает число в диапазоне [minValue,maxValue)
/// </summary>
public long NextInt64(long minValue, long maxValue)
{
if (minValue < 0)
throw new ArgumentOutOfRangeException(nameof(minValue), "Значение должно быть больше либо равно 0.");
if (maxValue <= minValue)
throw new ArgumentOutOfRangeException(nameof(maxValue), $"Значение должно быть больше, чем {nameof(minValue)}.");
return (long)(((BigInteger)(maxValue - minValue) * NextUInt64()) >> 64) + minValue;
}
/// <summary>
/// Возвращает число в диапазоне [0,1)
/// </summary>
public double NextDouble()
{
return (NextUInt64() >> 11) / 9007199254740992.0; // 2^53
}
/// <summary>
/// Возвращает число в диапазоне [0,1)
/// </summary>
public float NextSingle()
{
return (NextUInt32() >> 8) / 16777216.0f; // 2^24
}
}
На данным этапе оставлю так, здесь более оптимизировать особо нечего.