Plot Nca IllustrationΒΆ

============================================= Neighborhood Components Analysis IllustrationΒΆ

This example illustrates a learned distance metric that maximizes the nearest neighbors classification accuracy. It provides a visual representation of this metric compared to the original point space. Please refer to the :ref:User Guide <nca> for more information.

Imports for Visualizing NCA’s Learned Metric SpaceΒΆ

NCA transforms distances to improve stochastic neighbor assignment: NeighborhoodComponentsAnalysis learns a linear transformation A that maps points from the original feature space to a new space where stochastic nearest-neighbor classification accuracy is maximized. The probability that point i selects point j as its neighbor is proportional to exp(-||Ax_i - Ax_j||^2), computed via logsumexp for numerical stability in the softmax normalization. In the original space, a query point may have neighbors from multiple classes at similar distances; after NCA transformation, same-class points cluster tightly while different-class points are pushed apart, making the correct class assignment overwhelmingly probable.

Visual encoding of neighbor probabilities as line thickness: The link_thickness_i helper computes the softmax-normalized selection probabilities between a reference point and all other points based on squared Euclidean distances. The relate_point function draws lines between the reference point and every other point with thickness proportional to the selection probability, colored by class label. Before NCA, the lines show that the reference point has non-trivial connection probabilities to points from multiple classes. After the learned transformation, the line thickness concentrates almost entirely on same-class points, visually demonstrating how NCA reshapes the geometry to make nearest-neighbor classification reliable even with k=1.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from scipy.special import logsumexp

from sklearn.datasets import make_classification
from sklearn.neighbors import NeighborhoodComponentsAnalysis

# %%
# Original points
# ---------------
# First we create a data set of 9 samples from 3 classes, and plot the points
# in the original space. For this example, we focus on the classification of
# point no. 3. The thickness of a link between point no. 3 and another point
# is proportional to their distance.

X, y = make_classification(
    n_samples=9,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_classes=3,
    n_clusters_per_class=1,
    class_sep=1.0,
    random_state=0,
)

plt.figure(1)
ax = plt.gca()
for i in range(X.shape[0]):
    ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center")
    ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis("equal")  # so that boundaries are displayed correctly as circles
def link_thickness_i(X, i):
    diff_embedded = X[i] - X
    dist_embedded = np.einsum("ij,ij->i", diff_embedded, diff_embedded)
    dist_embedded[i] = np.inf

    # compute exponentiated distances (use the log-sum-exp trick to
    # avoid numerical instabilities
    exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))
    return exp_dist_embedded
def relate_point(X, i, ax):
    pt_i = X[i]
    for j, pt_j in enumerate(X):
        thickness = link_thickness_i(X, i)
        if i != j:
            line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
            ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])


i = 3
relate_point(X, i, ax)
plt.show()

# %%
# Learning an embedding
# ---------------------
# We use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an
# embedding and plot the points after the transformation. We then take the
# embedding and find the nearest neighbors.

nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
nca = nca.fit(X, y)

plt.figure(2)
ax2 = plt.gca()
X_embedded = nca.transform(X)
relate_point(X_embedded, i, ax2)

for i in range(len(X)):
    ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va="center", ha="center")
    ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax2.set_title("NCA embedding")
ax2.axes.get_xaxis().set_visible(False)
ax2.axes.get_yaxis().set_visible(False)
ax2.axis("equal")
plt.show()