numpyの多次元配列の比較時に発生するエラーについて

numpyの多次元配列を比較しようとすると、配列の各要素を比較した結果を配列として返すので、all()やany()を使ってANDやORでreduceした結果を取得できそうだが、以下のようなエラーに遭遇する。

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

多次元の場合はちゃんとメソッド使いましょう、という話。

また、差集合や積集合が欲しい場合、numpyの関数setdiff1dやintersect1dはその名の通り1次元にflattenしてしまうので、多次元配列の要素をそのまま結果に取り出したい場合はsetを使う。ただし、numpyのメソッドだけで完結しないので速度は出ないかもしれない。

$ python
Python 3.4.3 (default, Oct 24 2015, 14:51:44) 
[GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> a = np.array([1,2,3])
>>> b = np.array([2,3,4])
>>> c = np.array([1,3,4])
>>> a==b
array([False, False, False], dtype=bool)
>>> a==c
array([ True, False, False], dtype=bool)
>>> all(a==b)
False
>>> any(a==b)
False
>>> any(a==c)
True
>>> all(a==a)
True
>>> (a==b).all()
False
>>> (a==b).any()
False
>>> (a==c).any()
True
>>> (a==a).all()
True
>>> aa = np.array([[10,20,30],[20,30,40],[30,40,50]])
>>> bb = np.array([[20,30,40],[30,40,50],[40,50,60]])
>>> cc = np.array([[10,20,30],[30,40,50],[40,50,60]])
>>> aa
array([[10, 20, 30],
       [20, 30, 40],
       [30, 40, 50]])
>>> aa==bb
array([[False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)
>>> all(aa==bb)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> (aa==bb).all()
False
>>> (aa==bb).any()
False
>>> (aa==cc).any()
True
>>> (aa==aa).any()
True
>>> np.setdiff1d(aa,bb)
array([10])
>>> np.setdiff1d(aa,cc)
array([], dtype=int64)
>>> np.intersect1d(aa,bb)
array([20, 30, 40, 50])
>>> np.intersect1d(aa,cc)
array([10, 20, 30, 40, 50])
>>> aas = set([tuple(x) for x in aa])
>>> bbs = set([tuple(x) for x in bb])
>>> ccs = set([tuple(x) for x in cc])
>>> np.array([x for x in aas-bbs])
array([[10, 20, 30]])
>>> np.array([x for x in aas-ccs])
array([[20, 30, 40]])
>>> np.array([x for x in aas-aas])
array([], dtype=float64)
>>> np.array([x for x in aas&bbs])
array([[20, 30, 40],
       [30, 40, 50]])
>>> np.array([x for x in aas&ccs])
array([[10, 20, 30],
       [30, 40, 50]])
>>> np.array([x for x in aas&aas])
array([[20, 30, 40],
       [10, 20, 30],
       [30, 40, 50]])

大きいサイズの配列でどの程度かかるのか試してみる。メモリに乗り切るなるべく大きいサイズの配列を準備して計測する。結果は違うが、numpyの関数と比較してみる。それぞれで差集合と積集合の処理速度が違うようだが、それほど差は無いように見える。

import numpy as np
import time

ai = [[x for y in range(2000)] for x in range(20000)]
bi = [[x+1 for y in range(2000)] for x in range(20000)]

al = np.array(ai)
bl = np.array(bi)

print(al)
print(bl)

start = time.time()
als = set([tuple(x) for x in al])
end = time.time()
print("{0:.5f} sec.".format(end-start))

bls = set([tuple(x) for x in bl])

start = time.time()
a_m_b = np.array([x for x in (als-bls)])
end = time.time()
print("{0:.5f} sec.".format(end-start))

start = time.time()
a_a_b = np.array([x for x in (als&bls)])
end = time.time()
print("{0:.5f} sec.".format(end-start))

print("# of setdiff  : {0}".format(len(a_m_b)))
print("# of intersect: {0}".format(len(a_a_b)))

start = time.time()
a_nm_b = np.setdiff1d(al,bl)
end = time.time()
print("{0:.5f} sec.".format(end-start))

start = time.time()
a_na_b = np.intersect1d(al,bl)
end = time.time()
print("{0:.5f} sec.".format(end-start))

print("# of setdiff1d  : {0}".format(len(a_nm_b)))
print("# of intersect1d: {0}".format(len(a_na_b)))
$ python numpy_test.py 
[[    0     0     0 ...,     0     0     0]
 [    1     1     1 ...,     1     1     1]
 [    2     2     2 ...,     2     2     2]
 ..., 
 [19997 19997 19997 ..., 19997 19997 19997]
 [19998 19998 19998 ..., 19998 19998 19998]
 [19999 19999 19999 ..., 19999 19999 19999]]
[[    1     1     1 ...,     1     1     1]
 [    2     2     2 ...,     2     2     2]
 [    3     3     3 ...,     3     3     3]
 ..., 
 [19998 19998 19998 ..., 19998 19998 19998]
 [19999 19999 19999 ..., 19999 19999 19999]
 [20000 20000 20000 ..., 20000 20000 20000]]
3.99010 sec.
1.13345 sec.
4.29675 sec.
# of setdiff  : 1
# of intersect: 19999
2.65021 sec.
1.83890 sec.
# of setdiff1d  : 1
# of intersect1d: 19999