Как передать массив в нативную функцию и вернуть его обратно

Всем здравствуйте. Недавно начал изучение JNI и появился вопрос как передать массив и вернуть его обратно в java. У меня есть код на C++, который выполняет умножение двух матриц:

void gemm_v2(int M, int N, int K, const float* A, const float* B, float* C)
{
    for (int i = 0; i < M; ++i)
    {
        float* c = C + i * N;
        for (int j = 0; j < N; j += 8)
            _mm256_storeu_ps(c + j + 0, _mm256_setzero_ps());
        for (int k = 0; k < K; ++k)
        {
            const float* b = B + k * N;
            __m256 a = _mm256_set1_ps(A[i * K + k]);
            for (int j = 0; j < N; j += 16)
            {
                _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,
                                                            _mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
                _mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,
                                                            _mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
            }
        }
    }
}

Который собственно принимает на вход размерности матриц, первую матрицу, вторую и третью, в которую будет записан результат. Далее есть структура:

struct buf_t
{
    float* p;
    int n;

    buf_t(int size) : n(size), p((float*)_mm_malloc(size * 4, 64)) {}
    ~buf_t() { _mm_free(p); }
};

И функция заполнения "структурной" матрицы из обычной (Матрица записывается в виде массива):

void init_f(buf_t& buf, float*m)
{
    std::srand(time(0));
    for (int i = 0; i < buf.n; ++i) {
        buf.p[i] = m[i];
    }
}

Непосредственно в C++ коде вызов выглядит так:

buf_t a(M * K), b(K * N), c(M * N); //создали структурные матрицы
init_f(a, M1);
init_f(b, M2); //заполнили матрицы из каких-то М1 и М2
gemm_v2(M, N, K, a.p, b.p, c.p); //перемножили

И вот теперь вопрос как это всё организовать в виде нативного кода. Я попробовал так, но дальше увы не понимаю:

JNIEXPORT jfloatArray JNICALL Java_MatrMultiply_multiply
  (JNIEnv *, jobject, jint m, jint n, jint k, jfloatArray m1, jfloatArray m2){
    int M = (int)m;
    int N = (int)n;
    int K = (int)k;
    float* M1 = (float *)env->GetFloatArrayElements(m1, 0);
    float* M2 = (float *)env->GetFloatArrayElements(m2, 0);
    buf_t a(M * K), b(K * N), c(M * N);
    init_f(a,M1);
    init_f(b,M2);
    gemm_v2(M, N, K, a.p, b.p, c.p);
}

Ответы (0 шт):