接雨水问题详解

接雨水这道题目挺有意思,在面试题中出现频率还挺高的,本文就来步步优化,讲解一下这道题。

先看一下题目:

接雨水 - 图1

就是用一个数组表示一个条形图,问你这个条形图最多能接多少水。

  1. int trap(int[] height);

下面就来由浅入深介绍暴力解法 -> 备忘录解法 -> 双指针解法,在 O(N) 时间 O(1) 空间内解决这个问题。

一、核心思路

我第一次看到这个问题,无计可施,完全没有思路,相信很多朋友跟我一样。所以对于这种问题,我们不要想整体,而应该去想局部;就像之前的文章处理字符串问题,不要考虑如何处理整个字符串,而是去思考应该如何处理每一个字符。

这么一想,可以发现这道题的思路其实很简单。具体来说,仅仅对于位置 i,能装下多少水呢?

接雨水 - 图2

能装 2 格水。为什么恰好是两格水呢?因为 height[i] 的高度为 0,而这里最多能盛 2 格水,2-0=2。

为什么位置 i 最多能盛 2 格水呢?因为,位置 i 能达到的水柱高度和其左边的最高柱子、右边的最高柱子有关,我们分别称这两个柱子高度为 l_maxr_max位置 i 最大的水柱高度就是 min(l_max, r_max)

更进一步,对于位置 i,能够装的水为:

  1. water[i] = min(
  2. # 左边最高的柱子
  3. max(height[0..i]),
  4. # 右边最高的柱子
  5. max(height[i..end])
  6. ) - height[i]

接雨水 - 图3

接雨水 - 图4

这就是本问题的核心思路,我们可以简单写一个暴力算法:

  1. int trap(vector<int>& height) {
  2. int n = height.size();
  3. int ans = 0;
  4. for (int i = 1; i < n - 1; i++) {
  5. int l_max = 0, r_max = 0;
  6. // 找右边最高的柱子
  7. for (int j = i; j < n; j++)
  8. r_max = max(r_max, height[j]);
  9. // 找左边最高的柱子
  10. for (int j = i; j >= 0; j--)
  11. l_max = max(l_max, height[j]);
  12. // 如果自己就是最高的话,
  13. // l_max == r_max == height[i]
  14. ans += min(l_max, r_max) - height[i];
  15. }
  16. return ans;
  17. }

有之前的思路,这个解法应该是很直接粗暴的,时间复杂度 O(N^2),空间复杂度 O(1)。但是很明显这种计算 r_maxl_max 的方式非常笨拙,一般的优化方法就是备忘录。

二、备忘录优化

之前的暴力解法,不是在每个位置 i 都要计算 r_maxl_max 吗?我们直接把结果都缓存下来,别傻不拉几的每次都遍历,这时间复杂度不就降下来了嘛。

我们开两个数组 r_maxl_max 充当备忘录,l_max[i] 表示位置 i 左边最高的柱子高度,r_max[i] 表示位置 i 右边最高的柱子高度。预先把这两个数组计算好,避免重复计算:

  1. int trap(vector<int>& height) {
  2. if (height.empty()) return 0;
  3. int n = height.size();
  4. int ans = 0;
  5. // 数组充当备忘录
  6. vector<int> l_max(n), r_max(n);
  7. // 初始化 base case
  8. l_max[0] = height[0];
  9. r_max[n - 1] = height[n - 1];
  10. // 从左向右计算 l_max
  11. for (int i = 1; i < n; i++)
  12. l_max[i] = max(height[i], l_max[i - 1]);
  13. // 从右向左计算 r_max
  14. for (int i = n - 2; i >= 0; i--)
  15. r_max[i] = max(height[i], r_max[i + 1]);
  16. // 计算答案
  17. for (int i = 1; i < n - 1; i++)
  18. ans += min(l_max[i], r_max[i]) - height[i];
  19. return ans;
  20. }

这个优化其实和暴力解法差不多,就是避免了重复计算,把时间复杂度降低为 O(N),已经是最优了,但是空间复杂度是 O(N)。下面来看一个精妙一些的解法,能够把空间复杂度降低到 O(1)。

三、双指针解法

这种解法的思路是完全相同的,但在实现手法上非常巧妙,我们这次也不要用备忘录提前计算了,而是用双指针边走边算,节省下空间复杂度。

首先,看一部分代码:

  1. int trap(vector<int>& height) {
  2. int n = height.size();
  3. int left = 0, right = n - 1;
  4. int l_max = height[0];
  5. int r_max = height[n - 1];
  6. while (left <= right) {
  7. l_max = max(l_max, height[left]);
  8. r_max = max(r_max, height[right]);
  9. left++; right--;
  10. }
  11. }

对于这部分代码,请问 l_maxr_max 分别表示什么意义呢?

很容易理解,l_maxheight[0..left] 中最高柱子的高度,r_maxheight[right..end] 的最高柱子的高度

明白了这一点,直接看解法:

  1. int trap(vector<int>& height) {
  2. if (height.empty()) return 0;
  3. int n = height.size();
  4. int left = 0, right = n - 1;
  5. int ans = 0;
  6. int l_max = height[0];
  7. int r_max = height[n - 1];
  8. while (left <= right) {
  9. l_max = max(l_max, height[left]);
  10. r_max = max(r_max, height[right]);
  11. // ans += min(l_max, r_max) - height[i]
  12. if (l_max < r_max) {
  13. ans += l_max - height[left];
  14. left++;
  15. } else {
  16. ans += r_max - height[right];
  17. right--;
  18. }
  19. }
  20. return ans;
  21. }

你看,其中的核心思想和之前一模一样,换汤不换药。但是细心的读者可能会发现次解法还是有点细节差异:

之前的备忘录解法,l_max[i]r_max[i] 代表的是 height[0..i]height[i..end] 的最高柱子高度。

  1. ans += min(l_max[i], r_max[i]) - height[i];

接雨水 - 图5

但是双指针解法中,l_maxr_max 代表的是 height[0..left]height[right..end] 的最高柱子高度。比如这段代码:

  1. if (l_max < r_max) {
  2. ans += l_max - height[left];
  3. left++;
  4. }

接雨水 - 图6

此时的 l_maxleft 指针左边的最高柱子,但是 r_max 并不一定是 left 指针右边最高的柱子,这真的可以得到正确答案吗?

其实这个问题要这么思考,我们只在乎 min(l_max, r_max)。对于上图的情况,我们已经知道 l_max < r_max 了,至于这个 r_max 是不是右边最大的,不重要,重要的是 height[i] 能够装的水只和 l_max 有关。

接雨水 - 图7

坚持原创高质量文章,致力于把算法问题讲清楚,欢迎关注我的公众号 labuladong 获取最新文章:

labuladong

newler提供java代码:

暴力解法

  1. public int trap(int[] height) {
  2. int ans = 0;
  3. for (int i = 1; i < height.length - 1; i++) {
  4. int leftMax = 0, rightMax = 0;
  5. // 找右边最高的柱子
  6. for (int j = i; j < height.length; j++) {
  7. rightMax = Math.max(height[j], rightMax);
  8. }
  9. // 找左边最高的柱子
  10. for (int j = i; j >= 0; j--) {
  11. leftMax = Math.max(height[j], leftMax);
  12. }
  13. // 如果自己就是最高的话,
  14. // leftMax == rightMax == height[i]
  15. ans += Math.min(leftMax, rightMax) - height[i];
  16. }
  17. return ans;
  18. }

备忘录优化解法

  1. public int trap(int[] height) {
  2. if (height == null || height.length == 0) return 0;
  3. int ans = 0;
  4. // 数组充当备忘录
  5. int[] leftMax = new int[height.length];
  6. int[] rightMax = new int[height.length];
  7. // 初始化base case
  8. leftMax[0] = height[0];
  9. rightMax[height.length - 1] = height[height.length - 1];
  10. // 从左到右计算leftMax
  11. for (int i = 1; i < height.length; i++) {
  12. leftMax[i] = Math.max(height[i], leftMax[i-1]);
  13. }
  14. // 从右到左计算rightMax
  15. for (int i = height.length - 2; i >= 0; i--) {
  16. rightMax[i] = Math.max(height[i], rightMax[i + 1]);
  17. }
  18. // 计算结果
  19. for (int i = 1; i < height.length - 1; i++) {
  20. ans += Math.min(leftMax[i], rightMax[i]) - height[i];
  21. }
  22. return ans;
  23. }

双指针解法

  1. public int trap(int[] height) {
  2. if (height == null || height.length == 0) return 0;
  3. int ans = 0;
  4. int leftMax, rightMax;
  5. // 左右指针
  6. int left = 0, right = height.length - 1;
  7. // 初始化
  8. leftMax = height[0];
  9. rightMax = height[height.length - 1];
  10. while (left < right) {
  11. // 更新左右两边柱子最大值
  12. leftMax = Math.max(height[left], leftMax);
  13. rightMax = Math.max(height[right], rightMax);
  14. // 相当于ans += Math.min(leftMax, rightMax) - height[i]
  15. if (leftMax < rightMax) {
  16. ans += leftMax - height[left];
  17. left++;
  18. } else {
  19. ans += rightMax - height[right];
  20. right--;
  21. }
  22. }
  23. return ans;
  24. }

eric wang 提供 Java 代码

  1. public int trap(int[] height) {
  2. if (height.length == 0) {
  3. return 0;
  4. }
  5. int n = height.length;
  6. int left = 0, right = n - 1;
  7. int ans = 0;
  8. int l_max = height[0];
  9. int r_max = height[n - 1];
  10. while (left <= right) {
  11. l_max = Math.max(l_max, height[left]);
  12. r_max = Math.max(r_max, height[right]);
  13. if (l_max < r_max) {
  14. ans += l_max - height[left];
  15. left++;
  16. } else {
  17. ans += r_max - height[right];
  18. right--;
  19. }
  20. }
  21. return ans;
  22. }

eric wang 提供 Python3 代码

  1. def trap(self, height: List[int]) -> int:
  2. if not height:
  3. return 0
  4. n = len(height)
  5. left, right = 0, n - 1
  6. ans = 0
  7. l_max = height[0]
  8. r_max = height[n - 1]
  9. while left <= right:
  10. l_max = max(l_max, height[left])
  11. r_max = max(r_max, height[right])
  12. if l_max < r_max:
  13. ans += l_max - height[left]
  14. left += 1
  15. else:
  16. ans += r_max - height[right]
  17. right -= 1
  18. return ans