5
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
# Generate random data and labels
data = np.random.rand(100)
labels = ["Class1" if x <= 0.5 else "Class2"
for x in data[:50]]
# Euclidean distance function
def euclidean_distance(x1, x2):
return abs(x1 - x2)
# k-NN classifier
def knn_classifier(train_data,
train_labels, test_point, k):
distances =
[(euclidean_distance(test_point, train_data[i]), train_labels[i]) for i in
range(len(train_data))]
distances.sort(key=lambda
x: x[0])
k_nearest_neighbors =
distances[:k]
k_nearest_labels =
[label for _, label in k_nearest_neighbors]
return Counter(k_nearest_labels).most_common(1)[0][0]
train_data =
data[:50]
train_labels = labels
test_data = data[50:]
k_values = [1, 2, 3,
4, 5, 20, 30]
print("---
k-Nearest Neighbors Classification ---")
print("Training
dataset: First 50 points labeled based on the rule (x <= 0.5 -> Class1, x
> 0.5 ->
Class2)")
print("Testing
dataset: Remaining 50 points to be classified\n")
results = {}
for k in k_values:
print(f"Results
for k = {k}:")
classified_labels = [knn_classifier(train_data, train_labels, test_point,
k) for test_point in
test_data]
results[k] =
classified_labels
for i, label in
enumerate(classified_labels, start=51):
print(f"Point
x{i} (value: {test_data[i - 51]:.4f}) is classified as {label}")
print("\n")
print("Classification
complete.\n")
for k in k_values:
classified_labels =
results[k]
class1_points =
[test_data[i] for i in range(len(test_data)) if classified_labels[i]
== "Class1"]
class2_points =
[test_data[i] for i in range(len(test_data)) if classified_labels[i]
== "Class2"]
plt.figure(figsize=(10,
6))
plt.scatter(train_data,
[0] * len(train_data), c=["blue" if label ==
"Class1" else "red" for label in
train_labels],
label="Training
Data", marker="o")
plt.scatter(class1_points,
[1] * len(class1_points), c="blue", label="Class1
(Test)", marker="x")
plt.scatter(class2_points,
[1] * len(class2_points), c="red", label="Class2
(Test)", marker="x")
plt.title(f"k-NN
Classification Results for k = {k}")
plt.xlabel("Data
Points")
plt.ylabel("Classification
Level")
plt.legend()
plt.grid(True)
plt.show()
Comments
Post a Comment