Algorithm explained: Linear regression using gradient descent with PHP

Posted on

Part 4 of Algorithms explained! Every few weeks I write about an algorithm and explain and implement it!
Is there an algorithm you always wanted to know about? Leave a comment!

I once took an online class on machine learning by the amazing Andrew Ng from Stanford for my bachelors studies. One of the first ML algorithms I ever had the chance to look at, was linear regression.

This algorithm really struck me. It allows machines to do what humans are incredibly good at: Pattern recognition. It allows machines to create a linear function that more or less matches seemingly unrelated data by applying mathematics to it!

In this post I’ll explain how linear regression works and implement it in PHP!



So what does linear regression actually do?

In one sentence: It creates a linear function that represents the trend of something. The go-to example is usually house prices compared to their size. The larger a house, the more expensive it is. Sounds reasonable, right? You would put a known house size, for example in square meters, into the function and you would get an estimated price out of it.

Let’s have a look at this data set:

House prices and sizes compared to each other in a two dimensional graph. There's some correlation visible.

On first glance, the data seems to correlate. The larger the house, the more expensive it is. The data is a bit scattered, but there’s an overall trend visible:

The same graph as above, now with an added trend line

(This is an approximation)

Of course there’s usually more factors to the price of a house in real life: Is there a grocery store nearby, is it modern, was it once considered haunted, what’s the size of its backyard? These factors make up the variation in the data – it’s not linear.

The algorithm itself works like this:

  1. Define any odd linear function as a trend line (usually random)
  2. Measure how far off it is by calculating the average distance between predicted Y and actual Y of every data point (the so called “error”)
  3. Adjust the trend line based on the measurement (the “gradient descent”)
  4. Repeat Step 2 and 3 until the average distance has reached a minimum (not necessarily 0)

Let me illustrate with the above example:

The different steps illustrated

(Again, this is an approximation)

So, to implement the algorithm, we need three different pieces:

  • The definition of a linear function
  • The measurement
  • The adjustment
  • Repeating until it’s at a minimum

Let’s start with…



The definition of a linear function

A linear function


f(x)f(x)



usually comes in the form of:

f(x)=ax+b
f(x) = ax + b

Whereas

aa



basically defines the slope of the function and

bb



moves it along the Y axis.

We can make a more general definition of a linear function by introducing a variable called

x0x_0



into the function and renaming a and b to

c0c_0



and

c0c_0



(“c” stands for “coefficient”):

f(x0,x1)=c0x0+c1x1
f(x_0, x_1) = c_0x_0 + c_1x_1

In PHP, this is pretty straight forward:

<?php

declare(strict_types=1);

function linearFunction (float $c0, float $x0, float $c1, float $x1) : float {
    return $c0 * $x0 + $c1 * $x1;
}
Enter fullscreen mode

Exit fullscreen mode

We can randomly define $c0 and $c1 to get a first starting point:

$c0 = mt_rand(-100, 100); // -100 and 100 are almost parallel to the X and Y axis
$c1 = mt_rand(-100, 100);
Enter fullscreen mode

Exit fullscreen mode

So far, so good. Let’s generate some data next!



Getting data

I’m going to generate the data randomly with a correlation. To get the linear function

f(x0,x1)=c0x0+c1x1f(x_0, x_1) = c_0x_0 + c_1x_1



back into it’s original form

f(x)=c0+c1xf(x) = c_0 + c_1x



, we set

x0x_0



to always be 1:

$data = [
    [1, mt_rand(0, 100), mt_rand(0, 100)],
    [1, mt_rand(100, 200), mt_rand(100, 200)],
    [1, mt_rand(200, 300), mt_rand(200, 300)],
    [1, mt_rand(300, 400), mt_rand(300, 400)],
    [1, mt_rand(400, 500), mt_rand(400, 500)],
    [1, mt_rand(500, 600), mt_rand(500, 600)],
    [1, mt_rand(600, 700), mt_rand(600, 700)],
    [1, mt_rand(700, 800), mt_rand(700, 800)],
    [1, mt_rand(800, 900), mt_rand(800, 900)],
    [1, mt_rand(900, 1000), mt_rand(900, 1000)],
];