久しぶりです。修論も終わり、出張の準備も落ち着いたので、ちょっとした tips とかまた書き溜めていこうと思います。サイト名も急な思いつきで変更しました...^^;

ところで、python で高速な数値計算を実現する numpy、それと大体互換性のある cupy というGPU用ライブラリが便利で愛用してます。しかし、ちょっとした関数を書くとき、引数が numpy の array か、 cupy なのか教えてくれる (より正確には cupy の array があれば cupy モジュールを、それ以外は numpy を返す) cupy.get_array_module という関数名を思い出して、タイプするのは、ほぼボイラープレートなので面倒でした。そこで多くの場合、私はデコレータを使って、関数の"見えない"引数である xp により numpy と cupy を切り替えてます。

def host_device(func):
    import functools
    import cupy
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        xp = cupy.get_array_module(*args, **kwargs) # ボイラープレート
        return func(xp, *args, **kwargs)
    return wrapper


@host_device
def sigmoid(xp, x):
    """
    >>> import numpy, cupy
    >>> i = numpy.random.randn(2,3,4)
    >>> n = cupy.asarray(sigmoid(i))
    >>> c = sigmoid(cupy.asarray(i))
    >>> cupy.testing.assert_array_equal(n, c)
    """
    return 1.0 / (1.0 + xp.exp(-x))


if __name__ == "__main__":
    import doctest
    doctest.testmod()

例示しやすいので doctest で書きましたが、要素毎の操作する関数などテストパターンを共通化してもいいかもしれません。それもデコレータで書いてみましょう

def test_elementwise(func):
    if __debug__:
        # 最適化オプション(-O)無し、定義時の一度だけ実行
        import numpy, cupy
        i = numpy.random.randn(2,3,4)
        n = cupy.asarray(func(i))
        c = func(cupy.asarray(i))
        cupy.testing.assert_array_almost_equal(n, c)
        print("<{}> is ok".format(func.__name__))
    return func


@test_elementwise
def relu(x):
    return x * (x > 0.0)


@test_elementwise # 順番に注意
@host_device
def sigmoid(xp, x):
    return 1.0 / (1.0 + xp.exp(-x))

ipython 等で使ってみると、assert はデコレータ内で debug 時に一度だけ実行されて以降は呼ばれないことが print により確認できます (単体テストは一瞬で終わるべきと思っている私はこの些細な debug 時オーバーヘッドを気にしません)。 ちなみにデコレータは後に記述されたものから順に適用されます。その理由は一番下の関数本体から定義されるからだと推測しますが、控えめに言ってデコレータ記法って言語設計ミスだと思います。

最後に今回は挙げなかった @host_device のようなデコレータの使い道として、cupy に numpy 非互換な挙動があればデコレータ内のwrapperreturn する際に吸収することもできます。ただ、もし非互換な挙動を見つけたら issue を報告した方がいいのかも