Principal Component Analysis on MNIST dataset

../_images/light-pca.svg ../_images/dark-pca.svg
  1. Load data

# pip install scikit-learn
import detroit as d3
import polars as pl # for data manipulation
from collections import namedtuple
from sklearn.datasets import load_digits # for MNIST dataset
from sklearn.decomposition import PCA # PCA decomposition
from sklearn.preprocessing import StandardScaler # To normalize data

# Prepare data
mnist = load_digits() # load data
scaler = StandardScaler() # initialize scaler
X_scaled = scaler.fit_transform(mnist.data) # update scaler and normalize data
pca = PCA(n_components=2) # initialize PCA decomposition
components = pca.fit_transform(X_scaled) # update PCA and transform data

df = pl.DataFrame(components, schema=["Component 1", "Component 2"])
df = df.insert_column(2, pl.Series("digit", mnist.target)) # add digit as "Z-axis" for color
shape: (1_797, 3)
┌─────────────┬─────────────┬───────┐
│ Component 1 ┆ Component 2 ┆ digit │
│ ---         ┆ ---         ┆ ---   │
│ f64         ┆ f64         ┆ i64   │
╞═════════════╪═════════════╪═══════╡
│ 1.914217    ┆ -0.954454   ┆ 0     │
│ 0.58898     ┆ 0.924555    ┆ 1     │
│ 1.302033    ┆ -0.317105   ┆ 2     │
│ -3.02077    ┆ -0.868772   ┆ 3     │
│ 4.528946    ┆ -1.093314   ┆ 4     │
│ …           ┆ …           ┆ …     │
│ 0.104338    ┆ 0.254933    ┆ 9     │
│ 2.423241    ┆ -1.429641   ┆ 0     │
│ 1.022598    ┆ -0.148016   ┆ 8     │
│ 1.076056    ┆ -0.380974   ┆ 9     │
│ -1.257696   ┆ -2.227632   ┆ 8     │
└─────────────┴─────────────┴───────┘
  1. Make the PCA chart

# Specify the chart's dimensions
width = 928
height = 600

Margin = namedtuple("Margin", ["top", "right", "bottom", "left"])
margin = Margin(50, 30, 30, 40)


# Create the horizontal x scale
x = (
    d3.scale_linear()
    .set_domain([df["Component 1"].min(), df["Component 1"].max()])
    .nice()
    .set_range([margin.left, width - margin.right])
)

# Create the vertical y scale
y = (
    d3.scale_linear()
    .set_domain([df["Component 2"].min(), df["Component 2"].max()])
    .nice()
    .set_range([height - margin.bottom, margin.top])
)

svg = (
    d3.create("svg")
    .attr("width", width)
    .attr("height", height)
    .attr("viewBox", f"0 0 {width} {height}")
)

# Append the axis

(
    svg.append("g")
    .attr("transform", f"translate(0, {height - margin.bottom})")
    .call(d3.axis_bottom(x))
    .call(lambda g: g.select(".domain").remove())
    .call(
        lambda g: g.append("text")
        .attr("x", width - margin.right)
        .attr("y", -4)
        .attr("fill", "#000")
        .attr("font-weight", "bold")
        .attr("text-anchor", "end")
        .text("Component 1")
    )
)

(
    svg.append("g")
    .attr("transform", f"translate({margin.left}, 0)")
    .call(d3.axis_left(y))
    .call(lambda g: g.select(".domain").remove())
    .call(
        lambda g: g.select(".tick:last-of-type")
        .select("text")
        .clone()
        .attr("x", 4)
        .attr("text-anchor", "start")
        .attr("font-weight", "bold")
        .text("Component 2")
    )
)

# Append the symbols

symbol_type = d3.scale_ordinal(df["digit"].unique().sort().to_list(), d3.SYMBOLS_STROKE)

color = d3.scale_sequential(
    [df["digit"].min(), df["digit"].max()], d3.interpolate_rainbow
)

(
    svg.append("g")
    .attr("fill", "none")
    .attr("stroke-width", 1.5)
    .select_all("symbol")
    .data(df.iter_rows())
    .join("g")
    .attr("transform", lambda d: f"translate({x(d[0])}, {y(d[1])})")
    .append("path")
    .attr("d", lambda d: d3.symbol(symbol_type(d[2]))())
    .attr("stroke", lambda d: color(d[2]))
)

# Legend

labels = df["digit"].unique().sort()
nb_columns = labels.len()  # number of labels
offset = 40  # Space between legend labels
symbol_size = 3

legend = svg.append("g").select_all("legend").data(labels.to_list()).enter()

(
    legend.append("g")
    .attr(
        "transform",
        lambda _, i: f"translate({i * offset + margin.left - symbol_size * 4}, {30 - symbol_size * 1.5})"
    )
    .append("path")
    .attr("d", lambda d: d3.symbol(symbol_type(d))())
    .style("stroke-width", 1.5)
    .style("stroke", lambda d: color(d))
    .style("fill", "none")
)

(
    legend.append("text")
    .attr("x", lambda _, i: i * offset + margin.left)
    .attr("y", 30)
    .text(lambda d: str(d))
    .style("fill", "black")
    .style("font-size", 15)
)
  1. Save your chart

with open("pca.svg", "w") as file:
    file.write(str(svg))