AtCoder Beginner Contest 167 - E - ∙ (Bullet)

atcoder.jp

イワシは相性によってグルーピングができる。

このグルーピングさえうまく処理できれば、あとは数え上げればよい。

a匹のグループAとb匹のグループBの相性がそれぞれ悪かった時、 各グループの内部での取る取らないで2のa乗、2のb乗通りできる。 この時、1匹も取らないをダブルカウントしてしまっているので1を引けばよい。

グループAに対し、相性が悪いグループがいなければそのまま2のa乗とすればよい。

これらとは独立したグループC、グループDと居たとき、これらはグループA/Bとは独立しているため、 掛け算でMODを取りながら集計すればよい。

また、(0, 0)と(0, x), (x, 0)は特別扱いする必要がある。 といっても、(0, 0)はすべてと相性が悪いので、1つずつでグループを作る。 そのため、その数をそのまま足せばよい。(0, x), (x, 0)は特別な値として持てばよい。

と、ここまでは問題なかった。

グルーピングはまず、O(N2)では扱いきれないため、何か値としてハッシュマップに持っておく必要がある。

ここでキーは比の値が最初に思いつくが比の値が非常に小さくなり、誤差を扱いきれなくなる。 この誤差を扱おうとしたが時間切れ。

教訓として、この誤差を減らすように頑張るぐらいならそのまま持つのが吉。

「比は互いに素にしてしまえば、一意に決まる」ってまぁ考えれば当たり前なのを思いつかなかった。

正負の扱いに注意して対応すればまぁ時間内にもとけた気がするだけに悔しい。

import math
MOD = 1000000007

def solve(N, ABs):
    groups = {}
    zero_count = 0
    for a, b in ABs:
        if a == 0 and b == 0:
            zero_count += 1
            continue
        if a == 0:
            k = (0, 1)
        elif b == 0:
            k = (1, 0)
        else:
            g = math.gcd(abs(a), abs(b))
            a //= g
            b //= g
            if a * b < 0:
                k = (abs(a), -abs(b))
            else:
                k = (abs(a), abs(b))
        groups.setdefault(k, 0)
        groups[k] += 1

    visited = set()
    possibles = []
    for k, v in groups.items():
        if k in visited:
            continue
        p = 0
        p += pow(2, v, MOD)
        if k[1] == 0:
            m = (0, 1)
        elif k[0] == 0:
            m = (1, 0)
        else:
            if k[1] < 0:
                m = (-k[1], k[0])
            else:
                m = (k[1], -k[0])


        if m in groups.keys() and m not in visited:
            p += pow(2, groups[m], MOD)
            visited.add(m)
            p -= 1
        visited.add(k)
        possibles.append(p % MOD)

    ans = 1
    for p in possibles:
        ans *= p
        ans %= MOD

    if zero_count:
        ans += zero_count
        ans %= MOD
    return (ans - 1) % MOD


if __name__ == "__main__":
    N = int(input())
    ABs = [tuple(map(int, input().split(" "))) for _ in range(N)]
    print(solve(N, ABs))

E問題は大体そのままやると2回、3回ミスるので 遅いが以下のような確実な解法と合わせてチェックすると考慮漏れが少ない。

def slow_solve(N, ABs):
    import itertools
    ans = 0
    for k in range(1, 2**N):
        b = ("{:0" + str(N) + "b}").format(k)
        taken = []
        for i, c in enumerate(b):
            if c == "1":
                taken.append(ABs[i])

        if len(taken) == 1:
            ans += 1
        elif all(comb[0][0] * comb[1][0] + comb[0][1] * comb[1][1] != 0
                 for comb in itertools.combinations(taken, 2)):
            ans += 1
    return ans % MOD

def create(N):
    import random
    ABs = []
    MM = 10
    for _ in range(N):
        ABs.append( (random.randint(-MM, MM), random.randint(-MM, MM)))
    return ABs

for _ in range(1000):
    for N in range(1, 10):
        ABs = create(N)
        if solve(N, ABs) != slow_solve(N, ABs):
            print("ERROR")
            print(solve(N, ABs))
            print(slow_solve(N, ABs))
            print(N)
            print(ABs)
            break

AtCoder Beginner Contest 163

atcoder.jp

30分ぐらい。もう少し早く解きたかった。

まず、10100のおかげで、K個取るとき、K+1個取るとき、K+2個取るとき、・・・で、和がかぶることはなく独立して考えることができる。

N=10とし、このうち3個だけを考えてみたとすると以下の11個の中から自由に3つ取り、何種類の和ができるかを考える。

0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10

すると、最小は一番左の3つを取った時の3,最大は一番右の3つを取った時の27となり、この間はすべて作ることができるため、 差+1の25パターンが存在することができる。

これを同様に4個、5個と計算していけばよい。

左から取ったとき、右から取ったときはあらかじめ累積和で計算しておけばここの部分の計算はO(1)で計算可能なので十分に間に合う。

def solve(N, K):
    M = 10 ** 9 + 7
    pre = []
    post = []
    s = 0
    for i in range(N + 1):
        s += i
        pre.append(s)

    s = 0
    for i in range(N, -1, -1):
        s += i
        post.append(s)
    post = post[::-1]

    ans = 0
    for i in range(K, N + 2):
        a = post[N + 1 - i] - pre[i - 1] + 1
        ans += a
        ans %= M
    return ans


if __name__ == "__main__":
    N, K = tuple(map(int, input().split(" ")))
    print(solve(N, K))

atcoder.jp

うーむつらい。解説見てコード起こした。

大きい順に左右に振り分ける、というところまでは思いついた、が 局所解が大局解になりえないことに気が付いて(左右が同じ距離だった時局所解は同じだが、そのあとの状況によって大局解として異なる可能性がある)、 DPでも適応するのかな、で終わってしまった。

このDPの適応の仕方は思いつかなかった。。

自分なりの解説の解釈だとこうなった

dp[0][0] # => 1つも振り分けてないときの最大スコア

dp[1][0] # => 1番目に大きい値を左に振り分けたときの最大スコア
dp[2][0] # => 1番目に大きい値、2番目に大きい値をどちらも左に振り分けたときの最大スコア

dp[0][1] # => 1番目に大きい値を右に振り分けたときの最大スコア
dp[0][2] # => 1番目に大きい値、2番目に大きい値をどちらも右に振り分けたときの最大スコア

dp[1][1] # => 1番目に大きい値、2番目に大きい値をどちらか一つは左に、もう一つは右に振り分けたときの最大スコア

dp[1][1]は以下のどちらか大きいほうである。

  1. dp[0][1] + 2番目に大きい値を左に振り分けたときのスコア
  2. dp[1][0] + 2番目に大きい値を右に振り分けたときのスコア

この時、3番目に以降に大きい値は、1番目と2番目がどのように割り振られたとしても状況は変わらないため、順々に計算をすることができ、 再帰的に以下のように言うことができる

dp[x][y] = max(
     dp[x-1][y] + x+y番目に大きい値を左に振り分けたときのスコア,
     dp[x][y-1] + x+y番目に大きい値を右に振り分けたときのスコア
)

以下、後述のコード解説。 以下の部分で、すべてを右に振り分けたとすべてを右に振り分けた場合を先に計算しておく。

    for a, i in q:
        dp[x+1][y] = dp[x][y] + a * abs(i - x)
        x += 1
    for a, i in q:
        dp[x][y+1] = dp[x][y] + a * abs(N - 1 - y - i)
        y += 1

右振り分けは右端からの距離の注意して以下の3つに分解できる。 - N - 1 (=一番右端の座標) - y (=右へ振り分けた後の右端からの距離) - i (=元の座標)

最初の二つを用いて座標を特定し、元の座標との差を取れば移動距離が出る。

肝は以下の部分。

            a, i = q[x+y-1]
            left = a * abs(i - (x - 1))
            right = a * abs(N - 1 - (y - 1) - i)
            dp[x][y] = max(dp[x-1][y] + left, dp[x][y-1] + right)
            ans = max(ans, dp[x][y])

x+y番目に大きい値をとってきた後、これを左に振り分けたときと右に振り分けたときのスコアを考える。

ここで注意しなければならないのは、左に振り分けたときの状況はdp[x-1][y]から考えるため、 左に振り分けたときのその座標はx-1となり、移動距離はabs(i - (x-i))となること。

同様に右に振り分けたときは、先ほどの式のうち、yがy-1になる。

def solve(N, A):
    q = [(a, i) for i, a in enumerate(A)]
    q.sort(reverse=True)
    dp = [[-1 for _ in range(N + 1)] for __ in range(N + 1)]

    x = 0
    y = 0
    dp[x][y] = 0
    for a, i in q:
        dp[x+1][y] = dp[x][y] + a * abs(i - x)
        x += 1
    x = 0
    for a, i in q:
        dp[x][y+1] = dp[x][y] + a * abs(N - 1 - y - i)
        y += 1

    ans = 0
    for x in range(1, N+1):
        for y in range(1, N+1):
            if x + y > N:
                continue
            a, i = q[x+y-1]
            left = a * abs(i - (x - 1))
            right = a * abs(N - 1 - (y - 1) - i)
            dp[x][y] = max(dp[x-1][y] + left, dp[x][y-1] + right)
            ans = max(ans, dp[x][y])
    # [print(row) for row in dp]
    return ans

if __name__ == "__main__":
    N = int(input())
    A = list(map(int, input().split(" ")))
    print(solve(N, A))

dp[x][y]はx+y <=N という制限が付くため、 N+1 x N+1の行列のうち左上三角の部分だけが埋まることになる。

これらのうち最大が答えになる。

AtCoder Beginner Contest 161

atcoder.jp

上から作ったり、下から作ろうとしたり、dpを考えてみたが、 0の扱いが難しくて、下からは作れなかったし、上から作ると、何番目、の扱いが非常に面倒。

しばらく考えて苦肉の策として、Nがそこまで大きくないこととサンプルテストに最大値が出ていることに甘えて、 ルンルン数を100000番目を超える範囲ですべて作った。

コードテストで1000ms以下で実行できることを確認してからsubmit。

def create_lun(K):
    if K == 1:
        return ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    else:
        r = []
        d = create_lun(K - 1)
        for n in d:
            c = int(n[-1])
            if K == 10 and int(n[0]) > 3:
                continue
            if c - 1 >= 0:
                r.append(n + str(c - 1))
            r.append(n + n[-1])
            if c + 1 <= 9:
                r.append(n + str(c + 1))
        r.extend(d)
        return r


def solve(K):
    A = create_lun(10)
    S = set([int(l) for l in A])
    S.remove(0)
    A = sorted(list(S))

    return A[K - 1]


if __name__ == "__main__":
    K = int(input())
    print(solve(K))

AtCoder Beginner Contest 160

atcoder.jp

もう少し早くできた、、はず。

ワーシャルフロイドの本質をきちんと理解してなかったが故に実装に時間がかかった。 当然、Nの制約からワーシャルフロイドを直接使ったO(N3)はできないが、 それでも、ワーシャルフロイドの考え方がベースになっている。

ワーシャルフロイドは、「ある地点を経由したときに距離が更新されるか」を総当たりで試すもの。 この更新はアルゴリズムイントロダクションでは「緩和(relax)」と表現されている。

本問では、まずX地点とY地点の接続を考えずに2点間の距離を計算した後 X地点とY地点の接続(距離を1)し、そこから緩和をすればよい。

フロイドワーシャルの一番外側のループ、「どこを経由するか」を、XとYだけに限定すれば、 実質定数倍のため、O(N2)として計算することができる。

#include <bits/stdc++.h>

#define rep(i, a, b) for (long long int i = (a); i < (b); i++)

using namespace std;
using ll = long long int;
using ld = long double;

auto p = [](auto s) {
    cout << s << endl;
};

int main() {
    int N, X, Y;
    cin >> N >> X >> Y;
    vector<vector<int>> d(N, vector<int>(N));
    rep(i, 0, N) {
        rep(j, 0, N) {
            d[i][j] = abs(i - j);
        }
    }
    X--;
    Y--;

    d[X][Y] = 1;
    d[Y][X] = 1;
    rep(i, 0, N) {
        rep(j, 0, N) {
            d[i][j] = min(d[i][j], d[i][X] + d[X][j]);
        }
    }
    rep(i, 0, N) {
        rep(j, 0, N) {
            d[i][j] = min(d[i][j], d[i][Y] + d[Y][j]);
        }
    }
    vector<int> ans(N);
    rep(i, 0, N) {
        rep(j, 0, N) {
            ans[d[i][j]] += 1;
        }
    }

    rep(k, 1, N) {
        p(ans[k] / 2);
    }
    return 0;
}

atcoder.jp

代わってこっちは500点問題とは思えないぐらいの簡単さ。

これぐらいの問題で考えるのは特定の方法で貪欲的に行けるか、つまり局所解が最適解を導くかをとても気にする。 これがうまくいかないならば動的計画法に切り替える必要があるが今回は貪欲的に行ける。

P、Qの配列の中で上位X個、Y個だけに注目し、この中でRの中と交換を行ったときに一番利益が多くなるリンゴを考えると、 これは、P、Qの最小のものと、Rの中の最大のものを交換すればよいことになる。(もちろん交換する価値があれば)

その次はP、Qの2番目に小さいものと、Rの2番目に大きいもの、と順々に見て行って交換していけばよい。

その途中で、Rのリンゴが、PQよりも利益が小さくなった場合、Rの残ったリンゴの利益はすべて今調べているものより等しいか小さく、 PQの残ったリンゴも、等しいか大きい。つまりこれ以後チェックする価値はない。


def solve(X, Y, P, Q, R):
    P = sorted(P, reverse=True)[:X]
    Q = sorted(Q, reverse=True)[:Y]
    R.sort(reverse=True)
    PQ = sorted(P + Q)
    i = 0
    while i < len(R) and i < len(PQ) and R[i] > PQ[i]:
        PQ[i], R[i] = R[i], PQ[i]
        i += 1
    return sum(PQ)


if __name__ == "__main__":
    X, Y, A, B, C = tuple(map(int, input().split(" ")))
    P = list(map(int, input().split(" ")))
    Q = list(map(int, input().split(" ")))
    R = list(map(int, input().split(" ")))
    print(solve(X, Y, P, Q, R))

よく考えたら、上位X個、上位Y個とRすべてをソート、でもよいのか。 500点問題ってなんだったっけ、、、

AtCoder Beginner Contest 159

しばらくやめてたがまた再開する。

atcoder.jp

「一つだけ抜く」パターンは全体をあらかじめ求めておいて、その抜かれた影響を考えることで計算量がO(N)になるケースが多い気がする。

今回もそんな感じで、抜かれた数字が元々C個あったとすると、そのうちの2個の組み合わせはC * (C - 1) / 2となる。 ここから1つ抜かれて、C-1個になれば(C - 1) * (C - 2) / 2個になるので、この差分をあらかじめ取っておけばよい。

Submission #11102912 - AtCoder Beginner Contest 159

import collections

def solve(N, A):
    c = collections.Counter(A)
    originals = [0 for i in range(N + 1)]
    minus = [0 for i in range(N + 1)]
    for (k, v) in c.items():
        originals[k] = v * (v-1) // 2
        minus[k] = ((v * (v-1)) - (v-1) * (v-2)) // 2
    S = sum(originals)

    ans = []
    for a in A:
        ans.append(str(S - minus[a]))
    return "\n".join(ans)

if __name__ == "__main__":
    N = int(input())
    A = list(map(int, input().split(" ")))
    print(solve(N, A))

atcoder.jp

領域分割の問題。少し考えて、最適な分割方法を何かで探す、というのは早々にあきらめて、全探索ができるか検討。

Hの制約が非常に小さいことから、切るか切らないか、をH-1か所で最大512パターンになるのは予想でき、全探索で行けそうな予感がしたが、どうやってH分割を実装するかを少し悩んだ。

結論としては二進数で、前後が変わったら違うグループにする、という実装をした。 つまり、0101は4つのグループに分かれ、0000は1つのグループ、0011は2つのグループ、といった具合にする方針。

    for g in create_groups(H):
        M = len(set(g))
        s = [0 for _ in range(M)]

ここからcreate_groupsの戻り値は、あとで参照しやすいようにグループ番号を入れた。 例えば、0110の3つのグループに分かれた場合、gには[0, 1, 1, 2]という配列が入る。 このgはHの各行のグルーピングを表していて、sにはグループごとのカウントを入れている。

カウント時にはhの代わりにg[h]とすることで、対象グループのカウントを増やすことになる。

    s[g[h]] += 1

あとは、足しすぎたら直前で分割したことにして、再度この列を足し、

            if any([a > K for a in s]):
                s = [0 for _ in range(M)]
                tans += 1
                for h in range(H):
                    if fields[h][w] == "1":
                        s[g[h]] += 1

もし、この1列だけでもKを超えてしまうようならばH分割が足りてないので終了にする。

            if any([a > K for a in s]):
                tans = H * W + 1
                break

一番外側の1回のループで、求められる分割数はH分割+W分割であり、 全走査しながら、全体の分割数の最小値を取ればよい。

Submission #11123996 - AtCoder Beginner Contest 159

def create_groups(H):
    F = "{:0" + str(H) + "b}"
    ret = set()
    for n in range(2 ** H):
        g = create_group(F.format(n))
        ret.add(tuple(g))
    return list(sorted(ret))


def create_group(s):
    ret = []
    n = 0
    for i in range(len(s)):
        if i == 0:
            ret.append(0)
        else:
            if s[i-1] == s[i]:
                ret.append(n)
            else:
                n += 1
                ret.append(n)
    return ret


def solve(H, W, K, fields):
    ans = H * W + 1
    for g in create_groups(H):
        M = len(set(g))
        s = [0 for _ in range(M)]
        tans = M - 1
        for w in range(W):
            for h in range(H):
                if fields[h][w] == "1":
                    s[g[h]] += 1
            if any([a > K for a in s]):
                s = [0 for _ in range(M)]
                tans += 1
                for h in range(H):
                    if fields[h][w] == "1":
                        s[g[h]] += 1
            if any([a > K for a in s]):
                tans = H * W + 1
                break
        ans = min(ans, tans)
    return ans

if __name__ == "__main__":
    H, W, K = tuple(map(int, input().split(" ")))
    fields = []
    for _ in range(H):
        fields.append(input())
    print(solve(H, W, K, fields))

もう少し早く実装したいところだったがまぁパフォーマンス1600超えたので良しとしよう。

Inside-Python - lru_cache

functools --- 高階関数と呼び出し可能オブジェクトの操作 — Python 3.7.3 ドキュメント

pythonのlru_cacheはアノテーションとしてユーザの関数をデコレーションすることができ、 その呼び出しを監視し、結果をキャッシュするものである。

この実装はなかなか面白い。

単純なキャッシュアルゴリズムであれば、ハッシュテーブルでよいのだが、lru_cacheはデータのセット時に使われていないデータの追い出し作業が発生する。

つまり、get時には取得されたデータは「使われた」として順位の修正の必要があり、またセット時にキャッシュの限界に達したら最下位のデータを消す必要がある。

https://github.com/python/cpython/blob/master/Lib/functools.py#L496

実装自体はここにあるが、アノテーションのための実装もそこそこあり、またデータが無制限に入る場合、および1件もキャッシュできない場合には面白くないので 一番コアな部分を抽出する。

また、同期処理に関するところ, statsに関するところはコアな部分ではないのでいったん削った。

        def wrapper(*args, **kwds):
            key = make_key(args, kwds, typed)
            link = cache_get(key)
            if link is not None:
                # Move the link to the front of the circular queue
                link_prev, link_next, _key, result = link
                link_prev[NEXT] = link_next
                link_next[PREV] = link_prev
                last = root[PREV]
                last[NEXT] = root[PREV] = link
                link[PREV] = last
                link[NEXT] = root
                return result

            result = user_function(*args, **kwds)

            if full:
                # Use the old root to store the new key and result.
                oldroot = root
                oldroot[KEY] = key
                oldroot[RESULT] = result
                # Empty the oldest link and make it the new root.
                # Keep a reference to the old key and old result to
                # prevent their ref counts from going to zero during the
                # update. That will prevent potentially arbitrary object
                # clean-up code (i.e. __del__) from running while we're
                # still adjusting the links.
                root = oldroot[NEXT]
                oldkey = root[KEY]
                oldresult = root[RESULT]
                root[KEY] = root[RESULT] = None
                # Now update the cache dictionary.
                del cache[oldkey]
                # Save the potentially reentrant cache[key] assignment
                # for last, after the root and links have been put in
                # a consistent state.
                cache[key] = oldroot
            else:
                # Put result in a new link at the front of the queue.
                last = root[PREV]
                link = [last, root, key, result]
                last[NEXT] = root[PREV] = cache[key] = link
                # Use the cache_len bound method instead of the len() function
                # which could potentially be wrapped in an lru_cache itself.
                full = (cache_len() >= maxsize)
            return result

プログラム全体としては大きく二つに分かれ、以下のuser_functionのコールを境に前半はキャッシュの探索、 後半がキャッシュがヒットしなかった場合のキャッシュの保存になっている。

 result = user_function(*args, **kwds)

キャッシュの探索の部分を見てみると、キーの生成とキャッシュの取得後、キャッシュがヒットした場合には、何やらデータ構造をいじっている様子がわかる。

            key = make_key(args, kwds, typed)
            link = cache_get(key)
            if link is not None:
                # Move the link to the front of the circular queue
                link_prev, link_next, _key, result = link
                link_prev[NEXT] = link_next
                link_next[PREV] = link_prev
                last = root[PREV]
                last[NEXT] = root[PREV] = link
                link[PREV] = last
                link[NEXT] = root
                return result

ネタを明かしてしまうと、lru_cacheはハッシュテーブルに加えて、双方向の連結リストを用いてキャッシュを実装している。

f:id:mitsuo_0114:20190529232121p:plain

イメージ的にはこんな感じ。

                # Move the link to the front of the circular queue
                link_prev, link_next, _key, result = link
                link_prev[NEXT] = link_next
                link_next[PREV] = link_prev

まず、この3行ではキャッシュヒットしたデータを持つノードを連結リストの中から取り除く作業をしている。

双方向連結リストでは、自分の一つ前のノードの「次のノードへのポインタ」に対し、自分の次のノードのポインタを渡してあげ、 自分の次のノードの「前のノードへのポインタ」には、自分の前のノードのポインタを渡してあげる。

これによって、自らへの参照を取り除くことができる。

その後、取り除いたノードを一番最後に挿入する。

                last = root[PREV]
                last[NEXT] = root[PREV] = link
                link[PREV] = last
                link[NEXT] = root

rootはダミーオブジェクトであり、双方向連結リストではよく用いられるもので、 一番最後のオブジェクトの次、一番最初のオブジェクトの前に存在するイメージ。

ここでは、rootの前=一番最後のノードを取得し、この最後のノードと、rootの間に入れることで一番後ろにデータを挿入することになる。

これらの作業によって、キャッシュがヒットした場合、そのヒットしたデータを持つノードの順位を変更することが可能になっている。

次に後半の挿入を見てみる。

挿入はキャッシュがヒットしなかった場合に、user_functionの結果を格納するときにおこる。

一番最後の部分を最初に見ると、以下のように新しくデータを作成したうえで、キャッシュへの保存および、双方向連結リストの一番後ろにデータを持ってきていることがわかる。

                # Put result in a new link at the front of the queue.
                last = root[PREV]
                link = [last, root, key, result]
                last[NEXT] = root[PREV] = cache[key] = link
                # Use the cache_len bound method instead of the len() function
                # which could potentially be wrapped in an lru_cache itself.
                full = (cache_len() >= maxsize)

ここの部分が、新しくノードを作っているところ。4つは順番に前のノード、次のノード、キー、データの意味合い。

                link = [last, root, key, result]

最後に、fullであり、追い出しをする場合には以下。

                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result

                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we're
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot

コメントを読むに、gcによってデータが意図せぬところで回収されないように参照カウンタに関して少し気を配っているようだが、 やっていることはデータの上書きである。

まず、現在rootを指しているノードに対して、今回のkeyとresultを挿入する。

                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result

次に、rootの次のノードを指し示し(参照カウンタをキープしたうえで)、ここのデータをNoneに書き換える。

これによって、先ほどまでrootがさしていたノードの次のノードが新しいrootになる。この辺り変数名が親切じゃない気もする。

                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None

その後、キャッシュテーブルからも追い出したデータを削除し、新しく追加されたデータをキャッシュテーブルに追加する。

                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot

なかなか興味深い実装で、ここまで有効にデータ構造を使ってるとは思わなかった。

確かに

「現在の最低順位のデータがどれかを知る」、

「任意の位置のデータを最高順位に持ってきて、それ以外のすべてのデータの順位を一つ下げる」

というあたりの要件を考えるにハッシュテーブルと、最初と最後をいじれる双方向連結リストが必要なんだなと思う。

Androidにもlru_cacheという名前の実装があるので今度実装を見てみたい。

Inside-Python - heapqとPriority Queue

最近cpythonのソースコードの中まで見ることが多いので、ネタごとにまとめてみる。

queue --- 同期キュークラス — Python 3.7.3 ドキュメント

heapq --- ヒープキューアルゴリズム — Python 3.7.3 ドキュメント

heapqはPythonの中で、二分ヒープを実装したものであり、競技プログラミングではよくお世話になっている。

二分ヒープ - Wikipedia

二分ヒープは優先度付きキューによく使うことが多いが、実はpythonには”PriorityQueue"というクラスも存在している。 Priority Queueとheapqの関係は、マニュアルにこのように書かれており、heapqが内部的に利用されていることがわかる。

優先順位付きキュー(priority queue)では、エントリは(heapq モジュールを利用して)ソートされ、

実際、PriorityQueueの実装は以下のように非常にシンプルになっており、内部構造は単なる配列で、heappush / heappopを使ってput/getを実装していることがわかる。

from heapq import heappush, heappop

class PriorityQueue(Queue):
    '''Variant of Queue that retrieves open entries in priority order (lowest first).

    Entries are typically tuples of the form:  (priority number, data).
    '''

    def _init(self, maxsize):
        self.queue = []

    def _qsize(self):
        return len(self.queue)

    def _put(self, item):
        heappush(self.queue, item)

    def _get(self):
        return heappop(self.queue)

ただしこれらのput/getは直接は呼び出されず、 親クラスのQueueのget/putで外部用のIFを提供し、同期処理と共にget / putを呼び出す構造になっている。

ここの同期処理はまたまとめたいが、以下のようにqsize / empty / fullといったメソッドはすべてmutex処理をしており、 これらはすべてthreadingモジュールを利用している。

もちろんこれらの同期処理は競技プログラミングには不要なため、PriorityQueueを競技プログラミングで利用する意味はほぼ無い。

以前間違えてこちらを使ったらLTEになったことがあったな。。

import threading

class Queue:

    def __init__(self, maxsize=0):
        self.mutex = threading.Lock()

    def qsize(self):
        with self.mutex:
            return self._qsize()

    def empty(self):
        with self.mutex:
            return not self._qsize()

    def full(self):
        with self.mutex:
            return 0 < self.maxsize <= self._qsize()

heapqへの要素の追加はheappushで実装されている。

以下の通りの実装で、与えられたリストに対し、appendを行うところまでは通常の配列操作と同じだが、 配列操作後、_siftdown を呼び出している。これらはヒープ条件を保つために呼び出されるものである。

def heappush(heap, item):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap)-1)

さらに、_siftdownは以下のようになっている。

# 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
# is the index of a leaf with a possibly out-of-order value.  Restore the
# heap invariant.
def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

heappushから呼び出されたとき、_siftdownの引数には、startposは木構造のルート、posは新しく追加されたノードのindexが入っている。

        parentpos = (pos - 1) >> 1

ここはheapの特徴で、親のindexを計算している。>>1は右ビットシフト演算なので÷2と同じ意味合い。

つまり、追加された子のノードから初めて、上にさかのぼって要素を入れ替えていっていることがわかる。

二分ヒープ - Wikipedia

図示はWikipediaを参照。ただし、図はmaxヒープだが、heapqはminヒープであり、一番上に最小値が来ることに注意。

次に要素の削除を見てみる。

以下の通りの実装で、heappushと同じく、pop()を行い、それでリストが空になればそれをそのまま返すようになっている。

しかし、pop()はリストの一番後ろを返す、にもかかわらず最小値が存在するのはheapの一番である。

heappopは最小値を返すことが求められているため、pop()した要素を返して正しく動くのはリストに1件しかなかったときのみである。

2件以上あった時には、この一番後ろの要素を1番前に持ってきて、最初の要素をreturnするようにとっておいたうえで、siftup を呼び出している。 これもsiftdownと同じく、ヒープ条件を保つために呼び出されるものである。

def heappop(heap):
    """Pop the smallest item off the heap, maintaining the heap invariant."""
    lastelt = heap.pop()    # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup(heap, 0)
        return returnitem
    return lastelt

では_siftupを見てみる。

def _siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    childpos = 2*pos + 1
    while childpos < endpos:
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)

heappopから呼び出されたとき、引数には、posはルートのindexである0が入っている。

最初は単に走査に当たっての値の設定だけである。 childpos に関して、これもheapの特徴で、この演算により、左の子のindexが計算できる。教科書ではヒープは1から始まるため、2*posが左の子になることが多いが、 pythonは0から始めているため、+1が入っている。

    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    childpos = 2*pos + 1

同様に右の子は左の子の隣にいるので単に+1すればよい。

        rightpos = childpos + 1

それ以後は、左の子と右の子を比べて小さいほうを上にあげて、さらにその子に同じ処理を続けている。

        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1

そして最後に行きついた場所にもともと、一番最後にいた要素を入れる。

    heap[pos] = newitem

最後のこの_siftdownの呼び出しはいまいちわからない、、

    _siftdown(heap, startpos, pos)

以下のコメントが多分これを言っているのだろうが、いまいち理解できない。Knuth, Volume 3を買わなくては。

# The child indices of heap index pos are already heaps, and we want to make
# a heap at index pos too.  We do this by bubbling the smaller child of
# pos up (and so on with that child's children, etc) until hitting a leaf,
# then using _siftdown to move the oddball originally at index pos into place.
#
# We *could* break out of the loop as soon as we find a pos where newitem <=
# both its children, but turns out that's not a good idea, and despite that
# many books write the algorithm that way.  During a heap pop, the last array
# element is sifted in, and that tends to be large, so that comparing it
# against values starting from the root usually doesn't pay (= usually doesn't
# get us out of the loop early).  See Knuth, Volume 3, where this is
# explained and quantified in an exercise.
#

_siftdown以外の図示はこちら

二分ヒープ - Wikipedia

また、同様にheapfiyも以下のように_siftupを呼び出していることわかる。

def heapify(x):
    """Transform list into a heap, in-place, in O(len(x)) time."""
    n = len(x)
    # Transform bottom-up.  The largest index there's any point to looking at
    # is the largest with a child index in-range, so must have 2*i + 1 < n,
    # or i < (n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
    # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
    # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
    for i in reversed(range(n//2)):
        _siftup(x, i)

なお、以下の演算により、一番後ろにいる子の親が計算できる。0からここまでをreversedすることによって下からヒープ条件を満たしていくことができる。

n = len(x)
range(n//2)

教科書で勉強したことが実際このように実装されているとわかると楽しいよね。