Teaching A Peceptron to See

9 minute read

Ten years after the ImageNet Challenge thawed the last AI winter, ChaptGPT and generative AI have become part of our everyday life and colloquial language, like (almost) no one has imagined just 2 years back. As increasingly more folks aspire to foray into the field of ML/AI, I can’t help but think about a lesson from my guitar teacher:

Everyone wants to start playing the songs they love right off the bat, but without nailing seemingly “boring” building blocks such as scales, harmonies, and rhythms, the songs you love will sound like a nightmare…

Not to sound discouraging (yup, when people say “not to do X”, they’re doing exactly X…), but I think those who can ride the AI wave are those who take their time to build from the foundations. Let’s begin with the oldest and simplest neural net: The Perceptron, (Rosenblatt, 1959).

Seeing “7” vs. “L”

Which of 2 images below is a “7” and which one is an “L”?

To us humans, answering this question is a piece of cake: The left image is “L” and the right “7”. But how is it that we know (except for saying “intuitions”)?

How do humans see, in plain English?

We can divide each image into 4 pixels and check whether each pixel is filled (1) or empty (0), based on which we can come up with rules to classify images:

  • “L”: top right is empty (0) + lower left is filled (1)
  • “7”: lower left is empty (0) + top right is filled (1)

But how can a machine see?

However, machines don’t have eyes and, more importantly, it’d be cumbersome to write bespoke rules for each and every use case. As a common strategy in machine learning, we can somehow map an input vector (in this case, a 4-vector with pixel values) to an output vector (a 2-vector with probabilities of the 2 classes).

  • $\mathbf{x}$: a vector encoding the value of each pixel πŸ‘‰ already known (the green arrow indicates the direction from the first to the last pixel)
  • $\mathbf{w}$: a vector encoding the weight of each pixel πŸ‘‰ unknown

There are an infite number of rules that can map $\mathbf{x}$ and $\mathbf{w}$ to class labels. Let’s use the sign of their dot product $\mathbf{w}^T \mathbf{x}$ (np.dot(x, w)) to determine the class label πŸ‘‰ negative: 0 (“L”) vs. positive: 1 (“7”); which class is positive is arbitrary.

In our case, weights of pixels 1 and 3 can be 0 — these 2 are the same regardless of the label, so they are not really “pulling the punches” towards classification. Pixels 2 and 4 are likely negative in the “L” picture but positive in the “7” picture — which are in the direction of the desired dot product sign in each image.

Wait… how do we “learn” the weights?

Ummm, aren’t we back the realm of relying on human eyes to see the images and human brains to hand-engineer rules? We don’t have to — we can initialize $\mathbf{w}$ with 4 random floats, say [0.5, 0.9, -0.3, 0.5], and “learn from mistakes”.

  • Attempt #1: Say the first training example is an “L”. Using the random weights above, we get $\mathbf{w}^T \mathbf{x} = 0.5 \times 1 + 0.9 \times 0 -0.3 \times 1 + 0.5 \times 1 = 0.7$. Because the dot product is positive, we predict “7”, which is incorrect.

    • Correction: Because the dot product is too large (should be < 0 to classify the image as “L”!), we decrease the weights $\mathbf{w}$.
    • New weights: A crude way to do so is subtract $\mathbf{x}$ from $\mathbf{w}$ (let’s save gradient descent for the future…) πŸ‘‰ $\mathbf{w} = \mathbf{w}-\mathbf{x} \ = [0.5 - 1, 0.9 - 0, -0.3 - 1, 0.5 - 1] \ = [-0.5, 0.9, -1.3, -0.5]$
  • Attempt #2: The second training example is a “7”. Using updates weights, we get $\mathbf{w}^T \mathbf{x} = -0.5 \times 1 + 0.9 \times 1 -1.3 \times 0 - 0.5 \times 1 = -0.1$. Since the dot product is negative, we predict “L”, which is incorrect again.

    • Correction: Because the dot product is too small (should be > 0 to classify the image as “7”), we increase the weights $\mathbf{w}$.
    • New weights: $\mathbf{w} = \mathbf{w}-\mathbf{x} \ = [-0.5 + 1, 0.9 + 1, -1.3 + 0, - 0.5 + 1] \ = [0.5, 1.9, -1.3, 0.5]$
  • Final & successful attempt: The third training example is an “L”. Using the newest weights, we get $\mathbf{w}^T \mathbf{x} = 0.5 \times 1 + 1.9 \times 0 -1.3 \times 1 + 0.5 \times 1 = -0.3$. Since the dot product is negative, we correctly predict “L”!

To achieve good performance, neural networks usually train for (far) more than 3 examples. Early stopping can be applied when model hasn’t improved for a while.

(The slides above are adapted from the Spring 2022 offering of UC Berkeley’s CogSci 131 taught by me.)

Code Up A Perceptron

In the L vs. 7 toy example, each image only has 4 pixels. To practice what we’ve learned, let’s solve a slightly more complex problem: Classifying images of handwritten digits (e.g., a 28 $\times$ 28 image of “1”) into numerical labels (1).

(All images and code snippets used below can be found in this repo.)

Read image inputs

If using a convolutional neural net (check out this great intuitive explanation by 3Blue1Brown), you may want to keep input images as 28 $\times$ 28 matrices. In this case, however, we can just flatten each image into a 28 $\times$ 28 = 784 vector.

 1# dimension of all images
 2DIM = (28, 28)
 3
 4# flattened image dimension
 5N = DIM[0] * DIM[1]
 6
 7def load_image_files(n, path="images/"):
 8    """loads images of given digit and returns a list of vectors"""
 9    # initialize empty list to collect vectors
10    images = []
11    # read files in the path
12    for f in sorted(os.listdir(os.path.join(path, str(n)))):
13        p = os.path.join(path, str(n), f)
14        if os.path.isfile(p):
15            i = np.loadtxt(p)
16            # check image dimension
17            assert i.shape == DIM
18            # flatten i into a single vector
19            images.append(i.flatten())
20    return images

Classify 0 vs. 1

As a starter, let’s classify 0 vs. 1. First off, let’s translate some conceptual operations mentioned in the 7 vs. L example into helper functions.

  • Use dot product to classify one image: Given a known weight vector, we can compute $\mathbf{w}^T \mathbf{x}$ and return 0 or 1 as the classification result
1def classify_image(W, image):
2    """use weight matrix W to determine digit in image"""
3    # if dot product > 0, return 1; otherwise 0
4    y = (np.dot(image, W) > 0).astype(int)
5    return y
  • Update weights upon feedback from each true label: Subtract $\mathbf{x}$ from $\mathbf{w}$ in the case of false positive (should be 0 but predicted 1) and add $\mathbf{x}$ to $\mathbf{w}$ in case of false negative (should be 1 but predicted 0)
 1def update_weights(W, images, labels):
 2    """updates weight matrix W based on training (image,label) pairs"""
 3
 4    # loop through images and labels
 5    for image, label in zip(images, labels):
 6        # predict image label
 7        pred_label = classify_image(W, image)
 8        # update only if prediction is wrong
 9        if pred_label != label:
10            # predicted 0 but label is 1
11            if label == 1:
12                W += image
13            # predicted 1 but label is 0
14            else:
15                W -= image
16
17    # return updated weights
18    return W

Finally, we can write the training function that takes 4 arguments: Name of 1st digit (class 0), name of 2nd digit (class 1), # of training examples per class, and # of epochs (i.e., how many times we want to repeat the full training loop).

 1def train_perceptron(digit1, digit2, n_samples=25, epoch=200):
 2    """train perceptron on (image, label) pairs for given steps"""
 3
 4    # load images of two digits to compare
 5    img1 = load_image_files(digit1)
 6    img2 = load_image_files(digit2)
 7
 8    # initialize empty list to collect accuarcies
 9    accuracies = []
10    # initilaize random weights from standard normal
11    W = np.random.normal(0, 1, size=N)
12
13    # iterate through each epoch
14    for i in tqdm(range(epoch)):
15
16        # randomly sample images for '0'
17        sample1 = random.sample(img1, n_samples)
18        # randomly sample images for '1'
19        sample2 = random.sample(img2, n_samples)
20
21        # train on chosen samples
22        W = update_weights(W, sample1 + sample2, [0] * n_samples + [1] * n_samples)
23        # evaluate performance on all images
24        accuracy = compute_accuracy(W, img1 + img2, [0] * len(img1) + [1] * len(img2))
25        accuracies.append(accuracy)
26
27    # return accuracies and final weights
28    return accuracies, W

As you may have noticed from above, after each loop, we check the model accuracy, which is defined as # of correctly classified examples / # of examples classified.

 1def compute_accuracy(W, images, labels):
 2    """computes accuracy on a list of images"""
 3    # initialize count at 0
 4    n_correct = 0
 5    # loop through image list
 6    for i, image in enumerate(images):
 7        # increment count by each time we correctly classify an image
 8        n_correct += (labels[i] == classify_image(W, image)).astype(int)
 9
10    # accuracy is total correct / total
11    return n_correct / len(images)

Since the Perceptron algorithm is quite simple, training is lightning fast — we can train for 500 epochs in a matter of minutes. Training GPT may take months.

 1# train on 25 examples for given number of epochs
 2n_samples, epochs = 25, 500
 3accuracies, trained_W = train_perceptron(0, 1, n_samples, epochs)
 4
 5# plot accuracy as a function of epoch
 6fig, ax = plt.subplots(figsize=(10, 8))
 7ax.plot(range(epochs), accuracies)
 8ax.set_xlabel("epoch", fontsize=20)
 9ax.set_ylabel("mean accuracy", fontsize=20)
10ax.set_title("training accuracy", fontsize=25)

We almost achieved perfect (but not 100%) accuracy on 0 vs. 1 classification.

Interpret learned weights

People say neural nets are black boxes, which may well be true in the case of complex networks. In our example, learned weights may simply represent the exclusive disjunction of where “0” pixels and “1” pixels are in the input images.

1# reshape weights back to 28 * 28
2trained_W_2d = np.reshape(trained_W, DIM)
3
4# plot weight matrix
5fig, ax = plt.subplots(figsize=(10, 10))
6ax.imshow(trained_W_2d)
7ax.axis("off")

Generalize to all image pairs

Finally, if we can classify 0 vs. 1 and check model performance, we can do the same for each digit pair (e.g., 0 vs. 1, 1 vs. 2…). The code below trains one classifier for each digit pair and track each classifier’s accuracy after the last epoch.

 1def get_accuracy_matrix(n_num):
 2    """get maxtrix with classification accuracy"""
 3
 4    # initialize matrix with all 1's
 5    acc_matrix = np.ones([n_num, n_num])
 6
 7    # loop through digit pairs
 8    for digit1 in range(n_num - 1):
 9        for digit2 in range(digit1 + 1, n_num):
10            # get accuracies after 100 epochs
11            acc, _ = train_perceptron(digit1, digit2, 25, 100)
12            # replace value in matrix with final accuracy
13            acc_matrix[digit1, digit2] = acc[-1]
14            acc_matrix[digit2, digit1] = acc[-1]
15
16    # return matrix
17    return acc_matrix
18
19
20def plot_accuracy_matrix(acc_matrix, n_num):
21    """plot classification accuracy using a heatmap"""
22
23    # generate heatmap
24    fig, ax = plt.subplots(figsize=(8, 8))
25    g = ax.imshow(acc_matrix, cmap="Blues")
26    ax.set_xticks(range(0, n_num))
27    ax.set_yticks(range(0, n_num))
28    ax.set_title("accuracy matrix", fontsize=20)
29
30    # add color bar
31    fig.colorbar(g, ax=ax)
32
33
34# plot accuracy between each digit pair 
35acc_matrix = get_accuracy_matrix(10)
36plot_accuracy_matrix(acc_matrix, 10)

As we can imagine, 5 and 8 are pretty hard to tell apart as written digits, so are 7 vs. 9 and 6 vs. 9, but the rest of digit pairs have high accuracy.