알고리즘 풀이/백준

[백준 14003] 가장 긴 증가하는 부분 수열 5

mhko411 2021. 9. 2. 23:03
728x90

문제

수열 A가 주어졌을 때, 가장 긴 증가하는 부분 수열을 구하는 프로그램을 작성하시오.

예를 들어, 수열 A = {10, 20, 10, 30, 20, 50} 인 경우에 가장 긴 증가하는 부분 수열은 A = {1020, 10, 30, 20, 50} 이고, 길이는 4이다.

 

입력

첫째 줄에 수열 A의 크기 N (1 ≤ N ≤ 1,000,000)이 주어진다.

둘째 줄에는 수열 A를 이루고 있는 Ai가 주어진다. (-1,000,000,000 ≤ Ai ≤ 1,000,000,000)

 

출력

첫째 줄에 수열 A의 가장 긴 증가하는 부분 수열의 길이를 출력한다.

둘째 줄에는 정답이 될 수 있는 가장 긴 증가하는 부분 수열을 출력한다.


접근

가장 긴 증가하는 부분 수열의 문제를 이분 탐색으로 풀면서 출력되는 부분 수열이 맞지 않는다고 생각했다. 예를 들면 1, 3, 4, 2가 있을 때 1, 3, 4가 LIS가 되는데 출력되는 값은 1, 2, 4가 된다. 이는 이진 탐색을 통해 LIS를 만들 때 2보다 큰 값이 처음 나오는 위치에 2를 넣게 되기 때문이다. 이진 탐색을 이용해서 LIS를 구할 때 lower bound를 사용하기 때문에 이와 같은 결과가 나온다. 

따라서 1, 3, 4, 2가 LIS의 배열에 들어갈 때 위치를 기록하고 이를 활용하여 LIS를 최종적으로 만들어주도록 한다.

 

구현

- 일단 LIS 리스트의 마지막 원소와 현재 탐색하는 원소를 비교했을 때 현재 원소가 크다면 LIS에 추가하는데

- 추가하기전에 현재 원소가 LIS의 어느 위치에 있는지 기록한다.

- 만약 LIS 리스트의 마지막 원소가 크다면 이진 탐색으로 현재 원소가 들어갈 위치를 찾는데

- lower bound로 리스트 중에 자신보다 이상인 값이 처음 나오는 위치를 찾게된다.

- 이진 탐색으로 반환된 위치를 index_value에 기록을 하고 LIS 리스트에서 위치를 변경해준다.

for i in range(1, N):
    index_value[i][1] = numbers[i]
    if lis[-1] < numbers[i]:
        index_value[i][0] = len(lis)
        lis.append(numbers[i])
    else:
        idx = binary_search(0, len(lis)-1, numbers[i])
        index_value[i][0] = idx
        lis[idx] = numbers[i]

- 이제 최종적으로 올바르게 구성된 LIS를 찾아야한다.

- LIS 개수는 알맞게 LIS 리스트에 저장되어있다.

- 이제 어떤 원소가 최종적인 LIS 리스트에 담을지 찾게되는데 

- 처음 idx는 LIS의 마지막 원소의 위치를 가리키고 이를 하나씩 감소시켜서 LIS에 담길 원소를 찾는다.

- 만약 index_value에서 기록된 인덱스와 현재 idx가 같다면 index_value에 기록된 값을 answer에 저장한다.

- 그러면 내림차순으로 answer가 정렬될 것이고 이를 반대로하여 출력한다.

idx = len(lis) - 1
answer = []
for i in range(N-1, -1, -1):
    if idx == -1:
        break

    if idx == index_value[i][0]:
        answer.append(index_value[i][1])
        idx -= 1
print(len(answer))
print(*answer[::-1])

- 만약 1, 3, 4, 2라면 각각의 원소가 LIS 리스트에서 0, 1, 2, 1에 위치해 있을 것이다.

- 여기서 2에 들어갈 위치를 먼저 찾아보면 0, 1, 2, 1에서 가장 마지막에 있는 1은 아니다. 즉, 2는 들어가지 못하고 

- 그 다음 2로 넘어가서 4를 추가하게 된다. 이제 1에 들어갈 위치를 찾으면 3이 맞게되고 이를 추가하게되는 것이다.


전체 코드

import sys
input = sys.stdin.readline

def binary_search(left, right, target):
    while left < right:
        mid = (left + right) // 2
        if lis[mid] < target:
            left = mid + 1
        else:
            right = mid
    return right

N = int(input())
numbers = list(map(int, input().split()))

lis = [numbers[0], ]
index_value = [[0, 0] for _ in range(N)]
index_value[0][1] = numbers[0]

for i in range(1, N):
    index_value[i][1] = numbers[i]
    if lis[-1] < numbers[i]:
        index_value[i][0] = len(lis)
        lis.append(numbers[i])
    else:
        idx = binary_search(0, len(lis)-1, numbers[i])
        index_value[i][0] = idx
        lis[idx] = numbers[i]

idx = len(lis) - 1
answer = []
for i in range(N-1, -1, -1):
    if idx == -1:
        break

    if idx == index_value[i][0]:
        answer.append(index_value[i][1])
        idx -= 1
print(len(answer))
print(*answer[::-1])