2012-05-22 20 views
30

Para mi prueba unitaria, quiero verificar si dos matrices son idénticas. Ejemplo reducido:comparando matrices numpy que contienen NaN

a=np.array([1, 2, np.NaN]) 
b=np.array([1, 2, np.NaN]) 
if np.all(a==b): 
    print 'arrays are equal' 

Esto no funciona porque nan! = Nan. ¿Cuál es la mejor manera de proceder?

Gracias de antemano.

Respuesta

20

Alternativamente, puede utilizar numpy.testing.assert_equal o numpy.testing.assert_array_equal con un try/except:

In : import numpy as np 

In : def nan_equal(a,b): 
...:  try: 
...:   np.testing.assert_equal(a,b) 
...:  except AssertionError: 
...:   return False 
...:  return True 

In : a=np.array([1, 2, np.NaN]) 

In : b=np.array([1, 2, np.NaN]) 

In : nan_equal(a,b) 
Out: True 

In : a=np.array([1, 2, np.NaN]) 

In : b=np.array([3, 2, np.NaN]) 

In : nan_equal(a,b) 
Out: False 

Editar

Puesto que usted está usando este para la prueba unitaria, desnudo assert (en lugar de envolverlo para obtener True/False) podría ser más natural.

+0

Excelente, esta es la solución más elegante e integrada. Acabo de agregar 'np.testing.assert_equal (a, b)' en mi prueba unitaria, y si aumenta la excepción, la prueba falla (sin error), e incluso obtengo una buena impresión con las diferencias y la falta de coincidencia. Gracias. – saroele

+3

Tenga en cuenta que esta solución funciona porque 'numpy.testing.assert_ *' no sigue la misma semántica de python 'assert''s. En las excepciones simples de Python 'AssertionError' se plantean iff' __debug__ es True', es decir, si el script se ejecuta no optimizado (no -O indicador), vea el [documento] (http://docs.python.org/3.3/reference /simple_stmts.html#grammar-token-assert_stmt). Por esta razón, desaconsejaría encarecidamente envolver 'AssertionErrors' para el control de flujo. Por supuesto, dado que estamos en un banco de pruebas, la mejor solución es dejar el numpy.testing.assert solo. –

8

Usted podría utilizar matrices numpy enmascarados, enmascarar los valores NaN y luego usar numpy.ma.all o numpy.ma.allclose:

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.all.html

http://docs.scipy.org/doc/numpy/reference/generated/numpy.ma.allclose.html

Por ejemplo:

a=np.array([1, 2, np.NaN]) 
b=np.array([1, 2, np.NaN]) 
np.ma.all(np.ma.masked_invalid(a) == np.ma.masked_invalid(b)) #True 
+1

gracias por hacerme consciente del uso de matrices enmascarados. Sin embargo, prefiero la solución de Avaris. – saroele

+0

Debería usar 'np.ma.masked_where (np.isnan (a), a)' de lo contrario, no podrá comparar los valores infinitos. –

+0

Probé con 'a = np.array ([1, 2, np.NaN])' y 'b = np.array ([1, np.NaN, 2])' que claramente no son iguales y 'np. ma.all (np.ma.masked_invalid (a) == np.ma.masked_invalid (b)) 'todavía devuelve True, así que ten cuidado si usas este método. – tavo

20

No estoy seguro de que este es el mejor manera de proceder, pero es una manera:

>>> ((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all() 
True 
+0

+1 Esta solución parece ser un poco más rápida que la solución que publiqué con matrices enmascaradas, aunque si estuviera creando la máscara para utilizarla en otras partes de su código, la sobrecarga de la creación de la máscara sería menos importante en el eficiencia general de la estrategia ma. – JoshAdel

+0

Gracias.Su solución funciona de hecho, pero prefiero la prueba incorporada en numpy como sugiere Avaris – saroele

+1

Me gusta mucho la simplicidad de esto. Además, parece una solución más rápida que @Avaris. Al convertir esto en una función lambda, la prueba con el '% timeit' de Ipython produce 23.7 μs frente a 1.01 ms. – AllanLRH

1

cuando utilicé la respuesta anterior:

((a == b) | (numpy.isnan(a) & numpy.isnan(b))).all() 

Me dio algunos erros cuando se evalúan lista de cadenas.

Esto es más de tipo genérico:

def EQUAL(a,b): 
    return ((a == b) | ((a != a) & (b != b))) 
6

La forma más sencilla es utilizar numpy.allclose() método, que permite especificar el comportamiento al tener valores nan. Luego, su ejemplo se verá como la siguiente:

a = np.array([1, 2, np.nan]) 
b = np.array([1, 2, np.nan]) 

if np.allclose(a, b, equal_nan=True): 
    print 'arrays are equal' 

continuación se imprimirán arrays are equal.

puede encontrar here la documentación relacionada

+0

+1 porque su solución no reinventa la rueda. Sin embargo, esto solo funciona con elementos similares a números. De lo contrario, obtienes el desagradable 'TypeError: ufunc 'isfinite' no admitido para los tipos de entrada, y las entradas no se pueden forzar de forma segura a ningún tipo soportado de acuerdo con la regla de conversión '' safe''' – MLguy