2010-04-21 12 views
26

R tiene una función útil pairs que proporciona una buena matriz de gráficos de conexiones por pares entre variables en un conjunto de datos. El gráfico resultante tiene una apariencia similar a la siguiente figura, copiado de this blog post:matplotlib analógico de pares de R '

pairs

¿Hay alguna lista para utilizar la función basada en matplotlib de pitón? He buscado en gallery, pero no he podido encontrar nada que se parezca a lo que necesito. Técnicamente, esto debería ser una tarea simple, pero el manejo adecuado de todos los casos posibles, etiquetas, títulos, etc. es muy tedioso.

ACTUALIZACIÓN vea a continuación mi respuesta con una aproximación rápida y sucia.

+0

Seaborn tiene esto, ver: http://seaborn.pydata.org/generated/seaborn. pairplot.html –

Respuesta

33

Pandas ha construido en función de scatter_matrix (source code) que es algo como esto.

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 

df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D']) 
axes = pd.tools.plotting.scatter_matrix(df, alpha=0.2) 
plt.tight_layout() 
plt.savefig('scatter_matrix.png') 

scatter_matrix.png

Sin embargo, es pandas específico (pero podría ser utilizado como punto de partida).

Hay algunos más R como parcelas en pandas. Eche un vistazo al docs.

0

Hasta donde yo sé, no hay una función lista para usar como esa.

3

aproximación rápida y sucia para mis necesidades:

def pair(data, labels=None): 
    """ Generate something similar to R `pair` """ 

    nVariables = data.shape[1] 
    if labels is None: 
     labels = ['var%d'%i for i in range(nVariables)] 
    fig = pl.figure() 
    for i in range(nVariables): 
     for j in range(nVariables): 
      nSub = i * nVariables + j + 1 
      ax = fig.add_subplot(nVariables, nVariables, nSub) 
      if i == j: 
       ax.hist(data[:,i]) 
       ax.set_title(labels[i]) 
      else: 
       ax.plot(data[:,i], data[:,j], '.k') 

    return fig 

El código anterior queda liberado al dominio público

+0

Para mí hay un valor extra en el código del módulo base. Esto es muy claro, lo tomaré como instructivo para las tareas de manipulación de datos. Una pregunta: ¿qué tipos de objetos pueden ser 'datos'? – Merlin

2

La función subplots en las versiones recientes de matplotlib (al menos 1.4) hace que esta un poco más fácil:

def pairs(data, names): 
    "Quick&dirty scatterplot matrix" 
    d = len(data) 
    fig, axes = plt.subplots(nrows=d, ncols=d, sharex='col', sharey='row') 
    for i in range(d): 
     for j in range(d): 
      ax = axes[i,j] 
      if i == j: 
       ax.text(0.5, 0.5, names[i], transform=ax.transAxes, 
         horizontalalignment='center', verticalalignment='center', 
         fontsize=16) 
      else: 
       ax.scatter(data[j], data[i], s=10) 
+1

Si se aplica en un estándar de conjunto de datos para scikit-learn, 'len (data)' devolverá el número de observaciones, no el número de variables. 'd = data.shape [1]' es la solución en este caso. –

Cuestiones relacionadas