import java.util.*;
import java.io.*;
import java.lang.*;
public class Abhinav {
public final int N = 4e5;
public static final int M = 110;
public static final int inf = 0x3f3f3f3f;
public int[] a = new int[N];
public int[][] dp = new int[N][M];
public int[] lst = new int[N];
public int[] pre = new int[N];
public int[] nxt = new int[N];
public int i;
public int j;
public int n;
public int m;
public int L;
public int R;
public int sum;
public int cal(int l, int r) {
while (L < l) {
if (nxt[L] <= R) {
sum -= nxt[L] - L;
}
L++;
}
while (L > l) {
--L;
if (nxt[L] <= R) {
sum += nxt[L] - L;
}
}
while (R < r) {
++R;
if (pre[R] >= L) {
sum += R - pre[R];
}
}
while (R > r) {
if (pre[R] >= L) {
sum -= R - pre[R];
}
R--;
}
return sum;
}
public void solve(int l, int r, int L1, int R1, int now) {
if (l > r || L1 > R1)
return;
int mid = (l + r) / 2;
int val = inf;
int pos;
for (int i = L1; i < mid && i < = R1; ++i) {
int tmp = dp[i][now - 1] + cal(i + 1, mid);
if (tmp < val) {
pos = i;
val = tmp;
}
}
dp[mid][now] = val;
solve(l, mid - 1, L1, pos, now);
solve(mid + 1, r, pos, R1, now);
}
public static int main(String args[]) {
Scanner sc = new Scanner(System.in);
ios.sync_with_stdio(false);
n = sc.nextInt();
m = sc.nextInt();
for (i = 1; i <= n; ++i) {
sc.nextInt(a[i]);
}
for (i = 1; i <= n; ++i) {
dp[i][0] = inf;
pre[i] = lst[a[i]];
lst[a[i]] = i;
}
for (i = 1; i <= n; ++i) {
lst[a[i]] = n + 1;
}
for (i = n; i != 0; --i) {
nxt[i] = lst[a[i]];
lst[a[i]] = i;
}
for (i = 1; i <= m; ++i) {
solve(1, n, 0, n, i);
}
System.out.print(dp[n][m]);
System.out.print("\n");
}
}
What I have tried:
import java.util.*;
import java.io.*;
import java.lang.*;
public class Abhinav {
public final int N = 4e5;
public static final int M = 110;
public static final int inf = 0x3f3f3f3f;
public int[] a = new int[N];
public int[][] dp = new int[N][M];
public int[] lst = new int[N];
public int[] pre = new int[N];
public int[] nxt = new int[N];
public int i;
public int j;
public int n;
public int m;
public int L;
public int R;
public int sum;
public int cal(int l, int r) {
while (L < l) {
if (nxt[L] <= R) {
sum -= nxt[L] - L;
}
L++;
}
while (L > l) {
--L;
if (nxt[L] <= R) {
sum += nxt[L] - L;
}
}
while (R < r) {
++R;
if (pre[R] >= L) {
sum += R - pre[R];
}
}
while (R > r) {
if (pre[R] >= L) {
sum -= R - pre[R];
}
R--;
}
return sum;
}
public void solve(int l, int r, int L1, int R1, int now) {
if (l > r || L1 > R1)
return;
int mid = (l + r) / 2;
int val = inf;
int pos;
for (int i = L1; i < mid && i < = R1; ++i) {
int tmp = dp[i][now - 1] + cal(i + 1, mid);
if (tmp < val) {
pos = i;
val = tmp;
}
}
dp[mid][now] = val;
solve(l, mid - 1, L1, pos, now);
solve(mid + 1, r, pos, R1, now);
}
public static int main(String args[]) {
Scanner sc = new Scanner(System.in);
ios.sync_with_stdio(false);
n = sc.nextInt();
m = sc.nextInt();
for (i = 1; i <= n; ++i) {
sc.nextInt(a[i]);
}
for (i = 1; i <= n; ++i) {
dp[i][0] = inf;
pre[i] = lst[a[i]];
lst[a[i]] = i;
}
for (i = 1; i <= n; ++i) {
lst[a[i]] = n + 1;
}
for (i = n; i != 0; --i) {
nxt[i] = lst[a[i]];
lst[a[i]] = i;
}
for (i = 1; i <= m; ++i) {
solve(1, n, 0, n, i);
}
System.out.print(dp[n][m]);
System.out.print("\n");
}
}