import sys
input = sys.stdin.readline

N, M = map(int, input().split())
dp = [[0] * (N+1)]

for _ in range(N):
    nums = [0] + list(map(int, input().split()))
    dp.append(nums)

for i in range(1, N+1):
    for j in range(1, N):
        dp[i][j+1] += dp[i][j]
    
for j in range(1, N+1):
    for i in range (1, N):
        dp[i+1][j] += dp[i][j]

for _ in range(M):
    x1, y1, x2, y2 = map(int, input().split())
    result = dp[x2][y2] - dp[x1-1][y2] - dp[x2][y1-1] + dp[x1-1][y1-1]
    print (result)

 

이 문제를 관통하는 핵심아이디어는 아래와 같이, 입력 받은 표의 수를 행축으로 합을 구해준 다음 열 축으로 합을 구해주는 것이다. 

 

행 슬라이딩 합 이후 열 슬라이딩 합까지 해준 테이블을 우리는 DP 테이블이라 하자. 이후 구간합을 구하기 위한 예시로 (1, 1)과 (4, 4)를 입력 받았다면 구간합은 입력 받은 표의 빨간 영역의 합이 될 것이다. (45 = 3+4+5+4+5+6+5+6+7)

빨간 영역의 합을 구하기 위해선 전체 입력 받은 표에서 파란 부분을 각각 빼주고 초록 부분이 두 번 차감되었으니 한 번 더해주는 것이다. 

 

전체 입력 받은 표를 대표하는 값은 (1, 1) ~ (4, 4)의 합이 모두 저장된 DP[$x_2$][$y_2$]에 있다. 또한

첫 번째 파란색 영역을 대표하는 값은 DP[$x_1-1$][$y_2$]에 저장돼 있으며

두 번째 파란색 영역을 대표하는 값은 DP[$x_2$][$y_1-1$]에 저장돼 있으며

세 번째인 초록 영역을 대표하는 값은 DP[$x_1-1$][$y_1-1$]에 저장돼 있다.

 

따라서 최종 점화식은 dp[$x2$][$y2$] $-$ dp[$x_1-1$][$y_2$] $-$ dp[$x_2$][$y_1-1$] $+$ dp[$x_1-1$][$y_1-1$]이 된다.

문제 링크

https://www.acmicpc.net/problem/1068

풀이 코드

import sys
input = sys.stdin.readline

def dfs(K: int) -> None:
    tree[K] = -2
    for i in range(N):
        if K == tree[i]:
            dfs(i)

if __name__ == "__main__":
    # input 
    N = int(input())
    tree = list(map(int, input().split()))
    K = int(input())

    # main algorithm
    dfs(K)

    # output
    count = 0 
    for i in range(len(tree)):
        if tree[i] != -2 and i not in tree: # 제거 표시인 -2도 아니고, 자신을 부모 노드로 갖는 노드도 존재하지 않을 경우
            count +=1
    print(count)

풀이 과정

핵심 아이디어는 DFS 알고리즘을 통해 부모 노드를 제거 한 뒤, 부모 노드에 종속된 자식 노드를 탐색하며 제거해주는 것이다.

가령 예제 입력 4번으로 예시를 들면 다음과 같다.

 

[-1, 0, 0, 2, 2, 4, 4, 6, 6]은 0~8까지의 노드 번호 위치에 부모 노드를 나타낸 것이다. 이 중 4번 노드를 없애고자 한다면 5번째 인덱스 값인 2를 제거해주어야 한다. 하지만 실제로 리스트에서 제거할 경우 인덱스의 변화가 생겨 일관된 알고리즘 적용에 번거로워지므로 제거했단 표기로 대신한다. -1은 루트 노드로 사용되니 -2란 값이 적당할 것이고 이를 적용하면 다음과 같아진다. [-1, 0, 0, 2, -2, 4, 4, 6, 6]

 

위 그림에서 노드 4를 -2로 제거 표기 해준 것이다. 이후 4에 종속된 5, 6 그리고 6에 또 종속된 7, 8을 DFS 알고리즘으로 반복 탐색하며 -2로 변경한다면 4에 종속되지 않았던 노드들만 남아 있게 될 것이다.

문제에서 요구하는 핵심은 어떤 한 마을에 우체국이 세워진다 가정할 때, 다른 마을 사람들이 우체국이 세워진 마을에 가기 위해 필요한 거리의 합이 최소가 되는 마을 번호를 찾는 것이다. 크게 두 가지 아이디어가 있다. 첫 번째는 단순하게 모든 마을 사람의 총합을 구하고 [(마을 번호, 사람 수), (마을 번호, 사람 수), ...]의 리스트를 순회할 때 절반이 넘어가는 그 마을 번호를 찾는 것이다. 이 접근이 정답의 핵심 아이디어다. 처음에 이 방법이 희번뜩하며 떠올랐지만 단순하단 이유로 올바른 접근이 아닐 것이란 생각을 하고 다른 방법을 찾게 됐다. 

 

[정답 코드]

import sys
input = sys.stdin.readline

villiage = []
all_people = 0

N = int(input())

for i in range(N):
    n_viliage, people = map(int, input().split())
    villiage.append([n_viliage, people])
    all_people += people

villiage.sort(key= lambda x: x[0])

count = 0
for i in range(N):
    count += villiage[i] [1]
    if count >= all_people/2:
        print (villiage[i][0])
        break

 

 

두 번째 아이디어는, 반복문을 통해 [(마을 번호, 사람 수), (마을 번호, 사람 수), ...]을 차례대로 돌면서 최소 값을 구하는 것이다. 구체적으로 첫 번째 마을에 우체국이 세워졌다 가정하면 두 번째 세 번째 마을에서 사람들이 오기 위해 얼마나 비용이 드는지를 구하는 것이다. 예제 입력을 예시로 삼는다면

 

첫 번째 루프 즉, 첫 번째 마을에 우체국이 세워진다면 아래와 같이 총 11의 비용이 든다.

(1번째 마을 - 1) * 3 = (1-1) * 3 = 0

(2번째 마을 - 1) * 5 = (2-1) * 5 = 5

(3번째 마을 - 1) * 3 = (3-1) * 3 = 6

 

두 번째 루프 즉, 두 번째 마을에 우체국이 세워진다면 아래와 같이 총 6의 비용이 든다.

(1번째 마을 - 2) * 3 = (1-2) * 3 = 3 

(2번째 마을 - 2) * 5 = (2-2) * 5 = 0

(3번째 마을 - 2) * 3 = (3-2) * 3 = 3

 

세 번째 루프 즉, 세 번째 마을에 우체국이 세워진다면 아래와 같이 총 11의 비용이 든다.

(1번째 마을 - 3) * 3 = (1-3) * 3 = 6

(2번째 마을 - 3) * 5 = (2-3) * 5 = 5

(3번째 마을 - 3) * 3 = (3-3) * 3 = 0

 

* 음수가 되는 부분은 절대값을 취해줘야 함

 

따라서 총 6의 비용이 드는 두 번째 마을이 가장 우체국을 세우기 적합한 것이다. 이를 구현한 코드는 아래와 같으며, O(N^2)의 시간 복잡도로 인해 시간 초과가 발생한다.

 

[초기 시도한 코드]

import sys
input = sys.stdin.readline

N = int(input())

village = [0] * (N + 1)
people = [0] * (N + 1)

min_values = [0] * (N +1)

for i in range(1, N+1):
    x, a = map(int, input().split())
    village[i] = x
    people[i] = a

for i in range(1, N+1):
    value = 0
    for j in range(1, N+1):
        value += abs(village[j] - i) * people[j]
    
    min_values[i] = value

del min_values[0]
index = min_values.index(min(min_values))
print (village[index+1])

 

 

 

import sys
input = sys.stdin.readline
N = int(input())

_max = int(-1e9)
_min = int(1e9)

numbers = list(map(int, input().split()))
add, sub, mul, div = map(int, input().split())

def dfs(idx, _sum, add, sub, mul, div):
    global _max, _min
    if idx == N:
        _max = max(_max, _sum)
        _min = min(_min, _sum)
        return
    
    if add:
        dfs(idx+1, _sum + numbers[idx], add-1, sub, mul, div)

    if sub:
        dfs(idx+1, _sum - numbers[idx], add, sub-1, mul, div)
        
    if mul:
        dfs(idx+1, _sum * numbers[idx], add, sub, mul-1, div)
        
    if div:
        dfs(idx+1, _sum // numbers[idx] if _sum > 0 else -((-_sum)//numbers[idx]), add, sub, mul, div-1)

dfs(1, numbers[0], add, sub, mul, div)

print (_max)
print (_min)
def recursion(start):
    if M == len(picked):
        print (*picked)
        return

    for i in range(start, len(numbers)):
        picked.append(numbers[i])
        recursion(i)
        picked.pop()

if __name__ == '__main__':
    N, M = list(map(int, input().split()))
    numbers = list(map(int, input().split()))
    numbers.sort()
    picked = []
    
    recursion(0)
def recursion():
    if M == len(picked):
        print (*picked)
        return 

    for i in range(len(numbers)):
        picked.append(numbers[i])
        recursion()
        picked.pop()

if __name__ == "__main__":
    N, M = list(map(int, input().split()))
    numbers = list(map(int, input().split()))
    numbers.sort()
    picked = []
    recursion()
from itertools import combinations

def library():
    numbers.sort()
    result = list(combinations(numbers, M))
    for i in result:
        print (*i)

def recursion(start):
    if M == len(picked):
        print (*picked)
        return

    for i in range(start, len(numbers)):
        if numbers[i] not in picked:
            picked.append(numbers[i])
            recursion(i+1)
            picked.pop()

if __name__ == "__main__":
    N, M = list(map(int, input().split()))
    numbers = list(map(int, input().split()))
    numbers.sort()
    picked = []
    
    recursion(0)
    #library()
def recursion():
    if M == len(picked):
        print (*picked)
        return

    for i in inputValue:
        if i not in picked:
            picked.append(i)
            recursion()
            picked.pop()

if __name__ == "__main__":
    N, M = list(map(int, input().split()))
    numbers = list(map(int, input().split()))
    
    picked = []
    inputValue = [i for i in numbers]
    inputValue.sort()
    
    recursion()
def recursion(start):
    if M == len(picked):
        print (*picked)
        return
    
    for i in range(start, N):
        picked.append(i+1)
        recursion(start)
        picked.pop()
        start += 1

if __name__ == "__main__":
    N, M = list(map(int, input().split()))
    picked = []
    recursion(0)
def recursion():
    if M == len(picked):
        print (*picked)
        return
    
    for i in range(N):
        picked.append(i+1)
        recursion()
        picked.pop()

if __name__ == "__main__":
    
    N, M = list(map(int, input().split()))
    picked = []
    recursion()

+ Recent posts