Plot Random Multilabel DatasetΒΆ

============================================== Plot randomly generated multilabel datasetΒΆ

This illustrates the :func:~sklearn.datasets.make_multilabel_classification dataset generator. Each sample consists of counts of two features (up to 50 in total), which are differently distributed in each of two classes.

Points are labeled as follows, where Y means the class is present:

===== ===== ===== ====== 1 2 3 Color ===== ===== ===== ====== Y N N Red N Y N Blue N N Y Yellow Y Y N Purple Y N Y Orange Y Y N Green Y Y Y Brown ===== ===== ===== ======

A star marks the expected sample for each class; its size reflects the probability of selecting that class label.

The left and right examples highlight the n_labels parameter: more of the samples in the right plot have 2 or 3 labels.

Note that this two-dimensional example is very degenerate: generally the number of features would be much greater than the β€œdocument length”, while here we have much larger documents than vocabulary. Similarly, with n_classes > n_features, it is much less likely that a feature distinguishes a particular class.

Imports for Visualizing Randomly Generated Multilabel DatasetsΒΆ

make_multilabel_classification generates synthetic multilabel data by sampling from class-conditional word distributions, closely modeling how real document collections are structured: Each sample is a β€œdocument” whose feature counts are drawn from a mixture of class-specific multinomial distributions, with the number of assigned labels following a Poisson distribution controlled by n_labels. The return_distributions=True parameter exposes the underlying generative parameters: p_c (prior probability of each class) and p_w_c (per-class word/feature probabilities), which are visualized as star markers whose sizes reflect class priors and whose positions reflect the expected feature counts for each class.

The side-by-side comparison of n_labels=1 versus n_labels=3 reveals how label multiplicity affects the data distribution and class overlap: With n_labels=1, most samples belong to a single class and cluster tightly around their class centroids, producing well-separated groups colored by the 7-color scheme (single classes get primary colors, combinations get mixed colors like purple for class 1+2). With n_labels=3, samples frequently belong to multiple classes simultaneously, causing them to appear in intermediate positions between class centroids and reducing class separability. This synthetic data generator is useful for benchmarking multilabel classifiers like OneVsRestClassifier and ClassifierChain under controlled conditions where the true generative process is known.

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_multilabel_classification as make_ml_clf

COLORS = np.array(
    [
        "!",
        "#FF3333",  # red
        "#0198E1",  # blue
        "#BF5FFF",  # purple
        "#FCD116",  # yellow
        "#FF7216",  # orange
        "#4DBD33",  # green
        "#87421F",  # brown
    ]
)

# Use same random seed for multiple calls to make_multilabel_classification to
# ensure same distributions
RANDOM_SEED = np.random.randint(2**10)
def plot_2d(ax, n_labels=1, n_classes=3, length=50):
    X, Y, p_c, p_w_c = make_ml_clf(
        n_samples=150,
        n_features=2,
        n_classes=n_classes,
        n_labels=n_labels,
        length=length,
        allow_unlabeled=False,
        return_distributions=True,
        random_state=RANDOM_SEED,
    )

    ax.scatter(
        X[:, 0], X[:, 1], color=COLORS.take((Y * [1, 2, 4]).sum(axis=1)), marker="."
    )
    ax.scatter(
        p_w_c[0] * length,
        p_w_c[1] * length,
        marker="*",
        linewidth=0.5,
        edgecolor="black",
        s=20 + 1500 * p_c**2,
        color=COLORS.take([1, 2, 4]),
    )
    ax.set_xlabel("Feature 0 count")
    return p_c, p_w_c


_, (ax1, ax2) = plt.subplots(1, 2, sharex="row", sharey="row", figsize=(8, 4))
plt.subplots_adjust(bottom=0.15)

p_c, p_w_c = plot_2d(ax1, n_labels=1)
ax1.set_title("n_labels=1, length=50")
ax1.set_ylabel("Feature 1 count")

plot_2d(ax2, n_labels=3)
ax2.set_title("n_labels=3, length=50")
ax2.set_xlim(left=0, auto=True)
ax2.set_ylim(bottom=0, auto=True)

plt.show()

print("The data was generated from (random_state=%d):" % RANDOM_SEED)
print("Class", "P(C)", "P(w0|C)", "P(w1|C)", sep="\t")
for k, p, p_w in zip(["red", "blue", "yellow"], p_c, p_w_c.T):
    print("%s\t%0.2f\t%0.2f\t%0.2f" % (k, p, p_w[0], p_w[1]))