AtCoder Beginner Contest 159

しばらくやめてたがまた再開する。

atcoder.jp

「一つだけ抜く」パターンは全体をあらかじめ求めておいて、その抜かれた影響を考えることで計算量がO(N)になるケースが多い気がする。

今回もそんな感じで、抜かれた数字が元々C個あったとすると、そのうちの2個の組み合わせはC * (C - 1) / 2となる。 ここから1つ抜かれて、C-1個になれば(C - 1) * (C - 2) / 2個になるので、この差分をあらかじめ取っておけばよい。

Submission #11102912 - AtCoder Beginner Contest 159

import collections

def solve(N, A):
    c = collections.Counter(A)
    originals = [0 for i in range(N + 1)]
    minus = [0 for i in range(N + 1)]
    for (k, v) in c.items():
        originals[k] = v * (v-1) // 2
        minus[k] = ((v * (v-1)) - (v-1) * (v-2)) // 2
    S = sum(originals)

    ans = []
    for a in A:
        ans.append(str(S - minus[a]))
    return "\n".join(ans)

if __name__ == "__main__":
    N = int(input())
    A = list(map(int, input().split(" ")))
    print(solve(N, A))

atcoder.jp

領域分割の問題。少し考えて、最適な分割方法を何かで探す、というのは早々にあきらめて、全探索ができるか検討。

Hの制約が非常に小さいことから、切るか切らないか、をH-1か所で最大512パターンになるのは予想でき、全探索で行けそうな予感がしたが、どうやってH分割を実装するかを少し悩んだ。

結論としては二進数で、前後が変わったら違うグループにする、という実装をした。 つまり、0101は4つのグループに分かれ、0000は1つのグループ、0011は2つのグループ、といった具合にする方針。

    for g in create_groups(H):
        M = len(set(g))
        s = [0 for _ in range(M)]

ここからcreate_groupsの戻り値は、あとで参照しやすいようにグループ番号を入れた。 例えば、0110の3つのグループに分かれた場合、gには[0, 1, 1, 2]という配列が入る。 このgはHの各行のグルーピングを表していて、sにはグループごとのカウントを入れている。

カウント時にはhの代わりにg[h]とすることで、対象グループのカウントを増やすことになる。

    s[g[h]] += 1

あとは、足しすぎたら直前で分割したことにして、再度この列を足し、

            if any([a > K for a in s]):
                s = [0 for _ in range(M)]
                tans += 1
                for h in range(H):
                    if fields[h][w] == "1":
                        s[g[h]] += 1

もし、この1列だけでもKを超えてしまうようならばH分割が足りてないので終了にする。

            if any([a > K for a in s]):
                tans = H * W + 1
                break

一番外側の1回のループで、求められる分割数はH分割+W分割であり、 全走査しながら、全体の分割数の最小値を取ればよい。

Submission #11123996 - AtCoder Beginner Contest 159

def create_groups(H):
    F = "{:0" + str(H) + "b}"
    ret = set()
    for n in range(2 ** H):
        g = create_group(F.format(n))
        ret.add(tuple(g))
    return list(sorted(ret))


def create_group(s):
    ret = []
    n = 0
    for i in range(len(s)):
        if i == 0:
            ret.append(0)
        else:
            if s[i-1] == s[i]:
                ret.append(n)
            else:
                n += 1
                ret.append(n)
    return ret


def solve(H, W, K, fields):
    ans = H * W + 1
    for g in create_groups(H):
        M = len(set(g))
        s = [0 for _ in range(M)]
        tans = M - 1
        for w in range(W):
            for h in range(H):
                if fields[h][w] == "1":
                    s[g[h]] += 1
            if any([a > K for a in s]):
                s = [0 for _ in range(M)]
                tans += 1
                for h in range(H):
                    if fields[h][w] == "1":
                        s[g[h]] += 1
            if any([a > K for a in s]):
                tans = H * W + 1
                break
        ans = min(ans, tans)
    return ans

if __name__ == "__main__":
    H, W, K = tuple(map(int, input().split(" ")))
    fields = []
    for _ in range(H):
        fields.append(input())
    print(solve(H, W, K, fields))

もう少し早く実装したいところだったがまぁパフォーマンス1600超えたので良しとしよう。