What is K-Nearest Neighbors (KNN) Algorithm in Machine Learning? An Essential Guide

What is K-Nearest Neighbors (KNN) Algorithm in Machine Learning? An Essential Guide
The K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can solve classification and regression problems. In this comprehensive article from Zilliz, a leading vector database company for production-ready AI, we’ll answer questions such as: what is KNN, how does KNN work, what is KNN in machine learning, why is KNN needed, and what are some ways to improve KNN? We’ll also demonstrate a KNN model implementation using Python.
What is K-Nearest Neighbor ?
Let’s start with what the KNN algorithm is essentially for. The K Nearest Neighbor algorithm estimates the likelihood that a data point will belong to one of two groups based on which data points are closest to it.
A KNN algorithm can be used for classification as well as regression problems. It is categorized as a lazy learner, which means it only stores a training dataset rather than going through a training stage.
Additionally, it implies that all computation is performed when a classification or prediction is made. Since it uses memory to store all of its training data, it is also known as memory-based learning.
KNN has two key characteristics. Firstly, KNN is a non-parametric algorithm. It means no assumptions about the dataset are made when the model is used. Instead, the model is built entirely from the data that is provided.
Secondly, when using KNN, the dataset is not divided into training and test sets. This is because KNN does not differentiate between training and testing sets and instead uses all available data when asked to make predictions.
How to calculate K-Nearest Neighbor algorithm and how the algorithm works
To determine the class of an unobserved observation, KNN essentially uses a voting mechanism. It indicates that the class that receives the most votes will be the class for the relevant data point.
If K is equal to 1, we will only consider a data point's closest neighbor when determining its class. The 10 closest neighbors will be used if K is equal to 10, and so on.
How KNN works between two classes. Source: https://www.ibm.com/in-en/topics/knn
Consider two classes: A and B. The algorithm examines the states of data points nearby to determine whether a data point belongs to Class A or Class B. It is almost certain that the data point in question belongs to group A if most of the data points are in that group.
Now, you may wonder how the distance metric is calculated to determine whether or not a data point is a neighbor, right? There are numerous methods for calculating the distance between a data point and its nearest neighbor. They include Euclidean distance, Manhattan distance, Hamming distance, Cosine distance, Jaccard distance, Minkowski distance, and a few others.
This chart tells us what these distance metrics are all about:
The majority of steps in K-Nearest Neighbors regression are the same as classification. Instead of assigning the class with the most votes, the unknown data point is assigned the average of its neighbors’ values.
The Value of ‘K’ in K-Nearest Neighbors Classification
Choosing the correct value of K is known as hyperparameter tuning, which is required for better results. There is no defined method for determining the best value of K. It is determined by the specific type of problem.
The K value specifies how many neighbors will be checked to determine the K Nearest Neighbor classifier of a specific query point. If k=1, for example, the instance is assigned to the same class as its single nearest neighbor.
Different values of K can lead to overfitting or underfitting, so defining it can be a balancing act. Lower K values can have high variance but low bias, while higher K values can have high bias but low variance.
Why Do We Need the KNN Algorithm?
KNN makes highly accurate predictions. It can compete with the most accurate SOTA models (State-of-the-art models). As a result, the K Nearest Neighbor algorithm can be used for applications that require high accuracy but, at the same time, do not require a human-readable model.
The accuracy of the predictions is determined by the distance measured. Thus, the KNN algorithm is appropriate for applications with sufficient domain knowledge. This understanding helps select an appropriate measure.
Improving the K-Nearest Neighbors: how to do it
Normalizing data on the same scale is recommended for better results. In general, the normalization range is between 0 and 1. Apart from this, hyperparameter tuning of K and distance metric are also critical.
We can test the K Nearest Neighbor algorithm with different values of K using the cross-validation technique. The model with the highest accuracy can be considered the best option.
K Nearest Neighbor Example: Python Implementation of KNN Algorithm
Let’s now get into the implementation of the KNN model in Python. We are using Python 3.8.5 for this in Jupyter notebook. We’ll go over the steps to help you break the code down.
Here it goes:
Importing the modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
Dataset
Scikit-learn can be used for creating synthetic datasets, which are great for demo purpose.
X, y = make_blobs(n_samples = 4000, n_features = 3, centers = 3 ,cluster_std = 2, random_state = 80)
X
array([[ 7.60190561, 4.86336321, 6.97616573],
[ 5.97809745, 7.69910922, 2.77419701],
[-4.36024844, -2.23247572, -5.29113293],
...,
[-8.22252297, -6.88609334, -6.52102135],
[-3.96254707, -5.27559922, -2.70880022],
[-4.25865881, -1.67791521, -3.70523373]])
y
array([1, 1, 2, ..., 2, 2, 2])
Plot
plt.figure(figsize = (6,6))
plt.scatter(X[:,0], X[:,1], c=y, marker= '.', s=10, edgecolors='blue')
plt.show()
df = pd.DataFrame(X)
df.head()
plt.rcParams['figure.figsize']=(10,15)
df.plot(kind='hist', bins=100, subplots=True, layout=(5,2), sharex=False, sharey=False)
plt.show()
the K-Nearest Neighbors Classifier implementation
The first step is to figure out the k. The calculation of the K value varies greatly depending on the situation. The default value of K when using the Scikit-Learn Library is 5 and the default distance metric used is Euclidean.
Tuning the Model to Get High K Nearest Neighbor Accuracy
from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors':np.arange(1,4)}
knn = KNeighborsClassifier()
knn_cv= GridSearchCV(knn,param_grid,cv=5)
knn_cv.fit(X,y)
print(knn_cv.best_params_)
print(knn_cv.best_score_)
{'n_neighbors': 3}
0.9887499999999999
#train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 80)
# instantiate the model
knn = KNeighborsClassifier(n_neighbors=3)
# fit the model to the training set
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print('Model accuracy score: {0:0.4f}'. format(accuracy_score(y_test, y_pred)))
Model accuracy score: 0.9890.
We got an accuracy rate of 98.90%, which is considered very good. We increased the number of neighbours from 1 to 4, and the model performed best at k=3.
Crucial benefits of KNN
1. Time-efficient
The K Nearest Neighbor model does not involve any training period since the data itself is a model that will be the reference for future prediction. As a result, it is time-efficient, enabling quick improvisation for random modeling on the available data.
2. Simple to tune
KNN only requires two hyperparameters, a K value and a distance metric, making it simpler to tune than other machine learning algorithms.
3. Easy adaptability
Most classifier algorithms are easy to implement for binary problems but require extra effort to implement for multi-class problems. In contrast, KNN adapts to multi-class problems without any extra effort.
And a few drawbacks
1. High-dimensional data
KNN does not work well with large or high-dimensional data because calculating distances between each data instance would be prohibitively expensive.
2. Sensitive or missing data
KNN does not work well where data is sensitive to noise and where there’s missing data.
3. Unbalanced data
With unbalanced data, too, K Nearest Neighbors does not perform well.
4. Curse of dimensionality
Because of the curse of dimensionality, KNN is more prone to overfitting. While feature selection and dimensionality reduction techniques are used to prevent this, the value of K can impact the model's behavior.
When to use KNN and Why
For each prediction, the time complexity of the the K-Nearest Neighbors algorithm is O(MNlog(k), where M is the dimension of the data and N is the size or number of instances in the training data. But still, there are multiple specialized ways of organizing data to address this and make it more efficient.
Additionally, several preprocessing techniques can be used to eliminate missing data and noise and ensure that the dataset is balanced. Due to this, KNN is one of the most widely used algorithms.
The bottom line: vector databases
We can use the Cosine function and the K Nearest Neighbor algorithm to determine how similar or different two sets of items are and then use that information to classify them.
In high-dimensional space, the Cosine function is used to calculate the similarity or distance between observations. These high-dimensional data make the computation process very slow if used through a traditional database; hence, they can be stored in a vector database.
This is one of the many use cases of vector databases. And this is precisely where Zilliz steps in, especially if you’re wondering ‘what is a vector database?’ and whether to explore it.
There’s no denying that vector databases are the need of the hour in the modern era of AI. Zilliz offers a one-stop solution for challenges in handling unstructured data, especially for enterprises that build AI/ML applications that leverage vector similarity search.
- What is K-Nearest Neighbor ?
- How to calculate K-Nearest Neighbor algorithm and how the algorithm works
- Why Do We Need the KNN Algorithm?
- Improving the K-Nearest Neighbors: how to do it
- K Nearest Neighbor Example: Python Implementation of KNN Algorithm
- Crucial benefits of KNN
- And a few drawbacks
- When to use KNN and Why
- The bottom line: vector databases