动态规划设计:最长递增子序列

很多读者反应,就算看了前文动态规划详解,了解了动态规划的套路,也不会写状态转移方程,没有思路,怎么办?本文就借助「最长递增子序列」来讲一种设计动态规划的通用技巧:数学归纳思想。

最长递增子序列(Longest Increasing Subsequence,简写 LIS)是比较经典的一个问题,比较容易想到的是动态规划解法,时间复杂度 O(N^2),我们借这个问题来由浅入深讲解如何写动态规划。比较难想到的是利用二分查找,时间复杂度是 O(NlogN),我们通过一种简单的纸牌游戏来辅助理解这种巧妙的解法。

先看一下题目,很容易理解:

title

注意「子序列」和「子串」这两个名词的区别,子串一定是连续的,而子序列不一定是连续的。下面先来一步一步设计动态规划算法解决这个问题。

一、动态规划解法

动态规划的核心设计思想是数学归纳法。

相信大家对数学归纳法都不陌生,高中就学过,而且思路很简单。比如我们想证明一个数学结论,那么我们先假设这个结论在 $k<n$ 时成立,然后想办法证明 $k=n$ 的时候此结论也成立。如果能够证明出来,那么就说明这个结论对于 k 等于任何数都成立。

类似的,我们设计动态规划算法,不是需要一个 dp 数组吗?我们可以假设 $dp[0…i-1]$ 都已经被算出来了,然后问自己:怎么通过这些结果算出 dp[i]?

直接拿最长递增子序列这个问题举例你就明白了。不过,首先要定义清楚 dp 数组的含义,即 dp[i] 的值到底代表着什么?

我们的定义是这样的:dp[i] 表示以 nums[i] 这个数结尾的最长递增子序列的长度。

举两个例子:

1

2

算法演进的过程是这样的,:

gif1

根据这个定义,我们的最终结果(子序列的最大长度)应该是 dp 数组中的最大值。

  1. int res = 0;
  2. for (int i = 0; i < dp.size(); i++) {
  3. res = Math.max(res, dp[i]);
  4. }
  5. return res;

读者也许会问,刚才这个过程中每个 dp[i] 的结果是我们肉眼看出来的,我们应该怎么设计算法逻辑来正确计算每个 dp[i] 呢?

这就是动态规划的重头戏了,要思考如何进行状态转移,这里就可以使用数学归纳的思想:

我们已经知道了 $dp[0…4]$ 的所有结果,我们如何通过这些已知结果推出 $dp[5]$ 呢?

3

根据刚才我们对 dp 数组的定义,现在想求 dp[5] 的值,也就是想求以 nums[5] 为结尾的最长递增子序列。

nums[5] = 3,既然是递增子序列,我们只要找到前面那些结尾比 3 小的子序列,然后把 3 接到最后,就可以形成一个新的递增子序列,而且这个新的子序列长度加一。

当然,可能形成很多种新的子序列,但是我们只要最长的,把最长子序列的长度作为 dp[5] 的值即可。

gif2

  1. for (int j = 0; j < i; j++) {
  2. if (nums[i] > nums[j])
  3. dp[i] = Math.max(dp[i], dp[j] + 1);
  4. }

这段代码的逻辑就可以算出 dp[5]。到这里,这道算法题我们就基本做完了。读者也许会问,我们刚才只是算了 dp[5] 呀,dp[4], dp[3] 这些怎么算呢?

类似数学归纳法,你已经可以算出 dp[5] 了,其他的就都可以算出来:

  1. for (int i = 0; i < nums.length; i++) {
  2. for (int j = 0; j < i; j++) {
  3. if (nums[i] > nums[j])
  4. dp[i] = Math.max(dp[i], dp[j] + 1);
  5. }
  6. }

还有一个细节问题,dp 数组应该全部初始化为 1,因为子序列最少也要包含自己,所以长度最小为 1。下面我们看一下完整代码:

  1. public int lengthOfLIS(int[] nums) {
  2. int[] dp = new int[nums.length];
  3. // dp 数组全都初始化为 1
  4. Arrays.fill(dp, 1);
  5. for (int i = 0; i < nums.length; i++) {
  6. for (int j = 0; j < i; j++) {
  7. if (nums[i] > nums[j])
  8. dp[i] = Math.max(dp[i], dp[j] + 1);
  9. }
  10. }
  11. int res = 0;
  12. for (int i = 0; i < dp.length; i++) {
  13. res = Math.max(res, dp[i]);
  14. }
  15. return res;
  16. }

至此,这道题就解决了,时间复杂度 O(N^2)。总结一下动态规划的设计流程:

首先明确 dp 数组所存数据的含义。这步很重要,如果不得当或者不够清晰,会阻碍之后的步骤。

然后根据 dp 数组的定义,运用数学归纳法的思想,假设 $dp[0…i-1]$ 都已知,想办法求出 $dp[i]$,一旦这一步完成,整个题目基本就解决了。

但如果无法完成这一步,很可能就是 dp 数组的定义不够恰当,需要重新定义 dp 数组的含义;或者可能是 dp 数组存储的信息还不够,不足以推出下一步的答案,需要把 dp 数组扩大成二维数组甚至三维数组。

最后想一想问题的 base case 是什么,以此来初始化 dp 数组,以保证算法正确运行。

二、二分查找解法

这个解法的时间复杂度会将为 O(NlogN),但是说实话,正常人基本想不到这种解法(也许玩过某些纸牌游戏的人可以想出来)。所以如果大家了解一下就好,正常情况下能够给出动态规划解法就已经很不错了。

根据题目的意思,我都很难想象这个问题竟然能和二分查找扯上关系。其实最长递增子序列和一种叫做 patience game 的纸牌游戏有关,甚至有一种排序方法就叫做 patience sorting(耐心排序)。

为了简单起见,后文跳过所有数学证明,通过一个简化的例子来理解一下思路。

首先,给你一排扑克牌,我们像遍历数组那样从左到右一张一张处理这些扑克牌,最终要把这些牌分成若干堆。

poker1

处理这些扑克牌要遵循以下规则:

只能把点数小的牌压到点数比它大的牌上。如果当前牌点数较大没有可以放置的堆,则新建一个堆,把这张牌放进去。如果当前牌有多个堆可供选择,则选择最左边的堆放置。

比如说上述的扑克牌最终会被分成这样 5 堆(我们认为 A 的值是最大的,而不是 1)。

poker2

为什么遇到多个可选择堆的时候要放到最左边的堆上呢?因为这样可以保证牌堆顶的牌有序(2, 4, 7, 8, Q),证明略。

poker3

按照上述规则执行,可以算出最长递增子序列,牌的堆数就是最长递增子序列的长度,证明略。

LIS

我们只要把处理扑克牌的过程编程写出来即可。每次处理一张扑克牌不是要找一个合适的牌堆顶来放吗,牌堆顶的牌不是有序吗,这就能用到二分查找了:用二分查找来搜索当前牌应放置的位置。

PS:旧文二分查找算法详解详细介绍了二分查找的细节及变体,这里就完美应用上了。如果没读过强烈建议阅读。

  1. public int lengthOfLIS(int[] nums) {
  2. int[] top = new int[nums.length];
  3. // 牌堆数初始化为 0
  4. int piles = 0;
  5. for (int i = 0; i < nums.length; i++) {
  6. // 要处理的扑克牌
  7. int poker = nums[i];
  8. /***** 搜索左侧边界的二分查找 *****/
  9. int left = 0, right = piles;
  10. while (left < right) {
  11. int mid = (left + right) / 2;
  12. if (top[mid] > poker) {
  13. right = mid;
  14. } else if (top[mid] < poker) {
  15. left = mid + 1;
  16. } else {
  17. right = mid;
  18. }
  19. }
  20. /*********************************/
  21. // 没找到合适的牌堆,新建一堆
  22. if (left == piles) piles++;
  23. // 把这张牌放到牌堆顶
  24. top[left] = poker;
  25. }
  26. // 牌堆数就是 LIS 长度
  27. return piles;
  28. }

至此,二分查找的解法也讲解完毕。

这个解法确实很难想到。首先涉及数学证明,谁能想到按照这些规则执行,就能得到最长递增子序列呢?其次还有二分查找的运用,要是对二分查找的细节不清楚,给了思路也很难写对。

所以,这个方法作为思维拓展好了。但动态规划的设计方法应该完全理解:假设之前的答案已知,利用数学归纳的思想正确进行状态的推演转移,最终得到答案。

Hanmin 提供 Python3 代码:

动态规划解法

  1. def lengthOfLIS(self, nums: List[int]) -> int:
  2. n = len(nums)
  3. ## dp 数组全部初始化为1
  4. dp = [1 for x in range(0,n)]
  5. for i in range(0,n):
  6. for j in range(0,i):
  7. if nums[i] > nums[j]:
  8. dp[i] = max(dp[i],dp[j]+1)
  9. res = 0
  10. for temp in dp:
  11. res = max(temp,res)
  12. return res

二分查找解法

  1. def lengthOfLIS(self, nums: List[int]) -> int:
  2. top = []
  3. ##牌堆初始化为0
  4. piles = 0
  5. for num in nums:
  6. ## num为要处理的扑克牌
  7. ##二分查找
  8. left, right = 0, piles
  9. while left < right:
  10. mid = (left + right ) // 2
  11. ##搜索左侧边界
  12. if top[mid] > num:
  13. right = mid
  14. ##搜索右侧边界
  15. elif top[mid] < num:
  16. left = mid + 1
  17. else:
  18. right = mid
  19. if left == piles:
  20. ##没有找到合适的堆,新建一堆
  21. piles += 1
  22. ##把这张牌放到牌堆顶
  23. top[left] = num
  24. return piles
  25. ##牌堆数就是LIS的长度