Segment tree is a balanced binary tree with O(logn) height given n input segments. Segment tree supports fast range query O(logn + k), and update O(logn). Building such a tree takes O(n) time if the input is an array of numbers.
Tree之所以大行其道就是因为其结构和人类组织结构非常接近。就拿公司来说好了,CEO统领全局(root),下面CTO,CFO等各自管理一个部门,每个部门下面又正好有好两个VP,每个VP下面又有两个director,每个director下面又有2个manager,每个manager下面又有2个底层员工(leaf),为什么是2?因为我们用二叉树啊~
故事还是要从LeetCode 307. Range Sum Query – Mutable说起。
题目大意是:给你一个数组,再给你一个范围,让你求这个范围内所有元素的和,其中元素的值是可变的,通过update(index, val)更新。
nums = [1, 3, 5], sumRange(0, 2) = 1+3+5 = 9 update(1, 2) => [1, 2, 5] sumRange(0, 2) = 1 + 2 + 5 = 7
暴力求解就是扫描一下这个范围。
时间复杂度:Update O(1), Query O(n)。
如果数组元素不变的话(303题),我们可以使用动态规划求出前n个元素的和然后存在sums数组中。i到j所有元素的和等于0~j所有元素的和减去0~(i-1)所有元素的和,即:
sumRange(i, j) := sums[j] – sums[i – 1] if i > 0 else sums[j]
这样就可以把query的时间复杂度降低到O(1)。
但是这道题元素的值可变,那么就需要维护sums,虽然可以把query的时间复杂度降低到了O(1),但update的时间复杂度是O(n),并没有比暴力求解快。
这个时候就要请出我们今天的主人公Segment Tree了,可以做到
Update: O(logn),Query: O(logn+k)
其实Segment Tree的思想还是很好理解的,比我们之前讲过的Binary Indexed Tree要容易理解的多(回复 SP3 获取视频),但是代码量又是另外一回事情了…
回到一开始讲的公司组织结构上面,假设一个公司有200个底层员工,
最原始的信息(元素的值)只掌握在他们工手里,层层上报。每个manager把他所管理的人的元素值求和之后继续上报,直到CEO。CEO知道但仅知道0-199的和。
当你问CEO,5到199的和是多少?他手上只有0-199的和,一下子慌了,赶紧找来CTO,CFO说你们把5到199的和给报上来,CFO一看报表,心中暗喜:还好我负责的这个区间(100~199)已经计算过了,就直接把之前的总和上报了。CTO一看报表,自己负责0-99这个区间,只知道0-99的和,但5-99的和,还是问下面的人吧… 最后结果再一层层返回给CEO。
说到这里大家应该已经能想象Segment Tree是怎么工作了吧:
每个leaf负责一个元素的值
每个parent负责的范围是他的children所负责的范围的union,并把所有范围内的元素值相加。
同一层的节点没有overlap。
root存储的是所有元素的和。
所以一个SegmentTreeNode需要记录以下信息
start #起始范围 end #终止范围 mid #拆分点,通常是 (start + end) // 2 val #所有子元素的和 left #左子树 right #右子树
Update: 由于每次只更新一个元素的值,所以CEO知道这个员工是哪个人管的,派发下去就行了,然后把新的结果再返回回来。
def update ( root , index , val ) :
# 到底层员工了,直接更新
if root . start == index and root . end == index :
root . val = val
return
# 根据拆分点,更新左子树或右子树
if val < = root . mid :
update ( root . left , index , val )
else :
update ( root . right , index , val )
# 重新计算和
root . val = root . left ? . val + root . right ? . val
T(n) = T(n/2) + 1 => O(logn)
Query: query的range可以是任意的,就有以下三种情况:
case 1: 这个range正好和我负责的range完全相同。比如sumQuery(CTO, 0, 99),这个时候CTO直接返回已经求解过的所有下属的和即可。
case 2: 这个range只由我其中一个下属负责。比如sumQuery(CEO, 0, 10),CEO知道0~10全部由CFO负责,那么他就直接返回sumQuery(CTO, 0, 10)。
case 3: 这个range覆盖了我的两个下属,那么我就需要调用2次。比如sumQuery(CEO, 80, 120),CEO知道0~99由CTO管,100~199由CFO管,所以他只需要返回:
sumQuery(CTO, 80, 99) + sumQuery(CFO, 100, 120)
由此可见sumQuery的时间复杂度是不确定的,运气好时O(1),运气不好时是O(logn+k),k是需要访问的节点的数量。
我做了一个实验,给定N,尝试所有的(i,j)组合,看sumQuery的最坏情况和平均情况,结果如下图:
Query需要访问到的节点数量的worst和average case。Worst case 大约访问 4*logn – 3 个节点,这个数字远远小于n。和n成对数关系。
虽然不像Binary Indexed Tree query是严格的O(logn),但Segment tree query的worst case增长的非常慢,可以说是对数级别的。当N是2^20的时候,worst case也“只需要”访问77个节点。
最后我们再来讲一下这棵树是怎么创建的,其实方法有很多种,一分为二是比较常规的一种做法。
CEO管200人,那么就找2个人(CTO,CFO各管100人。CTO管100人,再找2个VP各管50人,以此类推,直到manager管2个人,2个人都是底层员工(leaf),没有人管(双关)。
def buildTree ( start , end ) :
# 超过范围返回空节点
if start > end : return None
boss . start = start
boss . end = end
# 如果是叶子节点,记录元素的值
if start == end : boss . val = nums [ start ]
boss . mid = ( start + end ) // 2
# 递归构建子树
boss . left = buildTree ( start , mid )
boss . right = buildTree ( mid + 1 , end )
# 求和
boss . val += boss . left ? . val + boss . right ? . val
return boss
CEO = buildTree(0, 199)
CEO.left # CTO 负责0~99 CEO.right # CFO 负责100~199
Query: # of nodes visited: Average and worst are both ~O(logn)
Applications
LeetCode 307 Range Sum Query – Mutable
C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// Author: Huahua, running time: 24 ms
class SegmentTreeNode {
public :
SegmentTreeNode ( int start , int end , int sum ,
SegmentTreeNode * left = nullptr ,
SegmentTreeNode * right = nullptr ) :
start ( start ) ,
end ( end ) ,
sum ( sum ) ,
left ( left ) ,
right ( right ) { }
SegmentTreeNode ( const SegmentTreeNode &) = delete;
SegmentTreeNode & operator=(const SegmentTreeNode&) = delete;
~ SegmentTreeNode ( ) {
delete left ;
delete right ;
left = right = nullptr ;
}
int start ;
int end ;
int sum ;
SegmentTreeNode * left ;
SegmentTreeNode * right ;
} ;
class NumArray {
public :
NumArray ( vector < int > nums ) {
nums_ . swap ( nums ) ;
if ( ! nums_ . empty ( ) )
root_ . reset ( buildTree ( 0 , nums_ . size ( ) - 1 ) ) ;
}
void update ( int i , int val ) {
updateTree ( root_ . get ( ) , i , val ) ;
}
int sumRange ( int i , int j ) {
return sumRange ( root_ . get ( ) , i , j ) ;
}
private :
vector < int > nums_ ;
std : : unique_ptr < SegmentTreeNode > root_ ;
SegmentTreeNode * buildTree ( int start , int end ) {
if ( start == end ) {
return new SegmentTreeNode ( start , end , nums_ [ start ] ) ;
}
int mid = start + ( end - start ) / 2 ;
auto left = buildTree ( start , mid ) ;
auto right = buildTree ( mid + 1 , end ) ;
auto node = new SegmentTreeNode ( start , end , left -> sum + right -> sum , left , right ) ;
return node ;
}
void updateTree ( SegmentTreeNode * root , int i , int val ) {
if ( root -> start == i && root->end == i) {
root->sum = val;
return ;
}
int mid = root -> start + ( root -> end - root -> start ) / 2 ;
if ( i < = mid ) {
updateTree ( root -> left , i , val ) ;
} else {
updateTree ( root -> right , i , val ) ;
}
root -> sum = root -> left -> sum + root -> right -> sum ;
}
int sumRange ( SegmentTreeNode * root , int i , int j ) {
if ( i == root -> start && j == root->end) {
return root->sum;
}
int mid = root -> start + ( root -> end - root -> start ) / 2 ;
if ( j < = mid ) {
return sumRange ( root -> left , i , j ) ;
} else if ( i > mid ) {
return sumRange ( root -> right , i , j ) ;
} else {
return sumRange ( root -> left , i , mid ) + sumRange ( root -> right , mid + 1 , j ) ;
}
}
} ;
Python3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Author: Huahua, running time: 176 ms
class SegmentTreeNode :
def __init__ ( self , start , end , val , left =None , right =None ) :
self . start = start
self . end = end
self . mid = start + ( end - start ) // 2
self . val = val
self . left = left
self . right = right
class NumArray :
def __init__ ( self , nums ) :
self . nums = nums
if self . nums :
self . root = self . _buildTree ( 0 , len ( nums ) - 1 )
def update ( self , i , val ) :
self . _updateTree ( self . root , i , val )
def sumRange ( self , i , j ) :
return self . _sumRange ( self . root , i , j )
def _buildTree ( self , start , end ) :
if start == end : return SegmentTreeNode ( start , end , self . nums [ start ] )
mid = start + ( end - start ) // 2
left = self . _buildTree ( start , mid )
right = self . _buildTree ( mid + 1 , end )
return SegmentTreeNode ( start , end , left . val + right . val , left , right )
def _updateTree ( self , root , i , val ) :
if root . start == i and root . end == i :
root . val = val
return
if i < = root . mid :
self . _updateTree ( root . left , i , val )
else :
self . _updateTree ( root . right , i , val )
root . val = root . left . val + root . right . val
def _sumRange ( self , root , i , j ) :
if root . start == i and root . end == j :
return root . val
if j < = root . mid :
return self . _sumRange ( root . left , i , j )
elif i > root . mid :
return self . _sumRange ( root . right , i , j )
else :
return self . _sumRange ( root . left , i , root . mid ) + self . _sumRange ( root . right , root . mid + 1 , j )