PythonのNumpyで、要素が非常に多い1次元配列をnp.sortやnp.argsortでソートした際の処理に時間がかかることに気づきました。いろいろ試していく中で、ソート後の値が小さい順に10個の要素だけ欲しいというような場合はnp.partitionやnp.argpartitionが速いということが分かりました。
遅いコード
以下のコードを動かします。
1億個の乱数をソートしますが、環境に合わせてNの値は変えてください。
また同じ乱数セットでソートのテストを行いたいため、np.saveで配列を保存して使いまわします。
import numpy as np
try:
x = np.load("x.npy")
except:
N = 100000000
x = np.random.rand(N)
np.save("x.npy", x)
idx = np.argsort(x)
sorted_x = x[idx]
idx10 = idx[:10]
sorted10_x = x[idx10]
print(idx10)
print(sorted10_x)
こちらのコードをWindowsのPythonで動かします。
動かす際はPowershellのコマンドを使って実行時間を測ります。
#コマンドプロンプトの場合
powershell -C (Measure-Command {python test_numpy.py}).TotalMilliseconds
#Powershellの場合
(Measure-Command {python test_numpy.py}).TotalMilliseconds
このコマンドで実行すると標準出力は出てきませんが、時間がミリ秒単位で出てきます。
この場合は17439.8781ミリ秒(17.4秒)でした。何回か試してもほぼ同じ時間でした。
速いコード
同じ乱数セットをロードして、ソート処理の部分を丸ごと書き換えました。
import numpy as np
try:
x = np.load("x.npy")
except:
N = 100000000
x = np.random.rand(N)
np.save("x.npy", x)
smallest_idx = np.argpartition(x, 10)[:10]
smallest_x = x[smallest_idx]
sorted_idx = np.argsort(smallest_x)
sorted10_x = smallest_x[sorted_idx]
idx10 = smallest_idx[sorted_idx]
print(idx10)
print(sorted10_x)
こちらのコードは1007.9567ミリ秒(1.0秒)でした。
なんと17倍も高速化されました。びっくり。
なお、partitionやargpartitionでは小さい順に10個の要素を取得することはできますが、その10個の順番は小さい順になりません。別途ソートをかける必要があり、そのためコードが少し分かりづらくなります。
np.partitionの動き
np.partitionやnp.argpartitionは、指定した要素よりも「小さいもの」「大きいもの」を分ける関数です。また「〇番目に小さい要素」を抽出する用途にも使います。
#1次元配列xを表示
>>> x
array([10, 5, 6, 2, 7, 3, 8, 4, 9])
#3番目に小さい要素(4)で区切る
>>> np.partition(x, 2)
array([ 2, 3, 4, 10, 7, 5, 8, 6, 9]) #4から左側は4より小さく、右側は4より大きい
#3番目に小さい要素(4)を抽出する
>>> np.partition(x, 2)[2]
4
注意点として、Pythonのインデックスは0から数え始めるため、3番目を指定する場合は第2引数に2を入れます。
まとめ
Numpyのソートが遅い場合はpartitionで何倍も高速化できる場合があります。
コメント