[BOJ] 2042번 구간 합 구하기
복기
세그먼트 트리 문제에 대표적인 문제다. 세그먼트 트리에 대해서는 이 문서를 참고하라. 노드의 의미를 어떻게 설정할지가 관건이다.
코드
C++ 17
#include <stdio.h>
using ll = long long;
#define MAX_N 1000000
size_t N, M, K; ll Arr[MAX_N + 1]; ll Tree[MAX_N * 4];
// @Desc : Arr의 [s, e]의 합으로 세그먼트 트리를 초기화한다. // @Return : node의 값 // @Param // node : 노드 번호 // s : 범위의 시작 // e : 범위의 마지막 ll Init(size_t node, size_t s, size_t e) { // 단말 노드라면 if (s == e) Tree[node] = Arr[s]; // 단말 노드가 아니라면 else { size_t m = (s + e) / 2; // 왼쪽 자식 노드와 오른쪽 자식 노드의 합이다. Tree[node] = Init(node * 2, s, m) + Init(node * 2 + 1, m + 1, e); }
return Tree[node]; }
// @Desc : Arr[index]의 값이 바뀌었을 때, 세그먼트 트리의 데이터를 갱신한다. // @Param // node : 현재 노드 // s, e : 노드가 담당하고 있는 범위 // index : 바뀌는 위치 // diff : 바뀌는 값 void Update(size_t node, size_t s, size_t e, size_t index, ll diff) { // 해당 노드가 관련이 없을 때 끝낸다. if (index < s || e < index) return;
// 값을 갱신한다. Tree[node] += diff;
// 단말 노드가 아니라면 자식 노드도 변경한다. if (s != e) { ll m = (s + e) / 2; Update(node * 2, s, m, index, diff); Update(node * 2 + 1, m + 1, e, index, diff); } }
// @Desc : Arr의 [l, r] 범위의 합을 구한다. // @Return : Arr[l] + Arr[l + 1] + ... + Arr[r] // @Param // node : 현재 노드 // s, e : 노드가 담당하고 있는 범위 // l, r : 구하고자 하는 범위 ll Sum(size_t node, size_t s, size_t e, size_t l, size_t r) { // 현재 노드와 관련이 없을 때 if (r < s || e < l) return 0; // 현재 노드가 구하고자 하는 범위에 완전히 포함될 때 if (l <= s && e <= r) return Tree[node];
// 그 외에는 자식 노드를 통해서 구한다. ll m = (s + e) / 2; return Sum(node * 2, s, m, l, r) + Sum(node * 2 + 1, m + 1, e, l, r); }
int main() { scanf("%lld %lld %lld", &N, &M, &K); for (size_t i = 1; i <= N; i++) scanf("%lld", &Arr[i]); Init(1, 1, N);
for (size_t i = 0; i < M + K; i++) { ll a, b, c; scanf("%lld %lld %lld", &a, &b, &c); if (a == 1) { ll diff = c - Arr[b]; Arr[b] = c; Update(1, 1, N, b, diff); } else { printf("%lld\n", Sum(1, 1, N, b, c)); } } } |
C# 6.0
using System;
namespace Csharp { class Program { const int MAX_N = 1000000; static int N, M, K; static long[] Arr = new long[MAX_N + 1]; static long[] Tree = new long[MAX_N * 4];
static void Main(string[] args) { var inputs = Console.ReadLine().Split(); N = Convert.ToInt32(inputs[0]); M = Convert.ToInt32(inputs[1]); K = Convert.ToInt32(inputs[2]); for (int i = 1; i <= N; i++) Arr[i] = Convert.ToInt64(Console.ReadLine()); Init(1, 1, N);
for (int i = 0; i < M + K; i++) { inputs = Console.ReadLine().Split(); long a = Convert.ToInt64(inputs[0]);
if (a == 1) { int b = Convert.ToInt32(inputs[1]); long c = Convert.ToInt64(inputs[2]); long diff = c - Arr[b]; Arr[b] = c; Update(1, 1, N, b, diff); } else { int b = Convert.ToInt32(inputs[1]); int c = Convert.ToInt32(inputs[2]); Console.WriteLine(Sum(1, 1, N, b, c)); } } }
static long Init(int node, int s, int e) { if (s == e) Tree[node] = Arr[s]; else { int m = (s + e) / 2; Tree[node] = Init(node * 2, s, m) + Init(node * 2 + 1, m + 1, e); }
return Tree[node]; }
static void Update(int node, int s, int e, int index, long diff) { if (index < s || e < index) return;
Tree[node] += diff;
if (s != e) { int m = (s + e) / 2; Update(node * 2, s, m, index, diff); Update(node * 2 + 1, m + 1, e, index, diff); } }
static long Sum(int node, int s, int e, int l, int r) { if (r < s || e < l) return 0; if (l <= s && e <= r) return Tree[node];
int m = (s + e) / 2; return Sum(node * 2, s, m, l, r) + Sum(node * 2 + 1, m + 1, e, l, r); } } } |
Python 3
from sys import stdin input = stdin.readline
N, M, K = map(int, input().split()) arr = [0] * (N + 1) for i in range(1, N + 1): arr[i] = int(input()) tree = [0] * (N * 4)
def init(node, s, e): global arr, tree
if s == e: tree[node] = arr[s] else: m = (s + e) // 2 tree[node] = init(node * 2, s, m) + init(node * 2 + 1, m + 1, e) return tree[node]
def sum(node, s, e, l, r): global tree if r < s or e < l: return 0 if l <= s and e <= r: return tree[node]
m = (s + e) // 2 return sum(node * 2, s, m, l, r) + sum(node * 2 + 1, m + 1, e, l, r)
def update(node, s, e, index, diff): global tree
if index < s or e < index: return
tree[node] += diff if s != e: m = (s + e) // 2 update(node * 2, s, m, index, diff) update(node * 2 + 1, m + 1, e, index, diff)
init(1, 1, N) for _ in range(M + K): a, b, c = map(int, input().split()) if a == 1: diff = c - arr[b] arr[b] = c update(1, 1, N, b, diff) else: print(sum(1, 1, N, b, c)) |