Оптимизация несложного кода (String.valueOf)

Дана задача:

Входные выходные данные:

Примеры данных:

Я решил эту задачу, но в моем алгоритме я каждый элемент массива перевожу в строку методом String.valueOf(), что тратит очень много памяти. Как можно оптимизировать код?

public static void main(String[] args) {
    Scanner scanner = new Scanner(System.in);
    int n = scanner.nextInt();
    int k = scanner.nextInt();
    int[] list = new int[n + 1];
    for (int i = 0; i < n; i++) {
        list[i] = scanner.nextInt();
    }
    System.out.println(findMaxDifference(list,k));
}

public static long findMaxDifference(int[] list, int k) {
    long result = 0;
    for (int i = 0; i < k; i++) {
        list = refactorNum(list);
        result += list[list.length - 1]; // увеличиваю результат по последнему индексу массива
        list[list.length - 1] = 0; // обновляю его
    }
    return result;
}

private static int[] refactorNum(int[] list) {
    int prevIndex = -1;
    int prevValue = -1;
    for (int i = 0; i < list.length-1; i++) {
        String stringNum = String.valueOf(list[i]); // (проблема) создаю n*k объектов String
        for (int j = 0; j < stringNum.length(); j++) {
            if (stringNum.charAt(j) != '9') {
                stringNum = stringNum.replaceFirst(String.valueOf(stringNum.charAt(j)),"9");
                int parsedInt = Integer.parseInt(stringNum);
                if (parsedInt - list[i] > list[list.length-1]) {
                    if (prevIndex != -1) {
                        list[prevIndex] = prevValue;
                    }
                    prevIndex = i;
                    prevValue = list[i];
                    list[list.length-1] = parsedInt - list[i]; // ожидаемый результат
                    list[i] = parsedInt;
                }
                break;
            }
        }
    }
    return list;
}

}


Ответы (1 шт):

Автор решения: Stanislav Volodarskiy

Да, вы действительно делаете nk операций чтобы получить результат. Это слишком много, но не из-за памяти, а из-за скорости. Можно управиться за nlogn.

Пусть на входе число 12345. Какие можно получить "увеличения" работая с его цифрами? Каждую цифру вычтем из девяти и припишем к ней нули по номеру позиции в числе:

12345 → 80000, 7000, 600, 50, 4.

Все входные числа превратим в такие наборы "увеличений". "Увеличения" соберём в один список, упорядочим его по убыванию и сосчитаем сумму первых k элементов этого списка:

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Scanner;

public class Temp {
    public static void main(String... args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        List<Integer> diffs = new ArrayList<>();
        for (int i = 0; i < n; ++i) {
            String s = sc.next();
            int p10 = 1;
            for (int j = s.length() - 1; j >= 0; --j) {
                diffs.add((9 - (s.charAt(j) - '0')) * p10);
                p10 *= 10;
            }
        }
        diffs.sort(Collections.reverseOrder());
        if (k < diffs.size()) {
            diffs = diffs.subList(0, k);
        }
        long sum = 0;
        for (int d : diffs) {
            sum += d;
        }
        System.out.println(sum);
    }
}

Скорее всего задачу в таком виде примут. Но можно сделать лучше. Заметим что в "увеличения" бывают только вида d·10j, где 0 ≤ d ≤ 9, 0 ≤ j ≤ 9. То есть их не более ста разных вариантов. А это значит что вместо сортировки можно устроить импровизированную сортировку подсчётом за время n + k, что лучше чем nlogn. И память нужна будет только константная.

import java.util.Scanner;

public class Temp {
    public static void main(String... args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        int[] p10 = new int[10];
        for (int j = 0, p = 1; j < p10.length; ++j, p *= 10) {
            p10[j] = p;
        }

        int[][] diffs = new int[10][10];
        for (int i = 0; i < n; ++i) {
            String s = sc.next();
            for (int j = s.length() - 1; j >= 0; --j) {
                ++diffs[s.length() - j - 1][9 - (s.charAt(j) - '0')];
            }
        }

        long sum = 0;
        for (int j = p10.length - 1; j >= 0; --j) {
            for (int d = 9; d >= 0; --d) {
                int r = Math.min(diffs[j][d], k);
                sum += d * r * (long)p10[j];
                k -= r;
            }
        }
        System.out.println(sum);
    }
}
→ Ссылка