What is the K-Nearest Neighbors (KNN) Algorithm in Machine Learning?
Latest Update: July 31
The K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can solve classification and regression problems. In this comprehensive article from Zilliz, a leading vector database company for production-ready AI, we'll answer questions such as: what is KNN, how does KNN work, what is KNN in machine learning, why you need KNN, and what are some ways to improve KNN? We'll also demonstrate the implementation of a KNN model using Python.
What is K-Nearest Neighbor (KNN)?
As mentioned, the K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can solve classification and regression problems. The KNN algorithm estimates the likelihood that a data point will belong to one of two groups based on which data points are closest to it. A KNN algorithm can be used for classification and regression tasks as well as regression problems. It is categorized as a lazy learner, which means it only stores a training dataset rather than going through a training stage. Additionally, it implies that all computation is performed when a classification or prediction is made. Since it uses memory to store all of its training data, it is also known as memory-based learning.
KNN has two key characteristics. Firstly, KNN is a non-parametric algorithm. It means no assumptions about the dataset are made when the model is used. Instead, the model is built entirely from the data that is provided. Secondly, when using KNN, the dataset is not divided into training and test data sets. This is because KNN does not differentiate between training and testing sets and instead uses all available data when asked to make predictions.
How to calculate the K-Nearest Neighbor (KNN) algorithm and how the algorithm works
To determine the class of an unobserved data point based on observation, the K-Nearest Neighbor essentially uses a voting mechanism. It indicates that the class that receives the most votes will be the class for the relevant data point.
If K is equal to 1, we will only consider a data point's closest neighbor when determining its class. The 10 closest neighbors will be used if K is equal to 10, and so on. The chart below describes how KNN works between two classes of training points.
How KNN works between two classes. Source: https://www.ibm.com/in-en/topics/knn
Consider two classes: A and B. The algorithm examines the states of nearby data points to determine whether a data point belongs to Class A or Class B. If most of the data points are in group A, it is almost certain that the data point in question belongs to group A.
Now, you may wonder how the distance metric is calculated to determine whether or not a data point is a neighbor, right? There are numerous methods for calculating the distance between a data point and its nearest neighbor. These distance metrics include Euclidean distance, Cosine distance, Jaccard distance, Hamming distance, and a few others.
Euclidean Distance is the true straight-line distance between two points in Euclidean space.
Cosine Distance is primarily used to calculate the similarity of two vectors.
Jaccard Distance or Jaccard Index examines both data sets and finds the incident where both values are equal to one.
Hamming Distance is used to examine whether the value of a given data point is equal to the value of the data point from which the distance is being measured when dealing with categorical data.
This chart tells us what these distance metrics are all about:
Distance metrics.
Distance metrics.
The majority of steps in K Nearest Neighbors regression are the same as classification. Instead of assigning the class of target data point with the most votes, the unknown data point is assigned the average of its neighbors' values.
The Value of ‘K’ in K-Nearest Neighbors Classification
Choosing the correct value of K is known as hyperparameter tuning, and it is required for better results. There is no defined method for determining the best value of K; it is determined by the specific type of problem.
The K value specifies how many neighbors will be checked to determine the K Nearest Neighbor classifier of a specific query point. If k=1, for example, the instance is assigned to the same class as its single nearest neighbor.
Different values of K can lead to overfitting or underfitting of new data, so defining it can be a balancing act. Lower K values can have high variance but low bias, while higher K values can have high bias but low variance.
Why Do We Need the KNN Algorithm?
KNN makes highly accurate predictions. It can compete with the most accurate SOTA models (State-of-the-art models). As a result, the K Nearest Neighbor algorithm can be used for applications that require high accuracy but do not require a human-readable model.
The accuracy of the predictions is determined by the distance measured. Thus, the KNN algorithm is appropriate for applications with sufficient domain knowledge. This understanding helps select an appropriate measure.
Improving the K-Nearest Neighbors: how to do it
Normalizing training data points on the same scale is recommended for better results. In general, the normalization range is between 0 and 1. Also, hyperparameter tuning of K and distance metrics are critical.
We can test the K Nearest Neighbor algorithm with different values of K using the cross-validation technique. The model with the highest accuracy can be considered the best option.
K-Nearest Neighbor Example: Python Implementation of KNN Algorithm
Let's now get into the implementation of the KNN model in Python. We are using Python 3.8.5 for this in Jupyter notebook. We'll go over the steps to help you break the code down.
Here it goes:
Importing the modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
Dataset
Scikit-learn can be used for training samples creating synthetic datasets, which are great for demo purpose.
X, y = make_blobs(n_samples = 4000, n_features = 3, centers = 3 ,cluster_std = 2, random_state = 80)
X
array([[ 7.60190561, 4.86336321, 6.97616573],
[ 5.97809745, 7.69910922, 2.77419701],
[-4.36024844, -2.23247572, -5.29113293],
...,
[-8.22252297, -6.88609334, -6.52102135],
[-3.96254707, -5.27559922, -2.70880022],
[-4.25865881, -1.67791521, -3.70523373]])
y
array([1, 1, 2, ..., 2, 2, 2])
Plot
plt.figure(figsize = (6,6))
plt.scatter(X[:,0], X[:,1], c=y, marker= '.', s=10, edgecolors='blue')
plt.show()
df = pd.DataFrame(X)
df.head()
plt.rcParams['figure.figsize']=(10,15)
df.plot(kind='hist', bins=100, subplots=True, layout=(5,2), sharex=False, sharey=False)
plt.show()
The K-Nearest Neighbors Classifier implementation
The first step is to figure out optimal value for the k. The calculation of the K value varies greatly depending on the situation. The default value of K when using the Scikit-Learn Library is 5 and the default distance metric used is Euclidean.
Tuning the Model to Get High K Nearest Neighbor Accuracy
from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors':np.arange(1,4)}
knn = KNeighborsClassifier()
knn_cv= GridSearchCV(knn,param_grid,cv=5)
knn_cv.fit(X,y)
print(knn_cv.best_params_)
print(knn_cv.best_score_)
{'n_neighbors': 3}
0.9887499999999999
#train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 80)
# instantiate the model
knn = KNeighborsClassifier(n_neighbors=3)
# fit the model to the training set
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print('Model accuracy score: {0:0.4f}'. format(accuracy_score(y_test, y_pred)))
Model accuracy score: 0.9890.
We got an accuracy rate of 98.90%, which is considered very good. We increased the number of neighbours from 1 to 4, and the model performed best at k=3.
Crucial benefits of KNN algorithm
1. Time-efficient
The K Nearest Neighbor model does not involve any training period since the data itself is a model that will be the reference for future training phase prediction. As a result, it is time-efficient, enabling quick improvisation for random modeling on the available data.
2. Simple to tune
KNN only requires two hyperparameters, a K value and a distance metric, making it simpler to tune than other machine learning algorithms.
3. Easy adaptability
Most classifier algorithms are easy to implement for binary classification problems but require extra effort to implement for multi-class problems. In contrast, KNN adapts to multi-class problems without any extra effort.
The Drawbacks of the K-Nearest neighbors (KNN) algorithm
1. High-dimensional data
KNN does not work well with large or high-dimensional data because calculating distances between each data instance would be prohibitively expensive.
2. Sensitive or missing data
KNN does not work well where data is sensitive to noise and where there’s missing data.
3. Unbalanced data
With unbalanced data, too, K Nearest Neighbors does not perform well.
4. Curse of dimensionality
Because of the curse of dimensionality, KNN is more prone to overfitting. While feature selection and dimensionality reduction techniques are used to prevent this, the value of K can impact the model's behavior.
When to use KNN and Why?
For each prediction, the time complexity of the the K-Nearest Neighbors algorithm is O(MNlog(k), where M is the dimension of the data and N is the size or number of instances in the training data set. But still, there are multiple specialized ways of organizing data to address this and make it more efficient.
Additionally, several preprocessing techniques can be used to eliminate missing data and noise and ensure that the dataset is balanced. Due to this, KNN is one of the most widely used algorithms.
The bottom line: vector databases
We can use the Cosine function and the K Nearest Neighbor algorithm to determine how similar or different two sets of items are and then use that information to classify them.
In a high-dimensional vector space, the Cosine function calculates the similarity or distance between observations. These high-dimensional data slow the computation process if used through a traditional database; hence, they can be stored in a vector database.
This is one of the many use cases of vector databases. This is precisely where Zilliz steps in, especially if you're wondering, ‘What is a vector database?' and whether to explore it.
There's no denying that vector databases are the need of the hour in the modern era of AI. Zilliz offers a one-stop solution for challenges in handling , especially for enterprises that build AI/ML applications that leverage vector similarity search.
- What is K-Nearest Neighbor (KNN)?
- How to calculate the K-Nearest Neighbor (KNN) algorithm and how the algorithm works
- Why Do We Need the KNN Algorithm?
- Improving the K-Nearest Neighbors: how to do it
- K-Nearest Neighbor Example: Python Implementation of KNN Algorithm
- Crucial benefits of KNN algorithm
- The Drawbacks of the K-Nearest neighbors (KNN) algorithm
- When to use KNN and Why?
- The bottom line: vector databases
Content
Start Free, Scale Easily
Try the fully-managed vector database built for your GenAI applications.
Try Zilliz Cloud for Free