Implementing K-nearest neighbors algorithm in Python

Narender Ravulakollu
3 min readAug 4, 2023

--

Introduction

The K-nearest neighbors (KNN) algorithm is a simple yet powerful supervised learning algorithm used for classification and regression tasks. It works on the principle of finding the K closest training examples in the feature space to a given test example and predicting the label based on the majority vote or averaging the labels of the nearest neighbors. In this article, we will learn how to implement the K-nearest neighbors algorithm in Python.

About K-nearest neighbors algorithm:

The K-nearest neighbors algorithm is a non-parametric algorithm that does not make any assumptions about the underlying data distribution. It is a lazy learning algorithm, meaning it does not build an explicit model during training but instead stores the training dataset for later use during the prediction phase. The algorithm calculates the distances between the test example and all training examples using a distance metric (e.g., Euclidean distance) and selects the K nearest neighbors based on these distances. The final prediction is made based on the majority class or the average of the labels of the K neighbors.

About Dataset:

In this implementation, we will be using the Iris dataset, a popular dataset in machine learning. The Iris dataset consists of measurements of sepal length, sepal width, petal length, and petal width of three different species of Iris flowers. It is commonly used to demonstrate various machine learning algorithms.

Requirements:

To implement the K-nearest neighbors algorithm in Python, we need the following libraries:

- pandas: For data manipulation and analysis.

- scikit-learn (sklearn): For machine learning algorithms and evaluation metrics.

- matplotlib: For data visualization.

- seaborn: For creating a confusion matrix.

Implementation:

First, we import the required libraries and load the Iris dataset:

import pandas as pd
from sklearn.datasets import load_iris
iris = load_iris()
# Create a DataFrame from the dataset
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
df['flower_name'] = df.target.apply(lambda x: iris.target_names[x])
Next, we split the dataset into training and testing sets:
from sklearn.model_selection import train_test_split
X = df.drop(['target', 'flower_name'], axis='columns')
y = df.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

Then, we create the K-nearest neighbors classifier and fit it to the training data:

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(X_train, y_train)
We can evaluate the model's accuracy on the test data:
accuracy = knn.score(X_test, y_test)
print("Accuracy:", accuracy)

We can also make predictions for new data points using the trained model:

new_data = [[4.8, 3.0, 1.5, 0.3]]
predicted_class = knn.predict(new_data)
print("Predicted class:", predicted_class)

To visualize the performance of the model, we can plot a confusion matrix:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
y_pred = knn.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(7, 5))
sns.heatmap(cm, annot=True)
plt.xlabel('Predicted')
plt.ylabel('Truth')

Finally, we can print a classification report, which provides precision, recall, and F1-score for each class:

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred))

Conclusion:

In this article, we learned how to implement the K-nearest neighbors algorithm in Python using the Iris dataset. We covered the steps from data loading to model evaluation, including training the classifier, making predictions, visualizing the results using a confusion matrix, and printing a classification report. The K-nearest neighbors algorithm is a versatile and intuitive algorithm that can be applied to various classification tasks. It is worth exploring different values of K and experimenting with different datasets to understand its behavior and performance.

--

--