Как подготовить данные временных рядов для LSTM
Всем привет.
Я использую метод контролируемого обучения с сетью LSTM для прогнозирования цен на рынке Форекс. Для этого я использую библиотеку deeplearning4j, но сомневаюсь в некоторых моментах своей реализации.
Я отключил функцию мини-пакетов, затем создал множество торговых индикаторов на основе данных форекс. Я создал итератор набора данных, который перебирает данные временных рядов от начала до конца. На каждой итерации итератор набора данных возвращает набор данных, содержащий одиночные входные данные. Таким образом я перебираю все данные, которые у меня есть.
Первая проблема заключается в том, что в начале каждой эпохи предыдущие входные данные из предыдущей эпохи мешают текущего прогноза.
Вторая проблема заключается в том, что я не знаю, как нормализовать свои данные. Сначала, когда я использовал данные о ценах в том виде, в каком они были, оценки моей модели были слишком малы, поэтому через несколько эпох она начала выдавать одно и то же значение. Я решил использовать только дробную часть цен (1,723455034 -- 7234), и теперь оценки слишком велики (160542,18602891127).
У вас есть идеи по решению этих проблем?
Neural net configuration
public static MultiLayerNetwork buildNetwork(int nIn, int nOut, int windowSize) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(System.currentTimeMillis())
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(Updater.RMSPROP)
.miniBatch(false)
.l2(25e-4)
.list()
.layer(0, new LSTM.Builder()
.nIn(nIn)
.nOut(256)
.activation(Activation.TANH)
.gateActivationFunction(Activation.HARDSIGMOID)
.dropOut(0.2)
.build())
.layer(1, new LSTM.Builder()
.nIn(256)
.nOut(256)
.activation(Activation.TANH)
.gateActivationFunction(Activation.HARDSIGMOID)
.dropOut(0.2)
.build())
.layer(2, new DenseLayer.Builder()
.nIn(256)
.nOut(32)
.activation(Activation.RELU)
.build())
.layer(3, new RnnOutputLayer.Builder()
.nIn(32)
.nOut(nOut)
.activation(Activation.IDENTITY)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTForwardLength(windowSize)
.tBPTTBackwardLength(windowSize)
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
return network;
}
Dataset Iterator
@Override
public DataSet next() {
INDArray observationArray = Nd4j.create(new int[]{1 , this.featureSize, this.windowSize}, 'f');
INDArray labelArray = Nd4j.create(new int[]{1, PREDICTION_VALUES_SIZE, this.windowSize}, 'f');
int windowStartOffset = this.seriesIndex;
int windowEndOffset = windowStartOffset + this.windowSize;
for (int windowOffset = windowStartOffset; windowOffset < windowEndOffset; windowOffset++) {
int windowIndex = windowOffset - windowStartOffset;
for (int featureIndex = ZERO_INDEX; featureIndex < this.featureSize; featureIndex++) {
observationArray.putScalar(
new int[]{ZERO_INDEX, featureIndex, windowIndex},
this.dataProvider.data(windowOffset, featureIndex)
);
}
labelArray.putScalar(new int[]{ZERO_INDEX, ZERO_INDEX, windowIndex},
this.dataProvider.pip(windowOffset + this.predictionStep)
);
}
seriesIndex++;
return new DataSet(observationArray, labelArray);
}
Training
public static final int EPOCHS = 100;
public static final int WINDOW_SIZE = 20;
public static final int PREDICTION_STEP = 1;
public static void prepare(String network, String dataset) throws IOException {
TradingDataProvider provider = new TradingDataProvider(CommonFileTools.loadSeries(dataset));
TradingDataIterator dataIterator = new TradingDataIterator(provider, WINDOW_SIZE, PREDICTION_STEP);
MultiLayerNetwork net = LSTMNetworkFactory.buildNetwork(dataIterator.inputColumns(), dataIterator.totalOutcomes(), WINDOW_SIZE);
long start;
for (int i = 0; i < EPOCHS; i++) {
start = System.currentTimeMillis();
net.fit(dataIterator);
logger.info("Epoch: {}, Score: {}, Duration: {} ms", i + 1, net.score(),
System.currentTimeMillis() - start
);
}
File locationToSave = new File(network);
ModelSerializer.writeModel(net, locationToSave, true);
logger.info("Model saved");
System.exit(0);
}