2012-08-28 35 views
14

Permítanme comenzar diciendo que no tengo experiencia con R, KNN o ciencia de datos en general. Recientemente encontré Kaggle y he estado jugando con la competencia/tutorial Digit Recognition.¿Cómo ver los vecinos más cercanos en R?

En este tutorial se proporcionan algunos ejemplos de código para que pueda empezar con una presentación básica:

# makes the KNN submission 

library(FNN) 

train <- read.csv("c:/Development/data/digits/train.csv", header=TRUE) 
test <- read.csv("c:/Development/data/digits/test.csv", header=TRUE) 

labels <- train[,1] 
train <- train[,-1] 

results <- (0:9)[knn(train, test, labels, k = 10, algorithm="cover_tree")] 

write(results, file="knn_benchmark.csv", ncolumns=1) 

Mis preguntas son:

  1. ¿Cómo puedo ver los vecinos más cercanos que han sido seleccionados para una fila de prueba particular ?
  2. ¿Cómo puedo modificar cuál de esas diez está seleccionada para mi results?

Estas preguntas pueden ser demasiado amplias. De ser así, agradecería cualquier enlace que pueda apuntarme por el camino correcto.

Es muy posible que he dicho algo que no tiene sentido aquí. Si este es el caso, por favor corrígeme.

Respuesta

23

1) Se puede llegar a los vecinos más cercanos de una fila determinada de esta manera:

k <- knn(train, test, labels, k = 10, algorithm="cover_tree") 
indices <- attr(k, "nn.index") 

A continuación, si desea que los índices de los 10 vecinos más cercanos a la fila 20 en el conjunto de entrenamiento:

print(indices[20, ]) 

(Obtendrá los 10 vecinos más cercanos porque seleccionó k=10). Por ejemplo, si se ejecuta con sólo los primeros 1000 filas de la formación y las pruebas de ajuste (para que sea computacionalmente más fácil):

train <- read.csv("train.csv", header=TRUE)[1:1000, ] 
test <- read.csv("test.csv", header=TRUE)[1:1000, ] 

labels <- train[,1] 
train <- train[,-1] 

k <- knn(train, test, labels, k = 10, algorithm="cover_tree") 
indices = attr(k, "nn.index") 

print(indices[20, ]) 
# output: 
# [1] 829 539 784 487 293 882 367 268 201 277 

Esos son los índices dentro del conjunto de entrenamiento de 1000 que están más cerca de la fila 20 del conjunto de prueba.

2) Depende de lo que quiere decir con "modificar". Para empezar, usted puede obtener los índices de cada una de las 10 etiquetas más cercanos a cada fila de esta manera:

closest.labels = apply(indices, 2, function(col) labels[col]) 

A continuación, puede ver las etiquetas de los 10 puntos más cercanos al punto de formación 20 de la siguiente manera:

closest.labels[20, ] 
# [1] 0 0 0 0 0 0 0 0 0 0 

Esto indica que los 10 puntos más cercanos a la fila 20 están todos en el grupo etiquetado 0. knn simplemente elige la etiqueta por mayoría de votos (con vínculos rotos aleatoriamente), pero puede elegir algún tipo de esquema de ponderación si prefieres.

ETA: Si usted está interesado en la ponderación de los elementos más estrechos en mayor medida en su esquema de votación, tenga en cuenta que también se puede obtener las distancias para cada uno de los k vecinos como esta:

dists = attr(k, "nn.dist") 
dists[20, ] 
# output: 
# [1] 1238.777 1243.581 1323.538 1398.060 1503.371 1529.660 1538.128 1609.730 
# [9] 1630.910 1667.014 
+0

respuesta maravilloso, gracias ¡tú! Tenía algunas preguntas. Cada vez que trato de imprimir 'índices', devuelve nulo, ¿debería hacer algo diferente de tu ejemplo? ¿Puede recomendar algún recurso para investigar más sobre cómo crear un esquema de ponderación personalizado? ¿O ejemplos de alguien que crea uno que pueda mirar? –

+0

Eso es muy extraño.¿Qué obtienes si haces 'print (k)'? En cuanto a otros esquemas de ponderación, tendrías tanta suerte como buscar la frase "ponderado por KNN" en Google. Pero estoy escribiendo un poco más sobre la ponderación en mi respuesta. –

+0

Ok, solo para aclarar que estoy usando 'results' en lugar de' k'. Supongo que esto no hace la diferencia, pero pensé que debería arrojar eso por ahí. Cuando lo hago 'print (results)' Imprime los 1000 elementos que finalmente se escriben en mi archivo csv. –

Cuestiones relacionadas