What is the K-Nearest Neighbors (KNN) Algorithm in Machine Learning?
data:image/s3,"s3://crabby-images/81743/817437bfbaf627796f35ba6e95097e9dc421c9bb" alt="What is the K-Nearest Neighbors (KNN) Algorithm in Machine Learning?"
Latest Update: March 1, 205
The K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can solve classification and regression problems. In this comprehensive article from Zilliz, a leading vector database company for production-ready AI, we'll answer questions such as: what is KNN, how does KNN work, what is KNN in machine learning, why you need KNN, and what are some ways to improve KNN? We'll also demonstrate the implementation of a KNN model using Python.
As mentioned, the K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can solve classification and regression problems. The KNN algorithm estimates the likelihood that a data point will belong to one of two groups based on which data points are closest to it. A KNN algorithm can be used for classification and regression tasks as well as regression problems. It is categorized as a lazy learner, which means it only stores a training dataset rather than going through a training stage. Additionally, it implies that all computation is performed when a classification or prediction is made. Since it uses memory to store all of its training data, it is also known as memory-based learning.
KNN has two key characteristics. Firstly, KNN is a non-parametric algorithm. It means no assumptions about the dataset are made when the model is used. Instead, the model is built entirely from the data that is provided. Secondly, when using KNN, the dataset is not divided into training and test data sets. This is because KNN does not differentiate between training and testing sets and instead uses all available data when asked to make predictions.
Introduction to KNN
Definition of KNN
The K-Nearest Neighbors (KNN) algorithm is a supervised machine learning algorithm that can be used for both classification and regression tasks. It is a simple and intuitive algorithm that makes predictions based on the majority class of the K nearest neighbors to a given data point. In essence, the KNN algorithm leverages the proximity of data points to make informed predictions, making it a versatile tool in the machine learning toolkit.
How to calculate the K-Nearest Neighbor (KNN) algorithm and how the algorithm works
To determine the class of an unobserved data point based on observation, the K-Nearest Neighbor essentially uses a majority vote mechanism. Majority voting is a fundamental process in KNN, where the algorithm classifies a data point by determining the category that most of its closest neighbors belong to. It indicates that the class that receives the most votes will be the class for the relevant data point.
If K is equal to 1, we will only consider a data point’s closest neighbor when determining its class. The 10 closest neighbors will be used if K is equal to 10, and so on. The chart below describes how KNN works between two classes of training points.
How KNN works between two classes. Source: https://www.ibm.com/in-en/topics/knn
Consider two classes: A and B. The algorithm examines the states of nearby data points to determine whether a data point belongs to Class A or Class B. If most of the data points are in group A, it is almost certain that the data point in question belongs to group A.
Now, you may wonder how the distance metric is calculated to determine whether or not a data point is a neighbor, right? There are numerous methods for calculating the distance between a data point and its nearest neighbor. These distance metrics include Euclidean distance, Cosine distance, Jaccard distance, Hamming distance, and a few others. The test point is evaluated by calculating distances to the closest training data points, which ultimately determines its classification label.
Euclidean Distance is the true straight-line distance between two points in Euclidean space.
Cosine Distance is primarily used to calculate the similarity of two vectors.
Jaccard Distance or Jaccard Index examines both data sets and finds the incident where both values are equal to one.
Hamming Distance is used to examine whether the value of a given data point is equal to the value of the data point from which the distance is being measured when dealing with categorical data.
This chart tells us what these distance metrics are all about:
Distance metrics.
Distance metrics.
The majority of steps in K Nearest Neighbors regression are the same as classification. Instead of assigning the class of target data point with the most votes, the unknown data point is assigned the average of its neighbors’ values.
Choosing the correct value of K is known as hyperparameter tuning, and it is required for better results. There is no defined method for determining the best value of K; it is determined by the specific type of problem.
The K value specifies how many neighbors will be checked to determine the K Nearest Neighbor classifier of a specific query point. If k=1, for example, the instance is assigned to the same class as its single nearest neighbor.
Different values of K can lead to overfitting or underfitting of new data, so defining it can be a balancing act. Lower K values can have high variance but low bias, while higher K values can have high bias but low variance.
KNN makes highly accurate predictions. It can compete with the most accurate SOTA models (State-of-the-art models). As a result, the K Nearest Neighbor algorithm can be used for applications that require high accuracy but do not require a human-readable model.
The accuracy of the predictions is determined by the distance measured. Thus, the KNN algorithm is appropriate for applications with sufficient domain knowledge. This understanding helps select an appropriate measure.
Normalizing training data points on the same scale is recommended for better results. In general, the normalization range is between 0 and 1. Also, hyperparameter tuning of K and distance metrics are critical.
We can test the K Nearest Neighbor algorithm with different values of K using the cross-validation technique. The model with the highest accuracy can be considered the best option.
Let’s now get into the implementation of the KNN model in Python. We are using Python 3.8.5 for this in Jupyter notebook. We’ll go over the steps to help you break the code down.
Here it goes:
Importing the modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
Dataset
Scikit-learn can be used for training samples creating synthetic datasets, which are great for demo purpose.
X, y = make_blobs(n_samples = 4000, n_features = 3, centers = 3 ,cluster_std = 2, random_state = 80)
X
array([[ 7.60190561, 4.86336321, 6.97616573],
[ 5.97809745, 7.69910922, 2.77419701],
[-4.36024844, -2.23247572, -5.29113293],
...,
[-8.22252297, -6.88609334, -6.52102135],
[-3.96254707, -5.27559922, -2.70880022],
[-4.25865881, -1.67791521, -3.70523373]])
y
array([1, 1, 2, ..., 2, 2, 2])
Plot
plt.figure(figsize = (6,6))
plt.scatter(X[:,0], X[:,1], c=y, marker= '.', s=10, edgecolors='blue')
plt.show()
df = pd.DataFrame(X)
df.head()
plt.rcParams['figure.figsize']=(10,15)
df.plot(kind='hist', bins=100, subplots=True, layout=(5,2), sharex=False, sharey=False)
plt.show()
The K-Nearest Neighbors Classifier implementation
The first step is to figure out optimal value for the k. The calculation of the K value varies greatly depending on the situation. The default value of K when using the Scikit-Learn Library is 5 and the default distance metric used is Euclidean.
Tuning the Model to Get High K Nearest Neighbor Accuracy
from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors':np.arange(1,4)}
knn = KNeighborsClassifier()
knn_cv= GridSearchCV(knn,param_grid,cv=5)
knn_cv.fit(X,y)
print(knn_cv.best_params_)
print(knn_cv.best_score_)
{'n_neighbors': 3}
0.9887499999999999
#train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 80)
# instantiate the model
knn = KNeighborsClassifier(n_neighbors=3)
# fit the model to the training set
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
print('Model accuracy score: {0:0.4f}'. format(accuracy_score(y_test, y_pred)))
Model accuracy score: 0.9890.
We got an accuracy rate of 98.90%, which is considered very good. We increased the number of neighbours from 1 to 4, and the model performed best at k=3.
The K Nearest Neighbor model does not involve any training period since the data itself is a model that will be the reference for future training phase prediction. As a result, it is time-efficient, enabling quick improvisation for random modeling on the available data.
KNN only requires two hyperparameters, a K value and a distance metric, making it simpler to tune than other machine learning algorithms.
Most classifier algorithms are easy to implement for binary classification problems but require extra effort to implement for multi-class problems. In contrast, KNN adapts to multi-class problems without any extra effort.
Main Mechanism
The main mechanism of the KNN algorithm involves identifying the K nearest neighbors to a given data point and using their class labels to make a prediction. For classification tasks, the algorithm assigns the class that is most common among the K nearest neighbors. For regression tasks, it averages the values of the K nearest neighbors to predict the value of the new data point. This approach is widely used in various domains due to its simplicity and effectiveness, allowing it to handle both classification and regression problems with ease.
Distance Metrics Used in KNN
The KNN algorithm relies on distance metrics to identify the K nearest neighbors to a given data point. The most commonly used distance metrics in KNN include:
Euclidean Distance: This is the straight-line distance between two points in a plane or space. It is the most commonly used distance metric in KNN due to its simplicity and effectiveness in measuring proximity.
Manhattan Distance: Also known as the L1 distance, this metric calculates the total distance you would travel if you could only move along horizontal and vertical lines. It is particularly useful in grid-like structures where movement is restricted to orthogonal directions.
Minkowski Distance: This is a generalization of both Euclidean and Manhattan distances. It includes these distances as special cases and allows for flexibility in measuring distance by adjusting the parameter p.
These distance metrics play a crucial role in determining the nearest neighbors and, consequently, the accuracy of the KNN algorithm.
Choosing the Value of K
Importance of Choosing the Right K
Choosing the right value of K is crucial in the KNN algorithm. The value of K directly affects the decision boundaries of the algorithm. A higher value of K results in smoother decision boundaries, which can help in reducing the impact of noise in the data. However, choosing a very high value of K can lead to overfitting, where the model becomes too generalized and loses its ability to make accurate predictions on new data. Conversely, choosing a very low value of K can lead to underfitting, where the model becomes too sensitive to the training data and fails to generalize well.
To determine the optimal value of K, one can plot the training error rate and validation error rate curves. The point where the validation error rate is minimized typically indicates the best value of K. This process, known as hyperparameter tuning, is essential for achieving the best performance from the KNN algorithm.
When to use KNN and Why?
For each prediction, the time complexity of the the K-Nearest Neighbors algorithm is O(MNlog(k), where M is the dimension of the data and N is the size or number of instances in the training data set. But still, there are multiple specialized ways of organizing data to address this and make it more efficient.
Additionally, several preprocessing techniques can be used to eliminate missing data and noise and ensure that the dataset is balanced. Due to this, KNN is one of the most widely used algorithms.
We can use the Cosine function and the K Nearest Neighbor algorithm to determine how similar or different two sets of items are and then use that information to classify them.
In a high-dimensional vector space, the Cosine function calculates the similarity or distance between observations. These high-dimensional data slow the computation process if used through a traditional database; hence, they can be stored in a vector database.
This is one of the many use cases of vector databases. This is precisely where Zilliz steps in, especially if you're wondering, ‘What is a vector database?' and whether to explore it.
There's no denying that vector databases are the need of the hour in the modern era of AI. Zilliz offers a one-stop solution for challenges in handling , especially for enterprises that build AI/ML applications that leverage vector similarity search
- Introduction to KNN
- How to calculate the K-Nearest Neighbor (KNN) algorithm and how the algorithm works
- Choosing the Value of K
- When to use KNN and Why?
Content
Start Free, Scale Easily
Try the fully-managed vector database built for your GenAI applications.
Try Zilliz Cloud for FreeKeep Reading
data:image/s3,"s3://crabby-images/57b6e/57b6e2b0daa2bbee57627623c63b66382b9fe2a6" alt="RocketQA: Optimized Dense Passage Retrieval for Open-Domain Question Answering"
RocketQA: Optimized Dense Passage Retrieval for Open-Domain Question Answering
RocketQA is a highly optimized dense passage retrieval framework designed to enhance open-domain question-answering (QA) systems.
data:image/s3,"s3://crabby-images/82053/82053fccf361521f7684058be4f1a20804e6dd20" alt="AI Integration in the Legal Industry: Revolutionizing Legal Practice with Data-Driven Solutions"
AI Integration in the Legal Industry: Revolutionizing Legal Practice with Data-Driven Solutions
Discover how AI and vector databases are revolutionizing legal work through advanced document processing, semantic search, and contract analysis capabilities.
data:image/s3,"s3://crabby-images/babb9/babb9e4a17764b976b34db97ff000265c492e509" alt="ColPali: Enhanced Document Retrieval with Vision Language Models and ColBERT Embedding Strategy"
ColPali: Enhanced Document Retrieval with Vision Language Models and ColBERT Embedding Strategy
ColPali is an advanced document retrieval model designed to index and retrieve information directly from the visual features of documents, particularly PDFs.