Inside-Python - heapqとPriority Queue

最近cpythonのソースコードの中まで見ることが多いので、ネタごとにまとめてみる。

queue --- 同期キュークラス — Python 3.7.3 ドキュメント

heapq --- ヒープキューアルゴリズム — Python 3.7.3 ドキュメント

heapqはPythonの中で、二分ヒープを実装したものであり、競技プログラミングではよくお世話になっている。

二分ヒープ - Wikipedia

二分ヒープは優先度付きキューによく使うことが多いが、実はpythonには”PriorityQueue"というクラスも存在している。 Priority Queueとheapqの関係は、マニュアルにこのように書かれており、heapqが内部的に利用されていることがわかる。

優先順位付きキュー(priority queue)では、エントリは(heapq モジュールを利用して)ソートされ、

実際、PriorityQueueの実装は以下のように非常にシンプルになっており、内部構造は単なる配列で、heappush / heappopを使ってput/getを実装していることがわかる。

from heapq import heappush, heappop

class PriorityQueue(Queue):
    '''Variant of Queue that retrieves open entries in priority order (lowest first).

    Entries are typically tuples of the form:  (priority number, data).
    '''

    def _init(self, maxsize):
        self.queue = []

    def _qsize(self):
        return len(self.queue)

    def _put(self, item):
        heappush(self.queue, item)

    def _get(self):
        return heappop(self.queue)

ただしこれらのput/getは直接は呼び出されず、 親クラスのQueueのget/putで外部用のIFを提供し、同期処理と共にget / putを呼び出す構造になっている。

ここの同期処理はまたまとめたいが、以下のようにqsize / empty / fullといったメソッドはすべてmutex処理をしており、 これらはすべてthreadingモジュールを利用している。

もちろんこれらの同期処理は競技プログラミングには不要なため、PriorityQueueを競技プログラミングで利用する意味はほぼ無い。

以前間違えてこちらを使ったらLTEになったことがあったな。。

import threading

class Queue:

    def __init__(self, maxsize=0):
        self.mutex = threading.Lock()

    def qsize(self):
        with self.mutex:
            return self._qsize()

    def empty(self):
        with self.mutex:
            return not self._qsize()

    def full(self):
        with self.mutex:
            return 0 < self.maxsize <= self._qsize()

heapqへの要素の追加はheappushで実装されている。

以下の通りの実装で、与えられたリストに対し、appendを行うところまでは通常の配列操作と同じだが、 配列操作後、_siftdown を呼び出している。これらはヒープ条件を保つために呼び出されるものである。

def heappush(heap, item):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap)-1)

さらに、_siftdownは以下のようになっている。

# 'heap' is a heap at all indices >= startpos, except possibly for pos.  pos
# is the index of a leaf with a possibly out-of-order value.  Restore the
# heap invariant.
def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

heappushから呼び出されたとき、_siftdownの引数には、startposは木構造のルート、posは新しく追加されたノードのindexが入っている。

        parentpos = (pos - 1) >> 1

ここはheapの特徴で、親のindexを計算している。>>1は右ビットシフト演算なので÷2と同じ意味合い。

つまり、追加された子のノードから初めて、上にさかのぼって要素を入れ替えていっていることがわかる。

二分ヒープ - Wikipedia

図示はWikipediaを参照。ただし、図はmaxヒープだが、heapqはminヒープであり、一番上に最小値が来ることに注意。

次に要素の削除を見てみる。

以下の通りの実装で、heappushと同じく、pop()を行い、それでリストが空になればそれをそのまま返すようになっている。

しかし、pop()はリストの一番後ろを返す、にもかかわらず最小値が存在するのはheapの一番である。

heappopは最小値を返すことが求められているため、pop()した要素を返して正しく動くのはリストに1件しかなかったときのみである。

2件以上あった時には、この一番後ろの要素を1番前に持ってきて、最初の要素をreturnするようにとっておいたうえで、siftup を呼び出している。 これもsiftdownと同じく、ヒープ条件を保つために呼び出されるものである。

def heappop(heap):
    """Pop the smallest item off the heap, maintaining the heap invariant."""
    lastelt = heap.pop()    # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup(heap, 0)
        return returnitem
    return lastelt

では_siftupを見てみる。

def _siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    childpos = 2*pos + 1
    while childpos < endpos:
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)

heappopから呼び出されたとき、引数には、posはルートのindexである0が入っている。

最初は単に走査に当たっての値の設定だけである。 childpos に関して、これもheapの特徴で、この演算により、左の子のindexが計算できる。教科書ではヒープは1から始まるため、2*posが左の子になることが多いが、 pythonは0から始めているため、+1が入っている。

    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    childpos = 2*pos + 1

同様に右の子は左の子の隣にいるので単に+1すればよい。

        rightpos = childpos + 1

それ以後は、左の子と右の子を比べて小さいほうを上にあげて、さらにその子に同じ処理を続けている。

        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1

そして最後に行きついた場所にもともと、一番最後にいた要素を入れる。

    heap[pos] = newitem

最後のこの_siftdownの呼び出しはいまいちわからない、、

    _siftdown(heap, startpos, pos)

以下のコメントが多分これを言っているのだろうが、いまいち理解できない。Knuth, Volume 3を買わなくては。

# The child indices of heap index pos are already heaps, and we want to make
# a heap at index pos too.  We do this by bubbling the smaller child of
# pos up (and so on with that child's children, etc) until hitting a leaf,
# then using _siftdown to move the oddball originally at index pos into place.
#
# We *could* break out of the loop as soon as we find a pos where newitem <=
# both its children, but turns out that's not a good idea, and despite that
# many books write the algorithm that way.  During a heap pop, the last array
# element is sifted in, and that tends to be large, so that comparing it
# against values starting from the root usually doesn't pay (= usually doesn't
# get us out of the loop early).  See Knuth, Volume 3, where this is
# explained and quantified in an exercise.
#

_siftdown以外の図示はこちら

二分ヒープ - Wikipedia

また、同様にheapfiyも以下のように_siftupを呼び出していることわかる。

def heapify(x):
    """Transform list into a heap, in-place, in O(len(x)) time."""
    n = len(x)
    # Transform bottom-up.  The largest index there's any point to looking at
    # is the largest with a child index in-range, so must have 2*i + 1 < n,
    # or i < (n-1)/2.  If n is even = 2*j, this is (2*j-1)/2 = j-1/2 so
    # j-1 is the largest, which is n//2 - 1.  If n is odd = 2*j+1, this is
    # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1.
    for i in reversed(range(n//2)):
        _siftup(x, i)

なお、以下の演算により、一番後ろにいる子の親が計算できる。0からここまでをreversedすることによって下からヒープ条件を満たしていくことができる。

n = len(x)
range(n//2)

教科書で勉強したことが実際このように実装されているとわかると楽しいよね。