알고리즘 풀이/백준

[백준 12014] 주식 - 이진탐색으로 다시 풀기

mhko411 2021. 5. 1. 11:11
728x90

문제

앞으로 N일간 주식 가격이 숫자로 주어진다. 이 중 K번 거래를 하려고한다. 거래를 할 때는 하루에 한 번만 할 수 있으며 첫 날을 제외하고 이전날보다 주식가격이 올랐을 때만 거래를 한다.

이 때 K번 거래를 할 수 있는지 검사하는 프로그램을 작성하시오.

 

입력

입력 파일에는 여러 테스트 메이스가 포함될 수 있다. 파일의 첫째 줄에 케이스의 개수 T(2 ≤ T ≤ 100)가 주어지고, 이후 차례로 T 개 테스트 케이스가 주어진다. 각 테스트 케이스의 첫 줄에 두 정수 N과 K이 주어진다. N은 앞으로 주가를 알 수 있는 날 수이며, (1 ≤ N ≤ 10,000) K는 거래의 회수이다. (1 ≤ K ≤ 10,000) 다음 줄에는 앞으로 N 날의 주가가 사이에 공백을 두고 주어진다. 주가는 1부터 10,000 사이의 정수이다.

 

출력

각 테스트 케이스에 대해서 출력은 한 줄로 구성된다. T 번째 테스트 케이스에 대해서는 첫째 줄에는 "Case #T"를 출력한다. 두 번째 줄에는 주어진 조건을 만족하게 주식을 살 수 있으면 1, 아니면 0을 출력한다.

 


접근

N일 중에 K길이의 오름차순 수열이 있는지 검사를 한다.

이 문제는 이진탐색을 활용한 LIS로 풀어야 시간초과가 발생하지 않을 것 같다. 지금은 DP로 풀었고 PYPY3로 제출을 하여 간신히 통과를 하였다.

 

나중에 꼭 이진탐색으로 LIS를 푸는 법을 공부하고 다시 풀어봐야겠다.

 

구현

- 현재 i를 포함하여 오름차순 수열의 길이를 구한다.

- cache에서 K가 있는지 검사를 하여

- 있으면 1, 없으면 0을 출력하도록 한다.

    for i in range(N):
        for j in range(i):
            if prices[j] < prices[i]:
                cache[i] = max(cache[j]+1, cache[i])

    answer = 0
    for c in cache:
        if c == K:
            answer = 1
            break
    print("Case #{}".format(tc+1))
    print(answer)

전체 코드

import sys
input = sys.stdin.readline

test_case = int(input())
for tc in range(test_case):
    N, K = map(int, input().split())
    prices = list(map(int, input().split()))
    cache = [1 for _ in range(N)]

    for i in range(N):
        for j in range(i):
            if prices[j] < prices[i]:
                cache[i] = max(cache[j]+1, cache[i])

    answer = 0
    for c in cache:
        if c == K:
            answer = 1
            break
    print("Case #{}".format(tc+1))
    print(answer)

이진탐색으로 풀기

이진탐색으로 LIS를 구할 수 있다. 문제를 읽었을 때 가장 긴 증가하는 부분수열이 K 이상일 때 만족할 수 있다는 것을 알 수 있다.

 

import sys
input = sys.stdin.readline

def binary_search(left, right, target):
    while left <= right:
        mid = (left + right) // 2
        if lis[mid] < target:
            left = mid + 1
        else:
            right = mid - 1

    return left

test_case = int(input())

for tc in range(1, test_case+1):
    N, K = map(int, input().split())
    numbers = list(map(int, input().split()))
    lis = [numbers[0]]
    for i in range(1, N):
        if lis[-1] < numbers[i]:
            lis.append(numbers[i])
        else:
            j = binary_search(0, len(lis)-1, numbers[i])
            lis[j] = numbers[i]

    print('Case #{}'.format(tc))
    if len(lis) >= K:
        print(1)
    else:
        print(0)