2011-12-29 60 views
27

Al dibujar un diagrama de puntos usando matplotlib, me gustaría desplazar puntos de datos superpuestos para mantenerlos todos visibles. Por ejemplos, si tengoMatplotlib: evitando la superposición de puntos de datos en un gráfico "scatter/dot/beeswarm"

CategoryA: 0,0,3,0,5 
CategoryB: 5,10,5,5,10 

quiero cada uno de los puntos de datos "0" CategoryA que se fijará al lado del otro, en lugar de a la derecha en la parte superior de uno al otro, sin dejar de ser distinto de CategoryB.

En R (ggplot2) hay una opción "jitter" que hace esto. ¿Hay una opción similar en matplotlib, o hay otro enfoque que conduzca a un resultado similar?

Editar: para aclarar, the "beeswarm" plot in R es esencialmente lo que tengo en mente, y pybeeswarm es un comienzo temprano, pero útil en una versión matplotlib/Python.

Editar: añadir que Seaborn de Swarmplot, introducida en la versión 0.7, es una excelente aplicación de lo que quería.

+0

En un [gráfico de puntos] (http://en.wikipedia.org/wiki/Dot_plot_ (estadísticas)) estos puntos ya están separados en su columna – joaquin

+1

La definición wiki de "gráfico de puntos" no es lo que estoy tratando de describir, pero nunca he oído hablar de un término que no sea "gráfico de puntos" para él. Es aproximadamente un diagrama de dispersión pero con etiquetas x arbitrarias (no necesariamente numéricas). Así, en el ejemplo que describo en la pregunta, habría una columna de valores para "Categoría A", una segunda columna para "Categoría B", etc. (_Editar_: la definición de wikipedia de "Plan de puntos de Cleveland" es más similar a lo que Estoy buscando, aunque todavía no es exactamente lo mismo.) – iayork

Respuesta

6

Sin saber de una alternativa directa MPL aquí se tiene una propuesta muy rudimentaria:

from matplotlib import pyplot as plt 
from itertools import groupby 

CA = [0,4,0,3,0,5] 
CB = [0,0,4,4,2,2,2,2,3,0,5] 

x = [] 
y = [] 
for indx, klass in enumerate([CA, CB]): 
    klass = groupby(sorted(klass)) 
    for item, objt in klass: 
     objt = list(objt) 
     points = len(objt) 
     pos = 1 + indx + (1 - points)/50. 
     for item in objt: 
      x.append(pos) 
      y.append(item) 
      pos += 0.04 

plt.plot(x, y, 'o') 
plt.xlim((0,3)) 

plt.show() 

enter image description here

7

Solía ​​numpy.random a "dispersión/beeswarm" los datos a lo largo del eje X, sino en todo un punto fijo para cada categoría, y luego, básicamente, hacer pyplot.scatter() para cada categoría:

import matplotlib.pyplot as plt 
import numpy as np 

#random data for category A, B, with B "taller" 
yA, yB = np.random.randn(100), 5.0+np.random.randn(1000) 

xA, xB = np.random.normal(1, 0.1, len(yA)), 
     np.random.normal(3, 0.1, len(yB)) 

plt.scatter(xA, yA) 
plt.scatter(xB, yB) 
plt.show() 

X-scattered data

29

La extensión de la respuesta por @ user2467675, así es como lo hice:

def rand_jitter(arr): 
    stdev = .01*(max(arr)-min(arr)) 
    return arr + np.random.randn(len(arr)) * stdev 

def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs): 
    return scatter(rand_jitter(x), rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, verts=verts, hold=hold, **kwargs) 

La variable stdev se asegura de que la fluctuación es suficiente para ser visto en diferentes escalas, pero se supone que los límites de los ejes son 0 y el valor máximo.

A continuación, puede llamar al jitter en lugar de scatter.

+0

Me gusta mucho su cálculo automático de la escala de jitter. Funciona bien para mi –

+0

¿Funciona esto si 'arr' contiene solo ceros (es decir, stdev = 0)? – Dataman

5

Una manera de abordar el problema es pensar en cada una 'fila' en su Dispersión/Puntos parcela/beeswarm como una papelera en un histograma:

data = np.random.randn(100) 

width = 0.8  # the maximum width of each 'row' in the scatter plot 
xpos = 0  # the centre position of the scatter plot in x 

counts, edges = np.histogram(data, bins=20) 

centres = (edges[:-1] + edges[1:])/2. 
yvals = centres.repeat(counts) 

max_offset = width/counts.max() 
offsets = np.hstack((np.arange(cc) - 0.5 * (cc - 1)) for cc in counts) 
xvals = xpos + (offsets * max_offset) 

fig, ax = plt.subplots(1, 1) 
ax.scatter(xvals, yvals, s=30, c='b') 

Obviamente, esto implica a agrupar los datos, por lo que puede perder algo de precisión.Si tiene datos discretos, puede reemplazar:

counts, edges = np.histogram(data, bins=20) 
centres = (edges[:-1] + edges[1:])/2. 

con:

centres, counts = np.unique(data, return_counts=True) 

Un enfoque alternativo que conserva las exactas coordenadas, incluso para los datos continuos, es utilizar un kernel density estimate para escalar la amplitud de la fluctuación de fase aleatoria en el eje x:

from scipy.stats import gaussian_kde 

kde = gaussian_kde(data) 
density = kde(data)  # estimate the local density at each datapoint 

# generate some random jitter between 0 and 1 
jitter = np.random.rand(*data.shape) - 0.5 

# scale the jitter by the KDE estimate and add it to the centre x-coordinate 
xvals = 1 + (density * jitter * width * 2) 

ax.scatter(xvals, data, s=30, c='g') 
for sp in ['top', 'bottom', 'right']: 
    ax.spines[sp].set_visible(False) 
ax.tick_params(top=False, bottom=False, right=False) 

ax.set_xticks([0, 1]) 
ax.set_xticklabels(['Histogram', 'KDE'], fontsize='x-large') 
fig.tight_layout() 

Este segundo cumplido hod está basado libremente en cómo funciona violin plots. Todavía no puedo garantizar que ninguno de los puntos se superponga, pero me parece que en la práctica tiende a dar resultados bastante agradables siempre que haya una cantidad decente de puntos (> 20), y la distribución se puede aproximar razonablemente bien por una suma de gaussianos.

enter image description here

3

Seaborn ofrece histograma como categóricas punto-parcelas a través sns.swarmplot() y jittered categóricas punto-a través de parcelas sns.stripplot():

import seaborn as sns 

sns.set(style='ticks', context='talk') 
iris = sns.load_dataset('iris') 

sns.swarmplot('species', 'sepal_length', data=iris) 
sns.despine() 

enter image description here

sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2) 
sns.despine() 

enter image description here

1

swarmplot de Seaborn parece que el ajuste más apto para lo que tiene en mente, pero también puede fluctuar con regplot de Seaborn:

import seaborn as sns 
iris = sns.load_dataset('iris') 

sns.regplot(x='sepal_length', 
      y='sepal_width', 
      data=iris, 
      fit_reg=False, # do not fit a regression line 
      x_jitter=0.1, # could also dynamically set this with range of data 
      y_jitter=0.1, 
      scatter_kws={'alpha': 0.5}) # set transparency to 50% 
Cuestiones relacionadas