Segment Tree
Segment Tree๋?
โ ์ด๋ค ๋ฐ์ดํฐ๊ฐ ์กด์ฌํ ๋, ํน์ ๊ตฌ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ์ ๊ตฌํ๋๋ฐ ์ฌ์ฉํ๋ ์๋ฃ ๊ตฌ์กฐ!
โ prefix sum? -> ์ ์ฉํ์ง๋ง ๊ฐ์ ๋ณ๊ฒฝ์ ์ทจ์ฝํ๋ค
โ ๋ฐ๋ผ์ segment tree๋ ์ด์ง ํธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋ค!
โ ๋ชฉํ๋ก ํ๋ ๊ฐ์ ์ต๋๋ก ์ปค๋ฒํ๋ ๋ฒ์์ segment๋ค์ ๋ํด์ ํฉ์ ๊ตฌํ๋ค
โ ๊ฐ ๋ณ๊ฒฝ ์์ ์์ ๋ ธ๋์ ๊ฐ๋ง ๋ฐ๊พธ๋ฉด ๋๊ธฐ ๋๋ฌธ์ logN ์๊ฐ ๋ณต์ก๋๋ก ๋์ํ ์ ์๋ค!
์ ์ฒด ์ฝ๋ (์ฌ๊ท์ ์ผ๋ก ๊ตฌํ)
from math import log2, ceil, gcd
class SegmentTree:
def __init__(self, input_list, calculation_method='sum'):
self.level = 0
self.length = 0
self.input_list = input_list
self.input_list_length = len(self.input_list)
self.input_start_index = 0
self.tree_index = 1
self.input_end_index = self.input_list_length - 1
self.calculation_method = calculation_method
self.result_list = []
def method(self, left_result, right_result):
if self.calculation_method == 'sum':
return left_result + right_result
elif self.calculation_method == 'max':
return max(left_result, right_result)
elif self.calculation_method == 'gcd':
return gcd(left_result, right_result)
def update_process(self, input_start_index, input_end_index, tree_index, update_index, update_value):
# ๊ตฌ๊ฐ์ ์ํฅ์ ๋ฏธ์น์ง ์๋ ๊ฒฝ์ฐ.
if update_index < input_start_index or update_index > input_end_index:
return self.result_list[tree_index]
# ์
๋ฐ์ดํธํ๊ณ ์ํ๋ ์์น์ ๋๋ฌํ ๊ฒฝ์ฐ.
if input_start_index == input_end_index:
self.result_list[tree_index] = update_value
return self.result_list[tree_index]
input_mid_index = (input_start_index + input_end_index) // 2
left_result = self.update_process(input_start_index, input_mid_index, tree_index * 2, update_index, update_value)
right_result = self.update_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, update_index, update_value)
self.result_list[tree_index] = self.method(left_result, right_result)
return self.result_list[tree_index]
def update(self, update_index, update_value):
self.tree_index = 1
self.input_list[update_index] = update_value
self.update_process(self.input_start_index, self.input_end_index, self.tree_index, update_index, update_value)
def get_range_process(self, input_start_index, input_end_index, tree_index, range_start_index, range_end_index):
if input_end_index < range_start_index or input_start_index > range_end_index:
return 0
if input_start_index >= range_start_index and input_end_index <= range_end_index:
return self.result_list[tree_index]
input_mid_index = (input_start_index + input_end_index) // 2
left_result = self.get_range_process(input_start_index, input_mid_index, tree_index * 2, range_start_index, range_end_index)
right_result = self.get_range_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, range_start_index, range_end_index)
return self.method(left_result, right_result)
def get_range(self, range_start_index, range_end_index):
self.tree_index = 1
return self.get_range_process(self.input_start_index, self.input_end_index, self.tree_index, range_start_index, range_end_index)
def process(self, input_start_index, input_end_index, tree_index):
if input_start_index == input_end_index:
self.result_list[tree_index] = self.input_list[input_start_index]
return self.result_list[tree_index]
input_mid_index = (input_start_index + input_end_index) // 2
left_result = self.process(input_start_index, input_mid_index, tree_index * 2)
right_result = self.process(input_mid_index + 1, input_end_index, tree_index * 2 + 1)
self.result_list[tree_index] = self.method(left_result, right_result)
return self.result_list[tree_index]
def make(self):
self.level = ceil(log2(self.input_list_length)) + 1
self.length = pow(2, self.level)
self.result_list = [0] * self.length
self.process(0, self.input_list_length-1, 1)
def main():
#number_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
number_list = [1, 2, 5, 5, 5, 5, 5, 5, 9, 10]
segment_tree_sum = SegmentTree(number_list, 'sum')
segment_tree_sum.make()
print(segment_tree_sum.result_list)
print(segment_tree_sum.get_range(3, 5))
segment_tree_sum.update(4, 7)
print(segment_tree_sum.result_list)
print(segment_tree_sum.get_range(3, 5))
segment_tree_max = SegmentTree(number_list, 'max')
segment_tree_max.make()
print(segment_tree_max.result_list)
print(segment_tree_max.get_range(3, 5))
segment_tree_max.update(4, 7)
print(segment_tree_max.result_list)
print(segment_tree_max.get_range(3, 5))
segment_tree_gcd = SegmentTree(number_list, 'gcd')
segment_tree_gcd.make()
print(segment_tree_gcd.result_list)
print(segment_tree_gcd.get_range(3, 5))
segment_tree_gcd.update(4, 7)
print(segment_tree_gcd.result_list)
print(segment_tree_gcd.get_range(3, 5))
if __name__ == '__main__':
main()
ํธ๋ฆฌ ๋ง๋ค๊ธฐ
def process(self, input_start_index, input_end_index, tree_index):
# 1. ๋ฆฌํ๋
ธ๋๋ผ๋ฉด tree_index์ ํ์ฌ ๊ฐ์ ์ฑ์ฐ๊ณ ํด๋น ๊ฐ์ ๋ฐํ(๊ฐ์ง๊ณ ์ฌ๋ผ์ด)
if input_start_index == input_end_index:
self.result_list[tree_index] = self.input_list[input_start_index]
return self.result_list[tree_index]
# 2. ๋ค์ ์ข/์ฐ ๊ตฌ๋ถํ๊ธฐ ์ํด ์ค๊ฐ๊ฐ ์ฐพ๊ธฐ
input_mid_index = (input_start_index + input_end_index) // 2
# 3. ์ผ์ชฝ๊ฐ๊ณผ ์ค๋ฅธ์ชฝ๊ฐ ๊ฐ์ ธ์ค๊ธฐ
left_result = self.process(input_start_index, input_mid_index, tree_index * 2)
right_result = self.process(input_mid_index + 1, input_end_index, tree_index * 2 + 1)
# 4. ๋ ๊ฐ์ ์ฐ์ฐ๊ฒฐ๊ณผ๋ฅผ ํ ์์น์ ์ ์ฅํ๊ณ ํด๋น๊ฐ์ ๋ฐํ
self.result_list[tree_index] = self.method(left_result, right_result)
return self.result_list[tree_index]
โ ์ฌ๊ท์ ์ผ๋ก ๋ด๋ ค๊ฐ์ ๋ฆฌํ๋ ธ๋์์๋ถํฐ ํธ๋ฆฌ๋ฅผ ์ฑ์ฐ๋ฉด์ ์ฌ๋ผ์จ๋ค!
์ฟผ๋ฆฌ
def get_range_process(self, input_start_index, input_end_index, tree_index, range_start_index, range_end_index):
# ๋ฒ์๋ฅผ ์์ ํ๊ฒ ๋ฒ์ด๋ ์์น(๊ตฌ๊ฐ)์ ๋ฌดํจ
if input_end_index < range_start_index or input_start_index > range_end_index:
return 0
# ๊ตฌ๊ฐ์ ์์ ํ ๋ค์ด๊ฐ๋ ๊ฒฝ์ฐ(๋ฆฌํ์ด๊ฑฐ๋ ์๋ ์๋ ์๋ค.)
# ๋์ด์ ์ฌ๊ท์ ์ผ๋ก ๋ด๋ ค๊ฐ์ง ์๊ณ ๋ฐ๋ก ๊ฐ์ ๋ฐํ๋ค.
if input_start_index >= range_start_index and input_end_index <= range_end_index:
return self.result_list[tree_index]
# ๊ตฌ๊ฐ์ ์ผ๋ถ๋ง ๊ฑธ์น ๊ฒฝ์ฐ ์ข/์ฐ๋ก ์ฌ๊ท์ ์ผ๋ก ๊ณ์ ๋ด๋ ค๊ฐ๋ค.
input_mid_index = (input_start_index + input_end_index) // 2
left_result = self.get_range_process(input_start_index, input_mid_index, tree_index * 2, range_start_index, range_end_index)
right_result = self.get_range_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, range_start_index, range_end_index)
return self.method(left_result, right_result)
โ ํจ์ ํธ์ถ ํํ๋ process(์ฒ์์ ํธ๋ฆฌ๋ฅผ ์ฑ์ฐ๋ ๊ณผ์ )๊ณผ ๋์ผํ๋ค!
โ ๋ด๊ฐ ๊ตฌํ๊ณ ์ ํ๋ ๋ฒ์(range_start_index, range_end_index)๋ ๋ณํ์ง ์๋๋ค.
โ ์ฌ๊ท ํธ์ถ์ ํ๊ณ ๋ด๋ ค๊ฐ๋ฉด์ ๊ตฌํ๊ณ ์ ํ๋ ๋ฒ์์ ์์ ํ ๋ค์ด๊ฐ๋ ๊ตฌ๊ฐ์ ๊ฐ์ ์ฐพ์์ ๋ฐํํ๋ค.
update
def update_process(self, input_start_index, input_end_index, tree_index, update_index, update_value):
# ๊ตฌ๊ฐ์ ์ํฅ์ ๋ฏธ์น์ง ์๋ ๊ฒฝ์ฐ.
# ๊ฐ์ ๊ทธ๋๋ก ๋ฐํํ๋ค.
if update_index < input_start_index or update_index > input_end_index:
return self.result_list[tree_index]
# ์
๋ฐ์ดํธํ๊ณ ์ํ๋ ์์น์ ๋๋ฌํ ๊ฒฝ์ฐ.
# ๊ฐ์ ๋ฐ๊พผ๋ค.
if input_start_index == input_end_index:
self.result_list[tree_index] = update_value
return self.result_list[tree_index]
input_mid_index = (input_start_index + input_end_index) // 2
left_result = self.update_process(input_start_index, input_mid_index, tree_index * 2, update_index, update_value)
right_result = self.update_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, update_index, update_value)
self.result_list[tree_index] = self.method(left_result, right_result)
return self.result_list[tree_index]
def update(self, update_index, update_value):
self.tree_index = 1
self.input_list[update_index] = update_value
self.update_process(self.input_start_index, self.input_end_index, self.tree_index, update_index, update_value)
โ ์
๋ฐ์ดํธ ํ๋ ๋ฆฌํ๋
ธ๋๋ก๋ถํฐ ์๋ก ์ฌ๋ผ๊ฐ๋ฉด์ ๊ฐ์ ์
๋ฐ์ดํธ ํ๋ค
โ ์ํฅ์ ๋ฏธ์น์ง ์๋ ๊ฒฝ์ฐ๋ ์
๋ฐ์ดํธํ์ง ์๊ณ ๊ฐ์ ๊ทธ๋๋ก ๋ฐํํ๋ค
Fenwick Tree
โ '๊ตฌ๊ฐํฉ'์ ๊ตฌํ๋ ๊ฒฝ์ฐ๋ผ๋ฉด ๋ชจ๋ ์ธ๊ทธ๋จผํธ๊ฐ ํ์ํ ๊ฒ ์๋๋ค!
โ ํ์ํ ์ต์์ ์ธ๊ทธ๋จผํธ๋ง ์์ผ๋ฉด ํฉ๊ณผ ์ฐจ๋ฅผ ์ด์ฉํด์ ๋ชจ๋ ๊ตฌ๊ฐํฉ์ ๊ตฌํ ์ ์๋ค.
โ ๋นํธ๋ฅผ ์ด์ฉํด ๊ตฌ๊ฐํฉ์ ๊ตฌํ๋ค!!
โ ๋ง์ง๋ง ๋นํธ์ 1์ ์๋ฆฌ์ ์์น๋งํผ ๋ํด์ค๋ค
'โญ Personal_Study > Algorithm' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[์๊ณ ๋ฆฌ์ฆ] ๋ค์ต์คํธ๋ผ (Dijkstra) (0) | 2022.10.01 |
---|---|
[์๊ณ ๋ฆฌ์ฆ] ์ต์ ์ ์ฅ ํธ๋ฆฌ (MST) (1) | 2022.09.29 |
[์๊ณ ๋ฆฌ์ฆ] Union-Find (1) | 2022.09.28 |
๋๊ธ