AtCoder Beginner Contest 167 - E - ∙ (Bullet)

atcoder.jp

イワシは相性によってグルーピングができる。

このグルーピングさえうまく処理できれば、あとは数え上げればよい。

a匹のグループAとb匹のグループBの相性がそれぞれ悪かった時、 各グループの内部での取る取らないで2のa乗、2のb乗通りできる。 この時、1匹も取らないをダブルカウントしてしまっているので1を引けばよい。

グループAに対し、相性が悪いグループがいなければそのまま2のa乗とすればよい。

これらとは独立したグループC、グループDと居たとき、これらはグループA/Bとは独立しているため、 掛け算でMODを取りながら集計すればよい。

また、(0, 0)と(0, x), (x, 0)は特別扱いする必要がある。 といっても、(0, 0)はすべてと相性が悪いので、1つずつでグループを作る。 そのため、その数をそのまま足せばよい。(0, x), (x, 0)は特別な値として持てばよい。

と、ここまでは問題なかった。

グルーピングはまず、O(N2)では扱いきれないため、何か値としてハッシュマップに持っておく必要がある。

ここでキーは比の値が最初に思いつくが比の値が非常に小さくなり、誤差を扱いきれなくなる。 この誤差を扱おうとしたが時間切れ。

教訓として、この誤差を減らすように頑張るぐらいならそのまま持つのが吉。

「比は互いに素にしてしまえば、一意に決まる」ってまぁ考えれば当たり前なのを思いつかなかった。

正負の扱いに注意して対応すればまぁ時間内にもとけた気がするだけに悔しい。

import math
MOD = 1000000007

def solve(N, ABs):
    groups = {}
    zero_count = 0
    for a, b in ABs:
        if a == 0 and b == 0:
            zero_count += 1
            continue
        if a == 0:
            k = (0, 1)
        elif b == 0:
            k = (1, 0)
        else:
            g = math.gcd(abs(a), abs(b))
            a //= g
            b //= g
            if a * b < 0:
                k = (abs(a), -abs(b))
            else:
                k = (abs(a), abs(b))
        groups.setdefault(k, 0)
        groups[k] += 1

    visited = set()
    possibles = []
    for k, v in groups.items():
        if k in visited:
            continue
        p = 0
        p += pow(2, v, MOD)
        if k[1] == 0:
            m = (0, 1)
        elif k[0] == 0:
            m = (1, 0)
        else:
            if k[1] < 0:
                m = (-k[1], k[0])
            else:
                m = (k[1], -k[0])


        if m in groups.keys() and m not in visited:
            p += pow(2, groups[m], MOD)
            visited.add(m)
            p -= 1
        visited.add(k)
        possibles.append(p % MOD)

    ans = 1
    for p in possibles:
        ans *= p
        ans %= MOD

    if zero_count:
        ans += zero_count
        ans %= MOD
    return (ans - 1) % MOD


if __name__ == "__main__":
    N = int(input())
    ABs = [tuple(map(int, input().split(" "))) for _ in range(N)]
    print(solve(N, ABs))

E問題は大体そのままやると2回、3回ミスるので 遅いが以下のような確実な解法と合わせてチェックすると考慮漏れが少ない。

def slow_solve(N, ABs):
    import itertools
    ans = 0
    for k in range(1, 2**N):
        b = ("{:0" + str(N) + "b}").format(k)
        taken = []
        for i, c in enumerate(b):
            if c == "1":
                taken.append(ABs[i])

        if len(taken) == 1:
            ans += 1
        elif all(comb[0][0] * comb[1][0] + comb[0][1] * comb[1][1] != 0
                 for comb in itertools.combinations(taken, 2)):
            ans += 1
    return ans % MOD

def create(N):
    import random
    ABs = []
    MM = 10
    for _ in range(N):
        ABs.append( (random.randint(-MM, MM), random.randint(-MM, MM)))
    return ABs

for _ in range(1000):
    for N in range(1, 10):
        ABs = create(N)
        if solve(N, ABs) != slow_solve(N, ABs):
            print("ERROR")
            print(solve(N, ABs))
            print(slow_solve(N, ABs))
            print(N)
            print(ABs)
            break