2009-11-08 21 views
11

Tengo dos matrices (a y b) con n elementos enteros en el rango (0, N).Numpy, problema con matrices largas

typo: matrices con 2^n números enteros, donde el entero más grande toma el valor N = 3^n

Quiero calcular la suma de todas las combinaciones de elementos en A y B (sum_ij_ = a_i_ + b_j_ para todos i, j). Luego tome el módulo N (sum_ij_ = sum_ij_% N), y finalmente calcule la frecuencia de las diferentes sumas.

Para hacer esto rápido con numpy, sin ningún bucle, traté de usar la malla de malla y la función de recuento de cuentas.

A,B = numpy.meshgrid(a,b) 
A = A + B 
A = A % N 
A = numpy.reshape(A,A.size) 
result = numpy.bincount(A) 

Ahora, el problema es que mis matrices de entrada son largas. Y meshgrid me da MemoryError cuando uso entradas con 2^13 elementos. Me gustaría calcular esto para matrices con 2^15-2^20 elementos.

que es n en el intervalo de 15 a 20

¿Hay algún truco inteligente de hacer esto con numpy?

Cualquier ayuda será muy apreciada.

- Jon

+0

¿Y qué tan grande es N? – unutbu

+0

¿Numpy realmente va a ser tan eficiente? Supongo que estarás mejor en C++, escribiendo tus propias funciones y optimizando todo lo que puedas. De lo que parece numpy no se puede manejar una matriz tan grande. Aunque debo decir que si tiene dos matrices con 2^15 a 2^20 elementos, entonces si mira todas sus diferentes sumas, terminará con una matriz de 2^30 a 2^40 elementos. Que es mucho .. – JSchlather

+0

@unutbu: N ~ 3^n @liberalkid: Supongo que tienes razón. Aunque mis habilidades en C++ no son tan buenas. – jonalm

Respuesta

7

try chunking it.tu malla de malla es una matriz NxN, bloquea hasta 10x10 N/10xN/10 y solo calcula 100 bandejas, agrégalas al final. esto solo usa ~ 1% de memoria tanto como hacer todo.

+0

Supongo que este es el camino a seguir, pero ¿hay alguna manera inteligente de hacerlo con matrices numpy? Minimizando el uso de for bucles. – jonalm

+0

Oye, ¿hay un tamaño óptimo para un bloque? – jonalm

+0

probablemente el más grande que pueda hacer un bloque y aún así mantenerlo seguro metido en ram. – Autoplectic

1

Editar en respuesta al comentario de jonalm:

jonalm: N ~ 3^n no n ~ 3^N. N es el elemento máximo en a y n es el número de elementos en a.

n es ~ 2^20. Si N es ~ 3^n, entonces N es ~ 3^(2^20)> 10^(500207). Los científicos estiman (http://www.stormloader.com/ajy/reallife.html) que solo hay alrededor de 10^87 partículas en el universo. Por lo tanto, no hay una forma (ingenua) en que una computadora pueda manejar un int del tamaño 10^(500207).

jonalm: Sin embargo, soy un poco curioso acerca de la función pv() que defina. (I no logro ejecutarlo porque text.find() no está definido (adivinarlo en otro módulo )). ¿Cómo funciona esta función y cuál es su ventaja?

pv es una pequeña función auxiliar que escribí para depurar el valor de las variables. Funciona como print() excepto cuando dice pv (x) imprime tanto el nombre de la variable literal (o cadena de expresión), dos puntos y luego el valor de la variable.

Si pones

#!/usr/bin/env python 
import traceback 
def pv(var): 
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2] 
    print('%s: %s'%(text[text.find('(')+1:-1],var)) 
x=1 
pv(x) 

en un script que debe obtener

x: 1 

La modesta ventaja de utilizar pv estampado es que le ahorra escribir.En lugar de tener que escribir

print('x: %s'%x) 

sólo puede dar una palmada abajo

pv(x) 

Cuando existen múltiples variables para realizar un seguimiento, es útil para etiquetar las variables. Me cansé de escribirlo todo.

La función pv funciona utilizando el módulo traceback para ver la línea de código que se usa para llamar a la función pv. (Consulte http://docs.python.org/library/traceback.html#module-traceback) Esa línea de código se almacena como una cadena en el texto variable. text.find() es una llamada al método de cadena habitual find(). Por ejemplo, si

text='pv(x)' 

entonces

text.find('(') == 2    # The index of the '(' in string text 
text[text.find('(')+1:-1] == 'x' # Everything in between the parentheses 

Asumo n ~ 3^N, y n ~ 2 ** 20

La idea es trabajar módulo N. Este recortes abajo en el tamaño de las matrices. La segunda idea (importante cuando n es enorme) es usar ndarrays numpy del tipo 'objeto' porque si usa un tipo entero, corre el riesgo de desbordar el tamaño del entero máximo permitido.

#!/usr/bin/env python 
import traceback 
import numpy as np 

def pv(var): 
    (filename,line_number,function_name,text)=traceback.extract_stack()[-2] 
    print('%s: %s'%(text[text.find('(')+1:-1],var)) 

Puede cambiar n ser 2 ** 20, pero a continuación les muestro lo que sucede con los pequeños n lo que la salida es más fácil de leer.

n=100 
N=int(np.exp(1./3*np.log(n))) 
pv(N) 
# N: 4 

a=np.random.randint(N,size=n) 
b=np.random.randint(N,size=n) 
pv(a) 
pv(b) 
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3 
# 1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2 
# 0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1] 
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0 
# 3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3 
# 3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1] 

wa mantiene el número de 0s, 1s, 2s, 3s en una wb contiene el número de 0s, 1s, 2s, 3s en b

wa=np.bincount(a) 
wb=np.bincount(b) 
pv(wa) 
pv(wb) 
# wa: [24 28 28 20] 
# wb: [21 34 20 25] 
result=np.zeros(N,dtype='object') 

Piense en un 0 como una muestra o chip Del mismo modo para 1,2,3.

Creo que wa = [24 28 28 20] significa que hay una bolsa con 24 chips 0, 28 1 chips, 28 chips 2, 20 chips 3.

Usted tiene un wa-bag y un wb-bag. Cuando sacas un chip de cada bolsa, las "agregas" para formar un nuevo chip. Usted "mod" la respuesta (módulo N).

Imagínese tomar un 1-chip del wb-bag y agregarlo con cada chip en el wa-bag.

1-chip + 0-chip = 1-chip 
1-chip + 1-chip = 2-chip 
1-chip + 2-chip = 3-chip 
1-chip + 3-chip = 4-chip = 0-chip (we are mod'ing by N=4) 

Puesto que hay 34 1-chips en la bolsa BM, cuando se agregan en contra de todas las fichas en el wa = [24 28 28 20] bolsa, se obtiene

34*24 1-chips 
34*28 2-chips 
34*28 3-chips 
34*20 0-chips 

Este es solo el recuento parcial debido a los 34 1 chips. También tiene que manejar los otros tipos de chips en el BM-bolsa, pero esto se muestra el método utilizado a continuación:

for i,count in enumerate(wb): 
    partial_count=count*wa 
    pv(partial_count) 
    shifted_partial_count=np.roll(partial_count,i) 
    pv(shifted_partial_count) 
    result+=shifted_partial_count 
# partial_count: [504 588 588 420] 
# shifted_partial_count: [504 588 588 420] 
# partial_count: [816 952 952 680] 
# shifted_partial_count: [680 816 952 952] 
# partial_count: [480 560 560 400] 
# shifted_partial_count: [560 400 480 560] 
# partial_count: [600 700 700 500] 
# shifted_partial_count: [700 700 500 600] 

pv(result)  
# result: [2444 2504 2520 2532] 

Este es el resultado final: 2444 0s, 2504 1s, 2520 2s, 2532 3s .

# This is a test to make sure the result is correct. 
# This uses a very memory intensive method. 
# c is too huge when n is large. 
if n>1000: 
    print('n is too large to run the check') 
else: 
    c=(a[:]+b[:,np.newaxis]) 
    c=c.ravel() 
    c=c%N 
    result2=np.bincount(c) 
    pv(result2) 
    assert(all(r1==r2 for r1,r2 in zip(result,result2))) 
# result2: [2444 2504 2520 2532] 
+0

Tenga en cuenta que 'c% = N' funciona (y puede usar el doble de memoria). – EOL

+0

@EOL, sí, c% = N es mejor. Sin embargo, definir 'c = (a [:] + b [:, np.newaxis])' significa que ya ha perdido la batalla, ya que esta es una gran variedad de formas de 2-d (n, n) mientras que la anterior solución utiliza nada más que un par de matrices de forma 1-d (N). – unutbu

+0

Muchas gracias por la respuesta, me gusta este método. Pero no creo que esto me ayude, ya que todos los números en el conjunto a (y en b) son diferentes (no lo mencioné, mi mal). bincount (a) solo consistirá en 1 y 0. N ~ 3^n not n ~ 3^N. N es el elemento máximo en a y n es el número de elementos en a. Sin embargo, soy un poco curioso sobre la función pv() que defines. (No logro ejecutarlo porque text.find() no está definido (adivinarlo en otro módulo)). ¿Cómo funciona esta función y cuál es su ventaja? – jonalm

1

Compruebe su matemáticas, eso es mucho del espacio que está pidiendo:

2^20 * 2^20 = 2^40 = 1 099 511 627 776

Si cada uno de sus elementos era solo un byte, eso ya es un terabyte de memoria.

Agregue un bucle o dos. Este problema no es adecuado para maximizar su memoria y minimizar su cálculo.