2010-12-02 6 views
9

PyBrain es una biblioteca de Python que proporciona (entre otras cosas) redes neuronales artificiales fáciles de usar.¿Cómo serializar/redes de pybrain deserializadas?

No serializo/deserializo correctamente las redes PyBrain usando pickle o cPickle.

Véase el siguiente ejemplo:

from pybrain.datasets   import SupervisedDataSet 
from pybrain.tools.shortcuts  import buildNetwork 
from pybrain.supervised.trainers import BackpropTrainer 
import cPickle as pickle 
import numpy as np 

#generate some data 
np.random.seed(93939393) 
data = SupervisedDataSet(2, 1) 
for x in xrange(10): 
    y = x * 3 
    z = x + y + 0.2 * np.random.randn() 
    data.addSample((x, y), (z,)) 

#build a network and train it  

net1 = buildNetwork(data.indim, 2, data.outdim) 
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True) 
for i in xrange(4): 
    trainer1.trainEpochs(1) 
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0]) 

Ésta es la salida del código anterior:

Total error: 201.501998476 
    value after 0 epochs: 2.79 
Total error: 152.487616382 
    value after 1 epochs: 5.44 
Total error: 120.48092561 
    value after 2 epochs: 7.56 
Total error: 97.9884043452 
    value after 3 epochs: 8.41 

Como se puede ver, la red de error total disminuye a medida que progresa la formación. También se puede ver que el valor predicho se acerca al valor esperado de 12.

Ahora vamos a hacer un ejercicio similar, pero incluirá serialización/deserialización:

print 'creating net2' 
net2 = buildNetwork(data.indim, 2, data.outdim) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
trainer2.trainEpochs(1) 
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0]) 

#So far, so good. Let's test pickle 
pickle.dump(net2, open('testNetwork.dump', 'w')) 
net2 = pickle.load(open('testNetwork.dump')) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
print 'loaded net2 using pickle, continue training' 
for i in xrange(1, 4): 
     trainer2.trainEpochs(1) 
     print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0]) 

Ésta es la salida de este bloque:

creating net2 
Total error: 176.339378639 
    value after 1 epochs: 5.45 
loaded net2 using pickle, continue training 
Total error: 123.392181859 
    value after 1 epochs: 5.45 
Total error: 94.2867637623 
    value after 2 epochs: 5.45 
Total error: 78.076711114 
    value after 3 epochs: 5.45 

Como se puede ver, parece que la formación tiene algún efecto sobre la red (el valor de error total reportado sigue disminuyendo), sin embargo, el valor de salida de la red se congela en un valor que era relevante para la primera iteración de entrenamiento.

¿Hay algún mecanismo de caché que deba tener en cuenta que cause este comportamiento erróneo? ¿Hay mejores formas de serializar/deserializar las redes de Pybrain?

números de versión pertinentes:

  • Python 2.6.5 (R265: 79096, el 19 de mar 2010, 21:48:26) [MSC v.1500 32 bit (Intel)]
  • NumPy 1.5. 1
  • cPickle 1,71
  • pybrain 0,3

PS He creado a bug report en el sitio del proyecto y mantener tanto el SO y el seguimiento de errores updatedj

+0

¿Estás seguro de que no deberías hacer 'trainer2 = BackpropTrainer (net2, dataset = data, verbose = True)' nuevamente después de volver a cargar 'net2'? –

+0

@Seth Johnson Por supuesto que sí, pero hacerlo no resuelve el problema. En realidad, mi código de prueba incluía esa línea, pero por error la cometí al pegar aquí. Solucionado ahora –

Respuesta

11

Causa

El mecanismo que hace que este comportamiento es el manejo de parámetros (.params) y derivados (.derivs) en PyBrain módulos: de hecho, todos los parámetros de red se almacenan en una matriz, pero los objetos individuales Module o Connection tienen acceso a "su propia" .params, que, sin embargo, son solo una vista en una porción de la matriz total. Esto permite escrituras y lecturas tanto locales como a nivel de red en la misma estructura de datos.

Aparentemente este enlace slice-view se pierde al descascarillar.

Solución

Insertar

net2.sorted = False 
net2.sortModules() 

después de la carga del archivo (que recrea este intercambio), y debería funcionar.

Cuestiones relacionadas