๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
โญ Personal_Study/Algorithm

Segment Tree

by ํฌ์ŠคํŠธ์‰์ดํฌ 2022. 12. 3.

Segment Tree

Segment Tree๋ž€?

โœ” ์–ด๋–ค ๋ฐ์ดํ„ฐ๊ฐ€ ์กด์žฌํ•  ๋•Œ, ํŠน์ • ๊ตฌ๊ฐ„์˜ ๊ฒฐ๊ณผ๊ฐ’์„ ๊ตฌํ•˜๋Š”๋ฐ ์‚ฌ์šฉํ•˜๋Š” ์ž๋ฃŒ ๊ตฌ์กฐ!

โœ” prefix sum? -> ์œ ์šฉํ•˜์ง€๋งŒ ๊ฐ’์˜ ๋ณ€๊ฒฝ์— ์ทจ์•ฝํ•˜๋‹ค

โœ” ๋”ฐ๋ผ์„œ segment tree๋Š” ์ด์ง„ ํŠธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„๋‹ค!

โœ” ๋ชฉํ‘œ๋กœ ํ•˜๋Š” ๊ฐ’์„ ์ตœ๋Œ€๋กœ ์ปค๋ฒ„ํ•˜๋Š” ๋ฒ”์œ„์˜ segment๋“ค์„ ๋”ํ•ด์„œ ํ•ฉ์„ ๊ตฌํ•œ๋‹ค

โœ” ๊ฐ’ ๋ณ€๊ฒฝ ์‹œ์— ์ž์‹ ๋…ธ๋“œ์˜ ๊ฐ’๋งŒ ๋ฐ”๊พธ๋ฉด ๋˜๊ธฐ ๋•Œ๋ฌธ์— logN ์‹œ๊ฐ„ ๋ณต์žก๋„๋กœ ๋Œ€์‘ํ•  ์ˆ˜ ์žˆ๋‹ค!

์ „์ฒด ์ฝ”๋“œ (์žฌ๊ท€์ ์œผ๋กœ ๊ตฌํ˜„)

from math import log2, ceil, gcd


class SegmentTree:
    def __init__(self, input_list, calculation_method='sum'):
        self.level = 0
        self.length = 0
        self.input_list = input_list
        self.input_list_length = len(self.input_list)
        self.input_start_index = 0
        self.tree_index = 1
        self.input_end_index = self.input_list_length - 1
        self.calculation_method = calculation_method
        self.result_list = []

    def method(self, left_result, right_result):
        if self.calculation_method == 'sum':
            return left_result + right_result
        elif self.calculation_method == 'max':
            return max(left_result, right_result)
        elif self.calculation_method == 'gcd':
            return gcd(left_result, right_result)

    def update_process(self, input_start_index, input_end_index, tree_index, update_index, update_value):
        # ๊ตฌ๊ฐ„์— ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ.
        if update_index < input_start_index or update_index > input_end_index:
            return self.result_list[tree_index]

        # ์—…๋ฐ์ดํŠธํ•˜๊ณ ์žํ•˜๋Š” ์œ„์น˜์— ๋„๋‹ฌํ•œ ๊ฒฝ์šฐ.
        if input_start_index == input_end_index:
            self.result_list[tree_index] = update_value
            return self.result_list[tree_index]

        input_mid_index = (input_start_index + input_end_index) // 2

        left_result = self.update_process(input_start_index, input_mid_index, tree_index * 2, update_index, update_value)

        right_result = self.update_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, update_index, update_value)

        self.result_list[tree_index] = self.method(left_result, right_result)

        return self.result_list[tree_index]

    def update(self, update_index, update_value):
        self.tree_index = 1
        self.input_list[update_index] = update_value

        self.update_process(self.input_start_index, self.input_end_index, self.tree_index, update_index, update_value)

    def get_range_process(self, input_start_index, input_end_index, tree_index, range_start_index, range_end_index):
        if input_end_index < range_start_index or input_start_index > range_end_index:
            return 0

        if input_start_index >= range_start_index and input_end_index <= range_end_index:
            return self.result_list[tree_index]

        input_mid_index = (input_start_index + input_end_index) // 2

        left_result = self.get_range_process(input_start_index, input_mid_index, tree_index * 2, range_start_index, range_end_index)

        right_result = self.get_range_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, range_start_index, range_end_index)

        return self.method(left_result, right_result)

    def get_range(self, range_start_index, range_end_index):
        self.tree_index = 1
        return self.get_range_process(self.input_start_index, self.input_end_index, self.tree_index, range_start_index, range_end_index)

    def process(self, input_start_index, input_end_index, tree_index):
        if input_start_index == input_end_index:
            self.result_list[tree_index] = self.input_list[input_start_index]
            return self.result_list[tree_index]

        input_mid_index = (input_start_index + input_end_index) // 2

        left_result = self.process(input_start_index, input_mid_index, tree_index * 2)

        right_result = self.process(input_mid_index + 1, input_end_index, tree_index * 2 + 1)

        self.result_list[tree_index] = self.method(left_result, right_result)

        return self.result_list[tree_index]

    def make(self):
        self.level = ceil(log2(self.input_list_length)) + 1
        self.length = pow(2, self.level)
        self.result_list = [0] * self.length
        self.process(0, self.input_list_length-1, 1)


def main():
    #number_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    number_list = [1, 2, 5, 5, 5, 5, 5, 5, 9, 10]

    segment_tree_sum = SegmentTree(number_list, 'sum')
    segment_tree_sum.make()
    print(segment_tree_sum.result_list)
    print(segment_tree_sum.get_range(3, 5))
    segment_tree_sum.update(4, 7)
    print(segment_tree_sum.result_list)
    print(segment_tree_sum.get_range(3, 5))

    segment_tree_max = SegmentTree(number_list, 'max')
    segment_tree_max.make()
    print(segment_tree_max.result_list)
    print(segment_tree_max.get_range(3, 5))
    segment_tree_max.update(4, 7)
    print(segment_tree_max.result_list)
    print(segment_tree_max.get_range(3, 5))

    segment_tree_gcd = SegmentTree(number_list, 'gcd')
    segment_tree_gcd.make()
    print(segment_tree_gcd.result_list)
    print(segment_tree_gcd.get_range(3, 5))
    segment_tree_gcd.update(4, 7)
    print(segment_tree_gcd.result_list)
    print(segment_tree_gcd.get_range(3, 5))


if __name__ == '__main__':
    main()

ํŠธ๋ฆฌ ๋งŒ๋“ค๊ธฐ

def process(self, input_start_index, input_end_index, tree_index):

    # 1. ๋ฆฌํ”„๋…ธ๋“œ๋ผ๋ฉด tree_index์— ํ˜„์žฌ ๊ฐ’์„ ์ฑ„์šฐ๊ณ  ํ•ด๋‹น ๊ฐ’์„ ๋ฐ˜ํ™˜(๊ฐ€์ง€๊ณ  ์˜ฌ๋ผ์˜ด) 
    if input_start_index == input_end_index:
        self.result_list[tree_index] = self.input_list[input_start_index]
        return self.result_list[tree_index]

    # 2. ๋‹ค์Œ ์ขŒ/์šฐ ๊ตฌ๋ถ„ํ•˜๊ธฐ ์œ„ํ•ด ์ค‘๊ฐ„๊ฐ’ ์ฐพ๊ธฐ
    input_mid_index = (input_start_index + input_end_index) // 2

    # 3. ์™ผ์ชฝ๊ฐ’๊ณผ ์˜ค๋ฅธ์ชฝ๊ฐ’ ๊ฐ€์ ธ์˜ค๊ธฐ  
    left_result = self.process(input_start_index, input_mid_index, tree_index * 2)

    right_result = self.process(input_mid_index + 1, input_end_index, tree_index * 2 + 1)

    # 4. ๋‘ ๊ฐ‘์˜ ์—ฐ์‚ฐ๊ฒฐ๊ณผ๋ฅผ ํ˜„ ์œ„์น˜์— ์ €์žฅํ•˜๊ณ  ํ•ด๋‹น๊ฐ’์„ ๋ฐ˜ํ™˜
    self.result_list[tree_index] = self.method(left_result, right_result)
    return self.result_list[tree_index]

โœ” ์žฌ๊ท€์ ์œผ๋กœ ๋‚ด๋ ค๊ฐ€์„œ ๋ฆฌํ”„๋…ธ๋“œ์—์„œ๋ถ€ํ„ฐ ํŠธ๋ฆฌ๋ฅผ ์ฑ„์šฐ๋ฉด์„œ ์˜ฌ๋ผ์˜จ๋‹ค!

์ฟผ๋ฆฌ

    def get_range_process(self, input_start_index, input_end_index, tree_index, range_start_index, range_end_index):

        # ๋ฒ”์œ„๋ฅผ ์™„์ „ํ•˜๊ฒŒ ๋ฒ—์–ด๋‚œ ์œ„์น˜(๊ตฌ๊ฐ„)์€ ๋ฌดํšจ
        if input_end_index < range_start_index or input_start_index > range_end_index:
            return 0 

        # ๊ตฌ๊ฐ„์— ์™„์ „ํžˆ ๋“ค์–ด๊ฐ€๋Š” ๊ฒฝ์šฐ(๋ฆฌํ”„์ด๊ฑฐ๋‚˜ ์•„๋‹ ์ˆ˜๋„ ์žˆ๋‹ค.) 
        # ๋”์ด์ƒ ์žฌ๊ท€์ ์œผ๋กœ ๋‚ด๋ ค๊ฐ€์ง€ ์•Š๊ณ  ๋ฐ”๋กœ ๊ฐ’์„ ๋ฐ˜ํ™˜๋‹ค.
        if input_start_index >= range_start_index and input_end_index <= range_end_index:
            return self.result_list[tree_index]

        # ๊ตฌ๊ฐ„์— ์ผ๋ถ€๋งŒ ๊ฑธ์นœ ๊ฒฝ์šฐ ์ขŒ/์šฐ๋กœ ์žฌ๊ท€์ ์œผ๋กœ ๊ณ„์† ๋‚ด๋ ค๊ฐ„๋‹ค.
        input_mid_index = (input_start_index + input_end_index) // 2

        left_result = self.get_range_process(input_start_index, input_mid_index, tree_index * 2, range_start_index, range_end_index)

        right_result = self.get_range_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, range_start_index, range_end_index)

        return self.method(left_result, right_result)

โœ” ํ•จ์ˆ˜ ํ˜ธ์ถœ ํ˜•ํƒœ๋Š” process(์ฒ˜์Œ์— ํŠธ๋ฆฌ๋ฅผ ์ฑ„์šฐ๋Š” ๊ณผ์ •)๊ณผ ๋™์ผํ•˜๋‹ค!

โœ” ๋‚ด๊ฐ€ ๊ตฌํ•˜๊ณ ์ž ํ•˜๋Š” ๋ฒ”์œ„(range_start_index, range_end_index)๋Š” ๋ณ€ํ•˜์ง€ ์•Š๋Š”๋‹ค.
โœ” ์žฌ๊ท€ ํ˜ธ์ถœ์„ ํƒ€๊ณ  ๋‚ด๋ ค๊ฐ€๋ฉด์„œ ๊ตฌํ•˜๊ณ ์ž ํ•˜๋Š” ๋ฒ”์œ„์— ์™„์ „ํžˆ ๋“ค์–ด๊ฐ€๋Š” ๊ตฌ๊ฐ„์˜ ๊ฐ’์„ ์ฐพ์•„์„œ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

update

    def update_process(self, input_start_index, input_end_index, tree_index, update_index, update_value):
        # ๊ตฌ๊ฐ„์— ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ.
        # ๊ฐ’์„ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
        if update_index < input_start_index or update_index > input_end_index:
            return self.result_list[tree_index]

        # ์—…๋ฐ์ดํŠธํ•˜๊ณ ์žํ•˜๋Š” ์œ„์น˜์— ๋„๋‹ฌํ•œ ๊ฒฝ์šฐ.
        # ๊ฐ’์„ ๋ฐ”๊พผ๋‹ค.
        if input_start_index == input_end_index:
            self.result_list[tree_index] = update_value
            return self.result_list[tree_index]

        input_mid_index = (input_start_index + input_end_index) // 2

        left_result = self.update_process(input_start_index, input_mid_index, tree_index * 2, update_index, update_value)

        right_result = self.update_process(input_mid_index + 1, input_end_index, tree_index * 2 + 1, update_index, update_value)

        self.result_list[tree_index] = self.method(left_result, right_result)

        return self.result_list[tree_index]

    def update(self, update_index, update_value):
        self.tree_index = 1
        self.input_list[update_index] = update_value

        self.update_process(self.input_start_index, self.input_end_index, self.tree_index, update_index, update_value)

โœ” ์—…๋ฐ์ดํŠธ ํ•˜๋Š” ๋ฆฌํ”„๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ ์œ„๋กœ ์˜ฌ๋ผ๊ฐ€๋ฉด์„œ ๊ฐ’์„ ์—…๋ฐ์ดํŠธ ํ•œ๋‹ค
โœ” ์˜ํ–ฅ์„ ๋ฏธ์น˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๋Š” ์—…๋ฐ์ดํŠธํ•˜์ง€ ์•Š๊ณ  ๊ฐ’์„ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค

Fenwick Tree

โœ” '๊ตฌ๊ฐ„ํ•ฉ'์„ ๊ตฌํ•˜๋Š” ๊ฒฝ์šฐ๋ผ๋ฉด ๋ชจ๋“  ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ํ•„์š”ํ•œ ๊ฒŒ ์•„๋‹ˆ๋‹ค!
โœ” ํ•„์š”ํ•œ ์ตœ์†Œ์˜ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์žˆ์œผ๋ฉด ํ•ฉ๊ณผ ์ฐจ๋ฅผ ์ด์šฉํ•ด์„œ ๋ชจ๋“  ๊ตฌ๊ฐ„ํ•ฉ์„ ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค.

โœ” ๋น„ํŠธ๋ฅผ ์ด์šฉํ•ด ๊ตฌ๊ฐ„ํ•ฉ์„ ๊ตฌํ•œ๋‹ค!!

โœ” ๋งˆ์ง€๋ง‰ ๋น„ํŠธ์˜ 1์˜ ์ž๋ฆฌ์˜ ์œ„์น˜๋งŒํผ ๋”ํ•ด์ค€๋‹ค

๋Œ“๊ธ€