2011-12-28 15 views
18

Alerta de alerta: esto está relacionado con Problem 14 de Project Euler.¿Por qué es tan simple este algoritmo haskell tan lento?

El siguiente código tarda unos 15 segundos en ejecutarse. Tengo una solución Java no recursiva que se ejecuta en 1s. Creo que debería poder obtener este código mucho más cerca de eso.

import Data.List 

collatz a 1 = a 
collatz a x 
    | even x = collatz (a + 1) (x `div` 2) 
    | otherwise = collatz (a + 1) (3 * x + 1) 

main = do 
    print ((foldl1' max) . map (collatz 1) $ [1..1000000]) 

he perfilado con +RHS -p y se dio cuenta de que la memoria asignada es grande, y crece a medida que crece la entrada. Para n = 100,000 se asigna 1gb (!), Para n = 1,000,000 se asignan 13 gb (!!).

Por otra parte, -sstderr muestra que, aunque se asignaron muchos bytes, el uso de la memoria total fue de 1 mb, y la productividad fue del 95% +, por lo que tal vez 13 gb son una auténtica locura.

Se me ocurren algunas posibilidades:

  1. Algo no es tan estricta como tiene que ser. Ya descubrí foldl1', pero ¿quizás necesito hacer más? ¿Es posible marcar collatz tan estrictas (¿eso siquiera tiene sentido?

  2. collatz no es la optimización de llamada final. Creo que debería ser, pero no saben una manera de confirmar.

  3. El compilador no está haciendo algunas optimizaciones creo que debería - por ejemplo solamente dos resultados de collatz tienen que estar en la memoria en un momento dado (máximo y actual)

cualquier sugerencia

?

Esto es más o menos un duplicado de Why is this Haskell expression so slow?, aunque señalaré que la solución rápida de Java no tiene que realizar ninguna memoria. ¿Hay alguna forma de acelerar esto sin tener que recurrir a él?

Para referencia, aquí es mi producción de perfiles:

Wed Dec 28 09:33 2011 Time and Allocation Profiling Report (Final) 

    scratch +RTS -p -hc -RTS 

    total time =  5.12 secs (256 ticks @ 20 ms) 
    total alloc = 13,229,705,716 bytes (excludes profiling overheads) 

COST CENTRE     MODULE    %time %alloc 

collatz      Main     99.6 99.4 


                           individual inherited 
COST CENTRE    MODULE            no. entries %time %alloc %time %alloc 

MAIN      MAIN             1   0 0.0 0.0 100.0 100.0 
CAF      Main             208   10 0.0 0.0 100.0 100.0 
    collatz    Main             215   1 0.0 0.0  0.0 0.0 
    main     Main             214   1 0.4 0.6 100.0 100.0 
    collatz    Main             216   0 99.6 99.4 99.6 99.4 
CAF      GHC.IO.Handle.FD          145   2 0.0 0.0  0.0 0.0 
CAF      System.Posix.Internals        144   1 0.0 0.0  0.0 0.0 
CAF      GHC.Conc            128   1 0.0 0.0  0.0 0.0 
CAF      GHC.IO.Handle.Internals        119   1 0.0 0.0  0.0 0.0 
CAF      GHC.IO.Encoding.Iconv        113   5 0.0 0.0  0.0 0.0 

Y -sstderr:

./scratch +RTS -sstderr 
525 
    21,085,474,908 bytes allocated in the heap 
     87,799,504 bytes copied during GC 
      9,420 bytes maximum residency (1 sample(s))   
      12,824 bytes maximum slop    
       1 MB total memory in use (0 MB lost due to fragmentation) 

    Generation 0: 40219 collections,  0 parallel, 0.40s, 0.51s elapsed 
    Generation 1:  1 collections,  0 parallel, 0.00s, 0.00s elapsed 

    INIT time 0.00s ( 0.00s elapsed) 
    MUT time 35.38s (36.37s elapsed) 
    GC time 0.40s ( 0.51s elapsed) 
    RP time 0.00s ( 0.00s elapsed) PROF time 0.00s ( 0.00s elapsed) 
    EXIT time 0.00s ( 0.00s elapsed) 
    Total time 35.79s (36.88s elapsed) %GC time  1.1% (1.4% elapsed) Alloc rate 595,897,095 bytes per MUT second 

    Productivity 98.9% of total user, 95.9% of total elapsed 

y la solución de Java (no la mía, tomada de los foros Proyecto Euler con memoization eliminado):

public class Collatz { 
    public int getChainLength(int n) 
    { 
    long num = n; 
    int count = 1; 
    while(num > 1) 
    { 
     num = (num%2 == 0) ? num >> 1 : 3*num+1; 
     count++; 
    } 
    return count; 
    } 

    public static void main(String[] args) { 
    Collatz obj = new Collatz(); 
    long tic = System.currentTimeMillis(); 
    int max = 0, len = 0, index = 0; 
    for(int i = 3; i < 1000000; i++) 
    { 
     len = obj.getChainLength(i); 
     if(len > max) 
     { 
     max = len; 
     index = i; 
     } 
    } 
    long toc = System.currentTimeMillis(); 
    System.out.println(toc-tic); 
    System.out.println("Index: " + index + ", length = " + max); 
    } 
} 
+0

Es bastante sorprendente que GHC no optimice (quot n 2) a (rshift n 1) como se esperaría que hiciera cualquier compilador de C que se precie. ¿Hay una razón? –

+0

@solrize: Me sorprendió también. – ehird

Respuesta

20

Al principio, pensé que debería intentar poner un signo de exclamación antes de un en collatz:

collatz !a 1 = a 
collatz !a x 
    | even x = collatz (a + 1) (x `div` 2) 
    | otherwise = collatz (a + 1) (3 * x + 1) 

(Usted tendrá que poner {-# LANGUAGE BangPatterns #-} en la parte superior de su archivo de origen para que esto funcione.)

Mi razonamiento fue el siguiente: El problema es que eres construyendo un enorme thunk en el primer argumento para collatz: comienza como 1, y luego se convierte en 1 + 1, y luego se convierte en (1 + 1) + 1, ... todo sin ser forzado.Este bang pattern obliga al primer argumento de collatz a forzarse cada vez que se realiza una llamada, por lo que comienza como 1, y luego se convierte en 2, y así sucesivamente, sin crear un gran thunk no evaluado: simplemente permanece como un entero.

Tenga en cuenta que un patrón de explosión es solo una abreviatura para usar seq; en este caso, podríamos reescribir collatz de la siguiente manera:

collatz a _ | seq a False = undefined 
collatz a 1 = a 
collatz a x 
    | even x = collatz (a + 1) (x `div` 2) 
    | otherwise = collatz (a + 1) (3 * x + 1) 

El truco aquí es forzar un en la guardia, que luego siempre se evalúa como False (y por lo que el cuerpo es irrelevante). Luego la evaluación continúa con el siguiente caso, a que ya ha sido evaluado. Sin embargo, un patrón de explosión es más claro.

Desafortunadamente, cuando se compila con -O2, ¡esto no se ejecuta más rápido que el original! Que mas podemos probar? Bueno, una cosa que podemos hacer es asumir que los dos números nunca desbordan un entero de máquina de tamaño, y dar collatz este tipo de anotación:

collatz :: Int -> Int -> Int 

Dejaremos el patrón de explosión allí, ya que todavía hay que evitar la construcción hasta thunks, incluso si no son la raíz del problema de rendimiento. Esto reduce el tiempo a 8.5 segundos en mi computadora (lenta).

El próximo paso es intentar acercar esto a la solución Java. Lo primero es darse cuenta de que en Haskell, div se comporta de una manera más matemáticamente correcta con respecto a los enteros negativos, pero es más lenta que la división C "normal", que en Haskell se llama quot. Reemplazando div con quot redujo el tiempo de ejecución a 5.2 segundos, y reemplazando x `quot` 2 con x `shiftR` 1 (importando Data.Bits) para que coincida con la solución Java lo bajó a 4.9 segundos.

Esto es casi lo más bajo que puedo obtener por ahora, pero creo que este es un resultado bastante bueno; dado que su computadora es más rápida que la mía, debería estar aún más cerca de la solución Java.

Aquí está el código final (lo hice un poco de limpieza en el camino):

{-# LANGUAGE BangPatterns #-} 

import Data.Bits 
import Data.List 

collatz :: Int -> Int 
collatz = collatz' 1 
    where collatz' :: Int -> Int -> Int 
     collatz' !a 1 = a 
     collatz' !a x 
      | even x = collatz' (a + 1) (x `shiftR` 1) 
      | otherwise = collatz' (a + 1) (3 * x + 1) 

main :: IO() 
main = print . foldl1' max . map collatz $ [1..1000000] 

Mirando el GHC Core para este programa (con ghc-core), creo que esto es probablemente tan bien como se pone; el bucle collatz usa enteros sin casillas y el resto del programa se ve bien. La única mejora que puedo pensar sería eliminar el boxeo de la iteración map collatz [1..1000000].

Por cierto, no se preocupe por la figura de "total alloc"; es la memoria total asignada durante la vida útil del programa, y nunca disminuye incluso cuando el GC recupera esa memoria. Las cifras de múltiples terabytes son comunes.

+0

Gracias, eso es realmente útil. No sabía acerca de '-O2', eso hace una gran diferencia (baja el tiempo de ejecución a 5s). Se agregó una solución Java a la pregunta. –

+0

Oh, asumí que ya usabas '-O2', ya que el programa revisado con un patrón de explosión se ejecutó en 16 segundos en mi máquina :) Voy a echar un vistazo a tu solución Java. – ehird

+0

OK, he actualizado esta respuesta con más mejoras. – ehird

2

Puede perder la lista y los patrones de explosión y obtener el mismo rendimiento utilizando la pila en su lugar.

import Data.List 
import Data.Bits 

coll :: Int -> Int 
coll 0 = 0 
coll 1 = 1 
coll 2 = 2 
coll n = 
    let a = coll (n - 1) 
     collatz a 1 = a 
     collatz a x 
     | even x = collatz (a + 1) (x `shiftR` 1) 
     | otherwise = collatz (a + 1) (3 * x + 1) 
    in max a (collatz 1 n) 


main = do 
    print $ coll 100000 

Un problema con esto es que va a tener que aumentar el tamaño de la pila para obtener grandes entradas, al igual que 1_000_000.

actualización:

Aquí está una versión recursiva cola que no sufre el problema de desbordamiento de pila.

import Data.Word 
collatz :: Word -> Word -> (Word, Word) 
collatz a x 
    | x == 1 = (a,x) 
    | even x = collatz (a + 1) (x `quot` 2) 
    | otherwise = collatz (a + 1) (3 * x + 1) 

coll :: Word -> Word 
coll n = collTail 0 n 
    where 
    collTail m 1 = m 
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1) 

Aviso el uso de Word en lugar de Int. Hace una diferencia en el rendimiento. Todavía podría usar los patrones de explosión si lo desea, y eso casi duplicaría el rendimiento.

0

Una cosa que encontré hizo una diferencia sorprendente en este problema. Me quedé con la relación de recurrencia directa en lugar de plegado, debe perdonar la expresión, el conteo con ella. Reescribiendo

collatz n = if even n then n `div` 2 else 3 * n + 1 

como

collatz n = case n `divMod` 2 of 
      (n', 0) -> n' 
      _  -> 3 * n + 1 

tomó 1.2 segundos del tiempo de ejecución para mi programa en un sistema con un Athlon a 2,8 GHz II X4 430 CPU. Mi versión inicial más rápido (2,3 segundos después de que el uso de DIVMOD):

{-# LANGUAGE BangPatterns #-} 

import Data.List 
import Data.Ord 

collatzChainLen :: Int -> Int 
collatzChainLen n = collatzChainLen' n 1 
    where collatzChainLen' n !l 
      | n == 1 = l 
      | otherwise = collatzChainLen' (collatz n) (l + 1) 

collatz:: Int -> Int 
collatz n = case n `divMod` 2 of 
       (n', 0) -> n' 
       _  -> 3 * n + 1 

pairMap :: (a -> b) -> [a] -> [(a, b)] 
pairMap f xs = [(x, f x) | x <- xs] 

main :: IO() 
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999])) 

Un quizás más idiomática Haskell versión de carreras en alrededor de 9,7 segundos (8,5 con DIVMOD); que es idéntica a excepción de

collatzChainLen :: Int -> Int 
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n 

Usando Data.List.Stream se supone para permitir la fusión corriente que haría que esta versión funcione más como que con la acumulación explícita, pero no puede encontrar un paquete de Ubuntu * libghc que tiene Data.List.Stream, por lo que aún no puedo verificarlo.

Cuestiones relacionadas