Imports for Plotting Hierarchical Clustering DendrogramsΒΆ
Dendrograms visualize the full merge history of agglomerative (bottom-up) hierarchical clustering, showing which clusters were merged at each step and at what distance. By setting distance_threshold=0 and n_clusters=None in AgglomerativeClustering, the full tree is computed without stopping early, and the children_ and distances_ attributes record the complete linkage structure. The plot_dendrogram helper converts this into scipyβs linkage matrix format for rendering with scipy.cluster.hierarchy.dendrogram.
Reading a dendrogram: The y-axis shows the merge distance (dissimilarity) at which two sub-clusters were joined, and the x-axis shows individual samples or sub-clusters. Tall vertical lines indicate large jumps in merge distance, suggesting natural cluster boundaries β cutting the tree at those heights yields well-separated clusters. The truncate_mode='level' parameter limits the display to the top p levels, preventing visual clutter when the dataset is large. This approach applied to the Iris dataset reveals that the three species form a clear hierarchical structure, with two species merging at a lower distance before the third joins at a much higher distance.
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
"""
=========================================
Plot Hierarchical Clustering Dendrogram
=========================================
This example plots the corresponding dendrogram of a hierarchical clustering
using AgglomerativeClustering and the dendrogram method available in scipy.
"""
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import load_iris
def plot_dendrogram(model, **kwargs):
# Create linkage matrix and then plot the dendrogram
# create the counts of samples under each node
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
# Plot the corresponding dendrogram
dendrogram(linkage_matrix, **kwargs)
iris = load_iris()
X = iris.data
# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
model = model.fit(X)
plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="level", p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()