Deep QLearning не работает, веса нейронной сети возрастают до бесконечности

Делаю реализацию алгоритма обучения нейронной сети Deep QLearning на C#. В качестве проверки, работает ли алгоритм, тренирую его на игре "змейка". В качестве реализации нейронной сети и алгоритма back propogation исипользую библиотеку классов AForge.NET. В алгоритме я разобрался, и все написанное выглядит разумно. Есть одно "но". Если использовать в качестве функции активации сигмоидную функцию, значения всех действий, выдаваемые нейронной сетью примерно одинаковые, и нейронная сеть даже спустя 600 000 ходов не умеет искать яблоки, и избегать стен. Если использовать функцию активации ReLu, все намного хуже. В определенный момент (это зависит от настроек) веса нейросети начинают значительно расти, и я не понимаю почему. Это не было вызвано тем, что я обновляю веса так чтобы они выдавали большое значение на выходе. У меня когда наибольшее значение, на которое я обновляю веса алгоритмом back propogation, равняется 6, в нейронной сети уже присутствую веса в 30 и -30, и спустя несколько шагов эти веса превращаются несколько тысяч, затем миллионов, а затем NAN. Однако ReLu может успеть научиться съедать до 5-8 яблок. Веса нейросети с ReLu инициализирую случайными значениями от -0.1 до 0.1, так она дольше живет. Буду рад любым замечаниям

public class DeepQLearning
{
    public DeepQLearning(AForge.Neuro.ActivationNetwork network)
    {
        _network = network;
        _targetNetwork = AForgeExtensions.Features.CloneActivationNetwork(_network);
        _backPropagationLearning = new AForge.Neuro.Learning.BackPropagationLearning(_network);
        _learningRate = 0.5;
        _discountFactor = 0.9;
        _targetNetworkUpdateTime = 400;
        _targetNetworkUpdateTimeElapsed = 0;
    }
    private AForge.Neuro.ActivationNetwork _network; //основная нейронная сеть
    public AForge.Neuro.ActivationNetwork Network { get { return _network; } }
    private AForge.Neuro.ActivationNetwork _targetNetwork; //целевая нейронная сеть, из неё будут браться значения будующих наград
    private AForge.Neuro.Learning.BackPropagationLearning _backPropagationLearning;
    private int _targetNetworkUpdateTime;
    private int _targetNetworkUpdateTimeElapsed;
    /// <summary>
    /// Количество итераций UpdateState(), через которое основная нейронная сеть копируется в целевую.
    /// </summary>
    public int TargetNetworkUpdateTime { get { return _targetNetworkUpdateTime; } set { _targetNetworkUpdateTime = value; } }
    private double _learningRate;
    /// <summary>
    /// Скорость обучения.
    /// </summary>
    public double LearningRate { get { return _learningRate; } set { _learningRate = value; } }
    private double _discountFactor;
    /// <summary>
    /// Фактор дисконтирования, в соответствии с которым размер будующей награды будет уменьшен.
    /// </summary>
    public double DiscountFactor { get { return _discountFactor; } set { _discountFactor = value; } }
    public void UpdateState(double[] previousStateInput, double[] previousStateOutput, int action, double reward, double[] nextStateInput)
    {
        double[] nextStateOutput = _targetNetwork.Compute(nextStateInput);
        double previousQvalue = previousStateOutput[action];
        double updatedQvalue = previousQvalue + _learningRate * (reward + _discountFactor * nextStateOutput.Max() - previousQvalue);
        double[] updatedPreviousStateOutput = new double[previousStateOutput.Length];
        previousStateOutput.CopyTo(updatedPreviousStateOutput, 0);
        updatedPreviousStateOutput[action] = updatedQvalue;
        _backPropagationLearning.Run(previousStateInput, updatedPreviousStateOutput);

        _targetNetworkUpdateTimeElapsed++;
        //если совершили достаточно итераций обновления состояния, обновляем целевую нейронную сеть
        if(_targetNetworkUpdateTimeElapsed >= _targetNetworkUpdateTime)
        {
            _targetNetworkUpdateTimeElapsed = 0;
            _targetNetwork = AForgeExtensions.Features.CloneActivationNetwork(_network);
        }
    }
}
public class ReLuActivationFunction : AForge.Neuro.IActivationFunction, ICloneable
{
    public ReLuActivationFunction()
    {

    }
    public object Clone()
    {
        return new ReLuActivationFunction();
    }
    public double Derivative(double value)
    {
        return value >= 0 ? 1 : 0;
    }
    public double Derivative2(double value)
    {
        return value >= 0 ? 1 : 0;
    }
    public double Function(double value)
    {
        return Math.Max(0, value);
    }
}

Ответы (0 шт):