しゃくとり法
はじめに
本記事では,「しゃくとり法」と呼ばれる計算量削減テクニックを紹介する.
筆者は競技プログラミング初心者である.AtCoderのビギナーズコンテストに4回参加したことがあり,そのうちD問を突破したことがあるのは1回のみである.そもそもD問を回答できる人は,「この問題にはこのアルゴリズムを使えば良い」ということが分かっている人であることに気づいた.自分にはまだその能力がないので,これからその能力をつけようと思い,記事にしてまとめることにした.
以上の経緯より,初心者が記事を書いているので誤ったことを記載しているかもしれない.もし間違ったことを書いていたら,コメントで教えていただきたい.
しゃくとり法とは
しゃくとり法とは,簡単に言ってしまうと数列{a1, a2, ..., an}において,条件を満たす区間の最小,最大,数え上げを効率的に行うアルゴリズムである.
言葉で説明してもピンとこないと思うので,例題を示す.
例題1:AOJ Course The Number of Windows
問題リンク: http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_3_C&lang=jp
この解法としてまず挙げられるのは,「全てのパターンを確かめる」ことである.つまり,左端alと右端arの全組み合わせで条件を満たすかチェックする方法である.
N, Q = input().split(" ") N, Q = int(N), int(Q) a_list = [int(i) for i in input().split(" ")] x_list = [int(i) for i in input().split(" ")] for x in x_list: count = 0 # 条件を満たす個数 for left in range(0, N): for right in range(left+1, N+1): if sum(a_list[left:right]) <= x: count += 1 print(count)
当然この方法だと,計算量はO(n2)になるためTLEになってしまう. 改善策として,「ifの条件に合致しない場合は,右端をインクリメントしても当然条件に合致しない」ということを利用し,条件に合致しないタイミングでfor文をbreakする方法である.
N, Q = input().split(" ") N, Q = int(N), int(Q) a_list = [int(i) for i in input().split(" ")] x_list = [int(i) for i in input().split(" ")] for x in x_list: count = 0 # 条件を満たす個数 for left in range(0, N): for right in range(left+1, N+1): if sum(a_list[left:right]) <= x: count += 1 else: break print(count)
これで多少計算量は小さくなったが,まだ改善の余地がある. 上記のプログラムだと,条件に合致しない場合,leftをインクリメントして,rightをleft+1の位置まで戻している.このrightをleft+1の位置まで戻す作業が実は無駄になっている.なぜなら,最後に条件に合致していたrightをright' と書くとすると,leftをインクリメントしても,右端がright'のときまでは必ず条件に合致するからである.
以上の挙動をプログラムに落とし込んでみよう.
N, Q = input().split(" ") N, Q = int(N), int(Q) a_list = [int(i) for i in input().split(" ")] x_list = [int(i) for i in input().split(" ")] for x in x_list: count = 0 # 条件を満たす個数 start_right = 0 # 初期right位置 sum_num = 0 # leftからrightまでのa_listの合計値 for left in range(0, N): if left > start_right: start_right = left for right in range(start_right, N): sum_num += a_list[right] if sum_num <= x: count += 1 else: break # すでに条件を満たすものについては,ここでcount up if (right - left - 1) >= 1: count += (right - left - 1) # sum_numを修正 & start_rightの位置を修正 if left != right: sum_num -= (a_list[left]+a_list[right]) start_right = right else: sum_num -= a_list[left] start_right = left+1 print(count)
上記のプログラムによって,rightを差し戻す必要がなくなり,leftもrightも値が減ることはなくなる.したがって,計算量もO(n)となる.
しゃくとり法が利用できる場合
しゃくとり法が利用できる場合は,以下の通り.(引用: しゃくとり法 (尺取り法) の解説と、それを用いる問題のまとめ - Qiita )
しゃくとり法は
・「条件」を満たす区間 (連続する部分列) のうち、最小の長さを求める
・「条件」を満たす区間 (連続する部分列) のうち、最大の長さを求める
・「条件」を満たす区間 (連続する部分列) を数え上げる
といったことを効率良く実現できる手法ですが、「条件」というのが何でもいいわけではないです。「条件を満たす区間」が以下のいずれかの構造になっている場合には、しゃくとり法を適用することができます:
上記の構造が成り立っている場合,rightの差し戻しが必要なくなるわけだが,そうでない場合は,rightの差し戻しが必要になってきてしまうので,しゃくとり法は使えないわけである.
例題2:3061 -- Subsequence
こちらの問題も,先ほど説明した構造を満たしているので,しゃくとり法が使える.以下がそのプログラムである.
T = int(input()) for _ in range(0, T): N, S = input().split(" ") N, S = int(N), int(S) x_list = [int(i) for i in input().split(" ")] start_right = 0 sum_num = 0 minimum = 10**18 for left in range(0, N): if left > start_right: start_right = left for right in range(start_right, N): sum_num += x_list[right] if sum_num >= S: if (right - left) < minimum: minimum = (right - left) + 1 break if left == right: sum_num -= x_list[right] start_right = left+1 else: sum_num -= (x_list[left] + x_list[right]) start_right = right print(minimum)
例題3:C - 列
これもしゃくとり法でいける.
N, K = input().split(" ") N, K = int(N), int(K) s_list = [] is_run = True for _ in range(0, N): tmp = int(input()) if tmp == 0: print(N) is_run = False break s_list.append(tmp) ans = 0 start_right = 0 mult_num = 1 if is_run: for left in range(0, N): plus = 1 if left > start_right: start_right = left for right in range(start_right, N): mult_num *= s_list[right] if mult_num > K: plus = 0 break if ans < (right - left): ans = (right - left) + plus if (left == right): mult_num = 1 start_right = left+1 else: mult_num //= s_list[left] mult_num //= s_list[right] start_right = right print(ans)
例題4:C - 単調増加
これもしゃくとり法の典型である.
N = int(input()) a_list = [int(i) for i in input().split(" ")] start_left = 0 count = N right = start_left while (start_left < N): right += 1 if(right >= N or a_list[right-1] >= a_list[right]): diff = right - start_left - 1 count += diff*(diff+1)//2 start_left = right print(count)
例題5:B - 細長いお菓子
N = int(input()) A = list(map(int, input().split())) used = [False]*(10**5+1) right = 0 ans = 0 for left in range(N): while right < N and used[A[right]] == False: used[A[right]] = True right += 1 ans = max(ans, right-left) used[A[left]] = False print(ans)