Segment tree is a balanced binary tree with O(logn) height given n input segments.
Segment tree supports fast range query O(logn + k), and update O(logn).
Building such a tree takes O(n) time if the input is an array of numbers.
Tree之所以大行其道就是因为其结构和人类组织结构非常接近。就拿公司来说好了,CEO统领全局(root),下面CTO,CFO等各自管理一个部门,每个部门下面又正好有好两个VP,每个VP下面又有两个director,每个director下面又有2个manager,每个manager下面又有2个底层员工(leaf),为什么是2?因为我们用二叉树啊~
故事还是要从LeetCode 307. Range Sum Query – Mutable说起。
题目大意是:给你一个数组,再给你一个范围,让你求这个范围内所有元素的和,其中元素的值是可变的,通过update(index, val)更新。
nums = [1, 3, 5],
sumRange(0, 2) = 1+3+5 = 9
update(1, 2) => [1, 2, 5]
sumRange(0, 2) = 1 + 2 + 5 = 7
暴力求解就是扫描一下这个范围。
时间复杂度:Update O(1), Query O(n)。
如果数组元素不变的话(303题),我们可以使用动态规划求出前n个元素的和然后存在sums数组中。i到j所有元素的和等于0~j所有元素的和减去0~(i-1)所有元素的和,即:
sumRange(i, j) := sums[j] – sums[i – 1] if i > 0 else sums[j]
这样就可以把query的时间复杂度降低到O(1)。
但是这道题元素的值可变,那么就需要维护sums,虽然可以把query的时间复杂度降低到了O(1),但update的时间复杂度是O(n),并没有比暴力求解快。
这个时候就要请出我们今天的主人公Segment Tree了,可以做到
Update: O(logn),Query: O(logn+k)
其实Segment Tree的思想还是很好理解的,比我们之前讲过的Binary Indexed Tree要容易理解的多(回复 SP3 获取视频),但是代码量又是另外一回事情了…
回到一开始讲的公司组织结构上面,假设一个公司有200个底层员工,
最原始的信息(元素的值)只掌握在他们工手里,层层上报。每个manager把他所管理的人的元素值求和之后继续上报,直到CEO。CEO知道但仅知道0-199的和。
当你问CEO,5到199的和是多少?他手上只有0-199的和,一下子慌了,赶紧找来CTO,CFO说你们把5到199的和给报上来,CFO一看报表,心中暗喜:还好我负责的这个区间(100~199)已经计算过了,就直接把之前的总和上报了。CTO一看报表,自己负责0-99这个区间,只知道0-99的和,但5-99的和,还是问下面的人吧… 最后结果再一层层返回给CEO。
说到这里大家应该已经能想象Segment Tree是怎么工作了吧:
每个leaf负责一个元素的值
每个parent负责的范围是他的children所负责的范围的union,并把所有范围内的元素值相加。
同一层的节点没有overlap。
root存储的是所有元素的和。
所以一个SegmentTreeNode需要记录以下信息
start #起始范围
end #终止范围
mid #拆分点,通常是 (start + end) // 2
val #所有子元素的和
left #左子树
right #右子树
Update: 由于每次只更新一个元素的值,所以CEO知道这个员工是哪个人管的,派发下去就行了,然后把新的结果再返回回来。
1 2 3 4 5 6 7 8 9 10 11 12 |
def update(root, index, val): # 到底层员工了,直接更新 if root.start == index and root.end == index: root.val = val return # 根据拆分点,更新左子树或右子树 if val <= root.mid: update(root.left, index, val) else: update(root.right, index, val) # 重新计算和 root.val = root.left?.val + root.right?.val |
T(n) = T(n/2) + 1 => O(logn)
Query: query的range可以是任意的,就有以下三种情况:
case 1: 这个range正好和我负责的range完全相同。比如sumQuery(CTO, 0, 99),这个时候CTO直接返回已经求解过的所有下属的和即可。
case 2: 这个range只由我其中一个下属负责。比如sumQuery(CEO, 0, 10),CEO知道0~10全部由CFO负责,那么他就直接返回sumQuery(CTO, 0, 10)。
case 3: 这个range覆盖了我的两个下属,那么我就需要调用2次。比如sumQuery(CEO, 80, 120),CEO知道0~99由CTO管,100~199由CFO管,所以他只需要返回:
sumQuery(CTO, 80, 99) + sumQuery(CFO, 100, 120)
由此可见sumQuery的时间复杂度是不确定的,运气好时O(1),运气不好时是O(logn+k),k是需要访问的节点的数量。
我做了一个实验,给定N,尝试所有的(i,j)组合,看sumQuery的最坏情况和平均情况,结果如下图:
Query需要访问到的节点数量的worst和average case。Worst case 大约访问 4*logn – 3 个节点,这个数字远远小于n。和n成对数关系。
虽然不像Binary Indexed Tree query是严格的O(logn),但Segment tree query的worst case增长的非常慢,可以说是对数级别的。当N是2^20的时候,worst case也“只需要”访问77个节点。
最后我们再来讲一下这棵树是怎么创建的,其实方法有很多种,一分为二是比较常规的一种做法。
CEO管200人,那么就找2个人(CTO,CFO各管100人。CTO管100人,再找2个VP各管50人,以此类推,直到manager管2个人,2个人都是底层员工(leaf),没有人管(双关)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
def buildTree(start, end): # 超过范围返回空节点 if start > end: return None boss.start = start boss.end = end # 如果是叶子节点,记录元素的值 if start == end: boss.val = nums[start] boss.mid = (start + end) // 2 # 递归构建子树 boss.left = buildTree(start, mid) boss.right = buildTree(mid + 1, end) # 求和 boss.val += boss.left?.val + boss.right?.val return boss |
CEO = buildTree(0, 199)
CEO.left # CTO 负责0~99
CEO.right # CFO 负责100~199
Query: # of nodes visited: Average and worst are both ~O(logn)
Applications
LeetCode 307 Range Sum Query – Mutable
C++
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
// Author: Huahua, running time: 24 ms class SegmentTreeNode { public: SegmentTreeNode(int start, int end, int sum, SegmentTreeNode* left = nullptr, SegmentTreeNode* right = nullptr): start(start), end(end), sum(sum), left(left), right(right){} SegmentTreeNode(const SegmentTreeNode&) = delete; SegmentTreeNode& operator=(const SegmentTreeNode&) = delete; ~SegmentTreeNode() { delete left; delete right; left = right = nullptr; } int start; int end; int sum; SegmentTreeNode* left; SegmentTreeNode* right; }; class NumArray { public: NumArray(vector<int> nums) { nums_.swap(nums); if (!nums_.empty()) root_.reset(buildTree(0, nums_.size() - 1)); } void update(int i, int val) { updateTree(root_.get(), i, val); } int sumRange(int i, int j) { return sumRange(root_.get(), i, j); } private: vector<int> nums_; std::unique_ptr<SegmentTreeNode> root_; SegmentTreeNode* buildTree(int start, int end) { if (start == end) { return new SegmentTreeNode(start, end, nums_[start]); } int mid = start + (end - start) / 2; auto left = buildTree(start, mid); auto right = buildTree(mid + 1, end); auto node = new SegmentTreeNode(start, end, left->sum + right->sum, left, right); return node; } void updateTree(SegmentTreeNode* root, int i, int val) { if (root->start == i && root->end == i) { root->sum = val; return; } int mid = root->start + (root->end - root->start) / 2; if (i <= mid) { updateTree(root->left, i, val); } else { updateTree(root->right, i, val); } root->sum = root->left->sum + root->right->sum; } int sumRange(SegmentTreeNode* root, int i, int j) { if (i == root->start && j == root->end) { return root->sum; } int mid = root->start + (root->end - root->start) / 2; if (j <= mid) { return sumRange(root->left, i, j); } else if (i > mid) { return sumRange(root->right, i, j); } else { return sumRange(root->left, i, mid) + sumRange(root->right, mid + 1, j); } } }; |
Python3
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Author: Huahua, running time: 176 ms class SegmentTreeNode: def __init__(self, start, end, val, left=None, right=None): self.start = start self.end = end self.mid = start + (end - start) // 2 self.val = val self.left = left self.right = right class NumArray: def __init__(self, nums): self.nums = nums if self.nums: self.root = self._buildTree(0, len(nums) - 1) def update(self, i, val): self._updateTree(self.root, i, val) def sumRange(self, i, j): return self._sumRange(self.root, i, j) def _buildTree(self, start, end): if start == end: return SegmentTreeNode(start, end, self.nums[start]) mid = start + (end - start) // 2 left = self._buildTree(start, mid) right = self._buildTree(mid + 1, end) return SegmentTreeNode(start, end, left.val + right.val, left, right) def _updateTree(self, root, i, val): if root.start == i and root.end == i: root.val = val return if i <= root.mid: self._updateTree(root.left, i, val) else: self._updateTree(root.right, i, val) root.val = root.left.val + root.right.val def _sumRange(self, root, i, j): if root.start == i and root.end == j: return root.val if j <= root.mid: return self._sumRange(root.left, i, j) elif i > root.mid: return self._sumRange(root.right, i, j) else: return self._sumRange(root.left, i, root.mid) + self._sumRange(root.right, root.mid + 1, j) |
请尊重作者的劳动成果,转载请注明出处!花花保留对文章/视频的所有权利。
如果您喜欢这篇文章/视频,欢迎您捐赠花花。
If you like my articles / videos, donations are welcome.
Be First to Comment