算法小抄 --dp 篇
2021-08-19 09:25:36 # algorithm

dp

dp是动态规划(dynamic programming)的简称,其实本质也是遍历了一棵N叉数

比如一个凑零钱问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
def coinChange(coins: List[int], amount: int):
def dp(n):
if n == 0: return 0
if n < 0: return -1

res = float('INF')
for coin in coins:
subproblem = dp(n - coin)
# 子问题无解,跳过
if subproblem == -1: continue
res = min(res, 1 + subproblem)
return res if res != float('INF') else -1
return dp(amount)

其实上面的dp函数就是一个遍历n叉树的过程

1
2
3
def dp(n):
for coin in coins:
dp(n - coin)
img

dp的一般形式就是求最值,在运筹学上是一种最优化方法,核心问题就是穷举,因为要求的是最值,所以要把所有可行的答案都穷举出来。

dp一般都有重叠子问题,如果都暴力穷举会使效率低下,所以需要备忘录或dp table来优化穷举过程。

dp具备"最优子结构",这样通过子问题的最值来得到原问题的最值。

斐波那契数列

递归解法

1
2
3
4
5
6
7
8
9
int fib(int n) {
if(n == 0) {
return 0;
}
if(n == 1||n == 2) {
return 1;
}
return fib(n-1)+fib(n-2);
}
img

这样写代码虽然简洁易懂,但是十分低效,有很多重复计算的,如f(17)

带备忘录的递归解法

我们可以造一个「备忘录」,每次算出某个子问题的答案后别急着返回,先记到「备忘录」里再返回

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public int fib(int N) {
// 备忘录全初始化为 0
int[] memo = new int[N + 1];
// 进行带备忘录的递归
return helper(memo, N);
}

private int helper(int[] memo, int n) {
// base case
if (n == 0 || n == 1) return n;
// 已经计算过,不用再计算了
if (memo[n] != 0) return memo[n];
memo[n] = helper(memo, n - 1) + helper(memo, n - 2);
return memo[n];
}

dp 数组的迭代解法

我们可以把这个「备忘录」独立出来成为一张表,叫做 DP table

1
2
3
4
5
6
7
8
9
10
11
12
public int fib(int N) {
if (N == 0) return 0;
int[] dp = new int[N + 1];
// base case
dp[0] = 0; dp[1] = 1;
// 状态转移
for (int i = 2; i <= N; i++) {
dp[i] = dp[i - 1] + dp[i - 2];
}

return dp[N];
}

「状态转移方程」这个名词,实际上就是描述问题结构的数学形式:

img

dp空间优化

根据斐波那契数列的状态转移方程,当前状态只和之前的两个状态有关,其实并不需要那么长的一个 DP table 来存储所有的状态,只要想办法存储之前的两个状态就行了。

1
2
3
4
5
6
7
8
9
10
11
12
int fib(int n) {
if (n < 1) return 0;
if (n == 2 || n == 1)
return 1;
int prev = 1, curr = 1;
for (int i = 3; i <= n; i++) {
int sum = prev + curr;
prev = curr;
curr = sum;
}
return curr;
}

这个技巧就是所谓的「状态压缩」,如果我们发现每次状态转移只需要 DP table 中的一部分,那么可以尝试用状态压缩来缩小 DP table 的大小,只记录必要的数据,上述例子就相当于把DP table 的大小从 n 缩小到 2。

凑零钱问题

题目描述

给你 k 种面值的硬币,面值分别为 c1, c2 ... ck,每种硬币的数量无限,再给一个总金额 amount,
问你最少需要几枚硬币凑出这个金额,如果不可能凑出,算法返回 -1 。

1、确定 base case,这个很简单,显然目标金额 amount 为 0 时算法返回 0,因为不需要任何硬币就已经凑出目标金额了。

2、确定「状态」,也就是原问题和子问题中会变化的变量。由于硬币数量无限,硬币的面额也是题目给定的,只有目标金额会不断地向 base case 靠近,所以唯一的「状态」就是目标金额 amount

3、确定「选择」,也就是导致「状态」产生变化的行为。目标金额为什么变化呢,因为你在选择硬币,你每选择一枚硬币,就相当于减少了目标金额。所以说所有硬币的面值,就是你的「选择」。

4、明确 dp 函数/数组的定义。我们这里讲的是自顶向下的解法,所以会有一个递归的 dp 函数,一般来说函数的参数就是状态转移中会变化的量,也就是上面说到的「状态」;函数的返回值就是题目要求我们计算的量。就本题来说,状态只有一个,即「目标金额」,题目要求我们计算凑出目标金额所需的最少硬币数量。所以我们可以这样定义 dp 函数:

dp(n) 的定义:输入一个目标金额 n,返回凑出目标金额 n 的最少硬币数量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public int coinChange(int[] coins, int amount) {
// 题目要求的最终结果是 dp(amount)
return dp(coins, amount)
}

private int dp(int[] coins, int amount) {
// base case
if (amount == 0) return 0;
if (amount < 0) return -1;

int res = Integer.MAX_VALUE;
for (int coin : coins) {
// 计算子问题的结果
int subProblem = dp(coins, amount - coin);
// 子问题无解则跳过
if (subProblem == -1) continue;
// 在子问题中选择最优解,然后加一
res = Math.min(res, subProblem + 1);
}

return res == Integer.MAX_VALUE ? -1 : res;
}
img

带备忘录的递归

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int[] memo;
public int coinChange(int[] coins, int amount) {
memo = new int[amount + 1];
// dp 数组全都初始化为特殊值
Arrays.fill(memo, -666);

return dp(coins, amount);
}

private int dp(int[] coins, int amount) {
if (amount == 0) return 0;
if (amount < 0) return -1;
// 查备忘录,防止重复计算
if (memo[amount] != -666)
return memo[amount];

int res = Integer.MAX_VALUE;
for (int coin : coins) {
// 计算子问题的结果
int subProblem = dp(coins, amount - coin);
// 子问题无解则跳过
if (subProblem == -1) continue;
// 在子问题中选择最优解,然后加一
res = Math.min(res, subProblem + 1);
}
// 把计算结果存入备忘录
memo[amount] = (res == Integer.MAX_VALUE) ? -1 : res;
return memo[amount];
}

dp 数组的迭代解法

当然,我们也可以自底向上使用 dp table 来消除重叠子问题,关于「状态」「选择」和 base case 与之前没有区别,dp 数组的定义和刚才 dp 函数类似,也是把「状态」,也就是目标金额作为变量。不过 dp 函数体现在函数参数,而 dp 数组体现在数组索引:

dp 数组的定义:当目标金额为 i 时,至少需要 dp[i] 枚硬币凑出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public int coinChange(int[] coins, int amount) {
int[] dp = new int[amount + 1];
// 数组大小为 amount + 1,初始值也为 amount + 1
Arrays.fill(dp, amount + 1);

// base case
dp[0] = 0;
// 外层 for 循环在遍历所有状态的所有取值
for (int i = 0; i < dp.length; i++) {
// 内层 for 循环在求所有选择的最小值
for (int coin : coins) {
// 子问题无解,跳过
if (i - coin < 0) {
continue;
}
dp[i] = Math.min(dp[i], 1 + dp[i - coin]);
}
}
return (dp[amount] == amount + 1) ? -1 : dp[amount];
}