Back

What is K-Nearest Neighbors (KNN) Algorithm in Machine Learning? An Essential Guide

By Zilliz on Oct 17, 2022

What is K-Nearest Neighbors (KNN) Algorithm in Machine Learning? An Essential Guide

The K Nearest Neighbor algorithm is a supervised machine learning algorithm that can be used to solve both classification and regression problems. In this comprehensive article from Zilliz, a leading vector database company for production-ready AI, we’ll dive deep into what KNN algorithm in machine learning is, why it’s needed, how KNN works, what its benefits are, and how to improve KNN. We’ll also demonstrate a KNN model implementation using Python.

What is a KNN Algorithm?

Let’s start with what the KNN algorithm is essentially for. The K Nearest Neighbor algorithm estimates the likelihood that a data point will belong to one of two groups based on which data points are closest to it.

A KNN algorithm can be used for classification as well as regression problems. It is categorized as a lazy learner, which means it only stores a training dataset rather than going through a training stage.

Additionally, it implies that all computation is performed when a classification or prediction is made. Since it uses memory to store all of its training data, it is also known as memory-based learning.

KNN has two key characteristics. Firstly, KNN is an algorithm that is non-parametric. It means when the model is used, no assumptions about the dataset are made. Instead, the model is built entirely from the data that is provided.

Secondly, when using KNN, the dataset is not divided into training and test sets. That’s because KNN makes no distinction between a training and testing set. All of the data is used when the model is asked to predict.

How Does a KNN Algorithm Work?

To determine the class of an unobserved observation, KNN essentially uses a voting mechanism. It indicates that the class that receives the most votes will be the class for the relevant data point.

If K is equal to 1, we will only consider a data point’s closest neighbor when determining its class. The 10 closest neighbors will be used if K is equal to 10, and so on.

How KNN works between two classes. Source: https://www.ibm.com/in-en/topics/knn How KNN works between two classes. Source: https://www.ibm.com/in-en/topics/knn

Consider two classes: A and B. The algorithm examines the states of data points nearby to determine whether a data point belongs to Class A or Class B. It is almost certain that the data point in question belongs to group A if most of the data points are in that group.

Now, you may wonder how the distance metric is calculated to determine whether or not a data point is a neighbor, right? There are numerous methods for calculating the distance between a data point and its nearest neighbor. They include Euclidean distance, Manhattan distance, Hamming distance, Cosine distance, Jaccard distance, Minkowski distance, and a few others.

This chart tells us what these distance metrics are all about:

Distance metrics. Distance metrics.

The majority of steps in KNN regression are the same as classification. Instead of assigning the class with the most votes, the unknown data point is assigned the average of its neighbors’ values.

The Value of ‘K’ in KNN Algorithm

Choosing the correct value of K is known as hyperparameter tuning, and it is required for better results. There is no defined method for determining the best value of K. It is determined by the specific type of problem.

The K value specifies how many neighbors will be checked to determine the classification of a specific query point. If k=1, for example, the instance is assigned to the same class as its single nearest neighbor.

Different values of K can lead to overfitting or underfitting, so defining it can be a balancing act. Lower K values can have high variance but low bias, while higher K values can have high bias but low variance.

Why Do We Need KNN Algorithm?

KNN makes highly accurate predictions. It can compete with the most accurate SOTA models (State-of-the-art models). As a result, the KNN algorithm can be used for applications that require high accuracy but, at the same time, do not require a human-readable model.

The accuracy of the predictions is determined by the distance measured. Thus, the KNN algorithm is appropriate for applications with sufficient domain knowledge. This understanding helps select an appropriate measure.

Improving KNN: How to Do it

Normalizing data on the same scale is recommended for better results. In general, the normalization range is between 0 and 1. Apart from this, hyperparameter tuning of K and distance metric are also critical.

We can test the KNN algorithm with different values of K using the cross-validation technique. The model with the highest accuracy can be considered the best option.

KNN in Practice: Python Implementation of KNN Algorithm

Let’s now get into the implementation of the KNN model in Python. We are using Python 3.8.5 for this in Jupyter notebook. We’ll go over the steps to help you break the code down.

Here it goes:

Importing the modules

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import make_blobs
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

Dataset

Scikit-learn can be used for creating synthetic datasets, which are great for demo purpose.

X, y = make_blobs(n_samples = 4000, n_features = 3, centers = 3 ,cluster_std = 2, random_state = 80)

X

array([[ 7.60190561,  4.86336321,  6.97616573],
       [ 5.97809745,  7.69910922,  2.77419701],
       [-4.36024844, -2.23247572, -5.29113293],
       ...,
       [-8.22252297, -6.88609334, -6.52102135],
       [-3.96254707, -5.27559922, -2.70880022],
       [-4.25865881, -1.67791521, -3.70523373]])

y

array([1, 1, 2, ..., 2, 2, 2])

Plot

plt.figure(figsize = (6,6))
plt.scatter(X[:,0], X[:,1], c=y, marker= '.', s=10, edgecolors='blue')
plt.show()

df = pd.DataFrame(X)
df.head()
plt.rcParams['figure.figsize']=(10,15)
df.plot(kind='hist', bins=100, subplots=True, layout=(5,2), sharex=False, sharey=False)
plt.show()

KNN Classifier Implementation

The first step is to figure out the k. The calculation of the K value varies greatly depending on the situation. The default value of K when using the Scikit-Learn Library is 5 and the default distance metric used is Euclidean.

Tuning the Model

from sklearn.model_selection import GridSearchCV
param_grid = {'n_neighbors':np.arange(1,4)}

knn = KNeighborsClassifier()
knn_cv= GridSearchCV(knn,param_grid,cv=5)
knn_cv.fit(X,y)

print(knn_cv.best_params_)
print(knn_cv.best_score_)
{'n_neighbors': 3}
0.9887499999999999
#train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 80)
# instantiate the model
knn = KNeighborsClassifier(n_neighbors=3)

# fit the model to the training set
knn.fit(X_train, y_train)

y_pred = knn.predict(X_test)

print('Model accuracy score: {0:0.4f}'. format(accuracy_score(y_test, y_pred)))

Model accuracy score: 0.9890.

We got an accuracy rate of 98.90%, which is considered very good. We increased the number of neighbours from 1 to 4, and the model performed best at k=3.

Crucial Benefits of KNN

1. Time-efficient

KNN modeling does not involve any training period since the data itself is a model that will be the reference for future prediction. As a result, it is time-efficient, enabling quick improvisation for random modeling on the available data.

2. Simple to Tune

KNN only requires two hyperparameters, a K value and a distance metric, making it simpler to tune than other machine learning algorithms.

3. Easy Adaptability

Most classifier algorithms are easy to implement for binary problems but require extra effort to implement for multi-class problems. In contrast, KNN adapts to multi-class problems without any extra effort.

And a Few Drawbacks

1. High-dimensional Data

KNN does not work well with large or high-dimensional data because calculating distances between each data instance would be prohibitively expensive.

2. Sensitive or Missing Data

KNN does not work well where data is sensitive to noise and where there’s missing data.

3. Unbalanced Data

With unbalanced data, too, KNN does not perform well.

4. Curse of Dimensionality

Because of the curse of dimensionality, KNN is more prone to overfitting. While feature selection and dimensionality reduction techniques are used to prevent this, the value of K can impact the model’s behavior.

Why KNN is a Widely Used Algorithm

For each prediction, the time complexity of the KNN algorithm is O(MNlog(k), where M is the dimension of the data and N is the size or number of instances in the training data. But still, there are multiple specialized ways of organizing data to address this and make it more efficient.

Additionally, several preprocessing techniques can be used to eliminate missing data and noise and ensure that the dataset is balanced. Due to this, KNN is one of the most widely used algorithms.

The Bottom Line: Vector Databases

We can use the Cosine function and the K Nearest Neighbor algorithm to determine how similar or different two sets of items are and then use that information to classify them.

In high-dimensional space, the Cosine function is used to calculate the similarity or distance between observations. These high-dimensional data make the computation process very slow if used through a traditional database; hence, they can be stored in a vector database.

This is one of the many use cases of vector databases. And this is precisely where Zilliz steps in, especially if you’re wondering ‘what is a vector database?’ and whether to explore it.

There’s no denying that vector databases are the need of the hour in the modern era of AI. Zilliz offers a one-stop solution for challenges in handling unstructured data, especially for enterprises that build AI/ML applications that leverage vector similarity search.