2010-06-12 11 views
8

Antes que nada, soy nuevo en R (comencé ayer).Cálculo de todas las distancias entre un punto y un grupo de puntos de manera eficiente en R

I tienen dos grupos de puntos, data y centers, la primera de tamaño n y el segundo de tamaño K (por ejemplo, n = 3823 y K = 10), y para cada i en el primer conjunto, necesito encontrar j en el segundo con la distancia mínima.

Mi idea es simple: para cada i, vamos dist[j] sea la distancia entre i y j, sólo necesito utilizar which.min(dist) a encontrar lo que estoy buscando.

Cada punto es una matriz de 64 dobles, por lo

> dim(data) 
[1] 3823 64 
> dim(centers) 
[1] 10 64 

He tratado con

for (i in 1:n) { 
    for (j in 1:K) { 
    d[j] <- sqrt(sum((centers[j,] - data[i,])^2)) 
    } 
    S[i] <- which.min(d) 
} 

que es extremadamente lenta (con n = 200, se tarda más de 40 años !!). La solución más rápida que escribí es

distance <- function(point, group) { 
    return(dist(t(array(c(point, t(group)), dim=c(ncol(group), 1+nrow(group)))))[1:nrow(group)]) 
} 

for (i in 1:n) { 
    d <- distance(data[i,], centers) 
    which.min(d) 
} 

incluso si lo hace un montón de computación que no uso (porque dist(m) calcula la distancia entre todas las filas de m), que es la forma más rápida que el otro (¿Alguien puede explicar por qué?), pero no es lo suficientemente rápido para lo que necesito, porque no se usará solo una vez. Y también, el código distance es muy feo. Traté de reemplazarlo con

distance <- function(point, group) { 
    return (dist(rbind(point,group))[1:nrow(group)]) 
} 

pero esto parece ser dos veces más lento. También traté de usar dist para cada par, pero también es más lento.

No sé qué hacer ahora. Parece que estoy haciendo algo muy malo. ¿Alguna idea sobre cómo hacer esto de manera más eficiente?

ps: Necesito esto para implementar k-means a mano (y tengo que hacerlo, es parte de una tarea). Creo que solo necesitaré la distancia euclidiana, pero todavía no estoy seguro, por lo que preferiré tener algún código donde el cálculo de la distancia pueda reemplazarse fácilmente. stats::kmeans hacen todos los cálculos en menos de un segundo.

+1

personas' aquí especie-a-no-como-hacer tareas ... así que trate de concentrarse en un problema específico. – aL3xa

Respuesta

13

En lugar de iterar a través de los puntos de datos, puede condensarlos en una operación de matriz, lo que significa que solo tiene que iterar en K.

# Generate some fake data. 
n <- 3823 
K <- 10 
d <- 64 
x <- matrix(rnorm(n * d), ncol = n) 
centers <- matrix(rnorm(K * d), ncol = K) 

system.time(
    dists <- apply(centers, 2, function(center) { 
    colSums((x - center)^2) 
}) 
) 

Se ejecuta en:

utilisateur  système  écoulé 
     0.100  0.008  0.108 

en mi portátil.

+0

+1 supera mi camino para calcular la matriz de dists. Este es un buen truco con el vector de replicación automática agregado o restado de la matriz. – Marek

+0

Estoy tratando de usar su solución, pero su matriz está transpuesta.¿Hay alguna manera de restar líneas como lo hizo con las columnas? – dbarbosa

+0

Probé la resta con líneas usando aplicar pero no fue tan rápido como tu solución. ¡Ahora estoy transponiendo la matriz y usando tu código y es realmente rápido! ¡¡¡Muchas gracias!!! Y también, gracias por su respuesta completa con un pequeño ejemplo y el uso de system.time. Muchas gracias :) – dbarbosa

1

Es posible que desee echar un vistazo a las funciones apply.

Por ejemplo, este código

for (j in 1:K) 
    { 
    d[j] <- sqrt(sum((centers[j,] - data[i,])^2)) 
    } 

Puede ser fácilmente sustituido por algo así como

dt <- data[i,] 
d <- apply(centers, 1, function(x){ sqrt(sum(x-dt)^2)}) 

Puede definitivamente optimizarlo más, pero usted consigue el punto espero

+0

Gracias ... Es un código más rápido que el primero que escribí, pero ni siquiera cerca del extraño que usa 'distance'. – dbarbosa

+1

@dbarbosa: bueno, aparentemente el paquete 'stats :: kmeans' usa código compilado que es obviamente más rápido. Simplemente escriba 'kmeans' y verá el código fuente para ello. :) – nico

1

dist funciona rápido porque no está vectorizado y llama a las funciones C internas.
El código en el bucle se puede vectorizar de muchas formas.

Por ejemplo, para calcular la distancia entre data y centers podría utilizar outer:

Esto le da n x K matriz de distancias. Y debería ser mucho más rápido que loop.

Luego puede usar max.col para encontrar el máximo en cada fila (consulte la ayuda, hay algunos matices cuando hay muchos máximos). X debe ser negado porque buscamos un mínimo.

CL <- max.col(-X) 

Para ser eficiente en R debe vectorizar como sea posible. Los bucles podrían reemplazarse en muchos casos por un sustituto vectorizado. Consulte la ayuda para rowSums (que también describe rowMeans, colSums, rowSums), pmax, cumsum. Puede buscar SO, p. Ej. https://stackoverflow.com/search?q=[r]+avoid+loop (copia & pegue este enlace, no sé cómo hacer que se pueda hacer clic) para algunos ejemplos.

+0

Hola, estoy tratando de usar tu código pero no está funcionando. Intenté usarlo con el mismo código que @Jonathan Chang, y agregué: 'system.time (outer (seq_len (n), seq_len (K), function (i, j) sqrt (rowSums ((x [, i] -centers [, j])^2)))) ', pero obtengo este error: ' Error en dim (robj) <- c (dX, dY): dims [producto 38230] no coincide con la longitud del objeto [64] ' ¿Ves lo que está mal? – dbarbosa

+0

En realidad, no entendía 'outer' (pensé que estaba llamando a la función una vez por cada par). Ahora lo entiendo, gracias, ¡puede ser útil! Y también, gracias por contar acerca de 'max.col'. – dbarbosa

0

Mi solución:

# data is a matrix where each row is a point 
# point is a vector of values 
euc.dist <- function(data, point) { 
    apply(data, 1, function (row) sqrt(sum((point - row)^2))) 
} 

se puede probar, como:

x <- matrix(rnorm(25), ncol=5) 
euc.dist(x, x[1,]) 
3

rdist() es una función R de {campos} paquete que es capaz de calcular distancias entre dos conjuntos de puntos en formato de matriz rápidamente.

https://www.image.ucar.edu/~nychka/Fields/Help/rdist.html

Uso: redonda

library(fields) 
#generating fake data 
n <- 5 
m <- 10 
d <- 3 

x <- matrix(rnorm(n * d), ncol = d) 
y <- matrix(rnorm(m * d), ncol = d) 

rdist(x, y) 
      [,1]  [,2]  [,3]  [,4]  [,5] 
[1,] 1.512383 3.053084 3.1420322 4.942360 3.345619 
[2,] 3.531150 4.593120 1.9895867 4.212358 2.868283 
[3,] 1.925701 2.217248 2.4232672 4.529040 2.243467 
[4,] 2.751179 2.260113 2.2469334 3.674180 1.701388 
[5,] 3.303224 3.888610 0.5091929 4.563767 1.661411 
[6,] 3.188290 3.304657 3.6668867 3.599771 3.453358 
[7,] 2.891969 2.823296 1.6926825 4.845681 1.544732 
[8,] 2.987394 1.553104 2.8849988 4.683407 2.000689 
[9,] 3.199353 2.822421 1.5221291 4.414465 1.078257 
[10,] 2.492993 2.994359 3.3573190 6.498129 3.337441 
Cuestiones relacionadas