NumPyのsortは遅いがpartitionを使うと速い

プログラミング

PythonのNumpyで、要素が非常に多い1次元配列をnp.sortnp.argsortでソートした際の処理に時間がかかることに気づきました。いろいろ試していく中で、ソート後の値が小さい順に10個の要素だけ欲しいというような場合はnp.partitionnp.argpartitionが速いということが分かりました。

スポンサーリンク

遅いコード

以下のコードを動かします。
1億個の乱数をソートしますが、環境に合わせてNの値は変えてください。
また同じ乱数セットでソートのテストを行いたいため、np.saveで配列を保存して使いまわします。

import numpy as np

try:
    x = np.load("x.npy")
except:
    N = 100000000
    x = np.random.rand(N)
    np.save("x.npy", x)

idx        = np.argsort(x)
sorted_x   = x[idx]
idx10      = idx[:10]
sorted10_x = x[idx10]

print(idx10)
print(sorted10_x)


こちらのコードをWindowsのPythonで動かします。
動かす際はPowershellのコマンドを使って実行時間を測ります。

#コマンドプロンプトの場合
powershell -C (Measure-Command {python test_numpy.py}).TotalMilliseconds
#Powershellの場合
(Measure-Command {python test_numpy.py}).TotalMilliseconds

このコマンドで実行すると標準出力は出てきませんが、時間がミリ秒単位で出てきます。
この場合は17439.8781ミリ秒(17.4秒)でした。何回か試してもほぼ同じ時間でした。

スポンサーリンク

速いコード

同じ乱数セットをロードして、ソート処理の部分を丸ごと書き換えました。

import numpy as np

try:
    x = np.load("x.npy")
except:
    N = 100000000
    x = np.random.rand(N)
    np.save("x.npy", x)

smallest_idx = np.argpartition(x, 10)[:10]
smallest_x   = x[smallest_idx]
sorted_idx   = np.argsort(smallest_x)
sorted10_x   = smallest_x[sorted_idx]
idx10        = smallest_idx[sorted_idx]

print(idx10)
print(sorted10_x)

こちらのコードは1007.9567ミリ秒(1.0秒)でした。
なんと17倍も高速化されました。びっくり。

なお、partitionやargpartitionでは小さい順に10個の要素を取得することはできますが、その10個の順番は小さい順になりません。別途ソートをかける必要があり、そのためコードが少し分かりづらくなります。

スポンサーリンク

np.partitionの動き

np.partitionやnp.argpartitionは、指定した要素よりも「小さいもの」「大きいもの」を分ける関数です。また「〇番目に小さい要素」を抽出する用途にも使います。

#1次元配列xを表示
>>> x
array([10,  5,  6,  2,  7,  3,  8,  4,  9])

#3番目に小さい要素(4)で区切る
>>> np.partition(x, 2)
array([ 2,  3,  4, 10,  7,  5,  8,  6,  9]) #4から左側は4より小さく、右側は4より大きい

#3番目に小さい要素(4)を抽出する
>>> np.partition(x, 2)[2]
4

注意点として、Pythonのインデックスは0から数え始めるため、3番目を指定する場合は第2引数に2を入れます。

スポンサーリンク

まとめ

Numpyのソートが遅い場合はpartitionで何倍も高速化できる場合があります。

コメント

タイトルとURLをコピーしました