Click here to Skip to main content
15,867,834 members
Articles / Artificial Intelligence / Machine Learning

Supervised Learning

Rate me:
Please Sign up or sign in to vote.
2.33/5 (2 votes)
23 Jul 2019CPOL21 min read 6.3K   6  
An article about supervised learning
In this article, we will be covering the following topics: When to use regression and classification, how to implement regression and classification using Go machine learning libraries, how to measure the performance of an algorithm.

Supervised learning is one of the two major branches of machine learning. In a way, it is similar to how humans learn a new skill: someone else shows us what to do, and we are then able to learn by following their example. In the case of supervised learning algorithms, we usually need lots of examples, that is, lots of data providing the input to our algorithm and what the expected output should be. The algorithm will learn from this data, and then predict the output based on new inputs that it hasn’t seen before.

A surprising number of problems can be addressed by using supervised learning. Many email systems use it to classify emails as either important or unimportant, automatically whenever a new message arrives in the inbox. More complex examples include image recognition systems, which can identify what an image contains purely from the input pixel values.

These systems start by learning from huge datasets of images that have been labelled manually by humans but are then able to categorize completely new images automatically. It is even possible to use supervised learning to steer a car automatically around a racing track: the algorithm starts by learning how a human driver controls the vehicle and is eventually able to replicate this behaviour.

By the end of this article, you will be able to use Go to implement two types of supervised learning:

  • Classification, where an algorithm must learn to classify the input into two or more discrete categories. We will build a simple image recognition system to demonstrate how this works.
  • Regression, in which the algorithm must learn to predict a continuous variable, for example, the price of an item for sale on a website. For our example, we will predict house prices based on inputs, such as the location, size, and age of the house.

In this article, we will be covering the following topics:

  • When to use regression and classification
  • How to implement regression and classification using Go machine learning libraries
  • How to measure the performance of an algorithm

We will cover the two stages involved in building a supervised learning system:

  • Training, which is the learning phase where we use labelled data to calibrate an algorithm
  • Inference or prediction, where we use the trained algorithm for its intended purpose: to make predictions from input data

Classification

When starting any supervised learning problem, the first step is to load and prepare the data. We are going to start by loading the MNIST Fashiondataset, a collection of small, gray scale images showing different items of clothing. Our job is to build a system that can recognize what is in each image; that is, does it contain a dress, a shoe, a coat, and so on?

First, we need to download the dataset by running the download-fashion-mnist.sh script in the code repository. Then, we will load it into Go:

Go
import (
    "fmt"
     mnist "github.com/petar/GoMNIST"
    "github.com/kniren/gota/dataframe"
    "github.com/kniren/gota/series"
    "math/rand"
    "github.com/cdipaolo/goml/linear"
    "github.com/cdipaolo/goml/base"
    "image"
    "bytes"
    "math"
    "github.com/gonum/stat"
    "github.com/gonum/integrate"
)
set, err := mnist.ReadSet("../datasets/mnist/images.gz", "../datasets/mnist/labels.gz")

Let's start by taking a look at a sample of the images. Each one is 28 x 28 pixels, and each pixel has a value between 0 and 255. We are going to use these pixel values as the inputs to our algorithm: our system will accept 784 inputs from an image and use them to classify the image according to which item of clothing it contains. In Jupyter, you can view an image as follows:

Go
set.Images[1]

This will display one of the 28 x 28 images from the dataset, as shown in the following image:

Image 1

To make this data suitable for a machine learning algorithm, we need to convert it into a data frame format. To start, we will load the first 1,000 images from the dataset:

Go
func MNISTSetToDataframe(st *mnist.Set, maxExamples int) dataframe.DataFrame {
 length := maxExamples
 if length > len(st.Images) {
 length = len(st.Images)
 }
 s := make([]string, length, length)
 l := make([]int, length, length)
 for i := 0; i < length; i++ {
 s[i] = string(st.Images[i])
 l[i] = int(st.Labels[i])
 }
 var df dataframe.DataFrame
 images := series.Strings(s)
 images.Name = "Image"
 labels := series.Ints(l)
 labels.Name = "Label"
 df = dataframe.New(images, labels)
 return df
}

df := MNISTSetToDataframe(set, 1000)

We also need a string array that contains the possible labels for each image:

Go
categories := []string{"tshirt", "trouser", "pullover", 
                       "dress", "coat", "sandal", "shirt", "shoe", "bag", "boot"}

It is very important to start by reserving a small proportion of your data in order to test the finished algorithm. This allows us to measure how well the algorithm works on new data that was not used during training. If you do not do this, you will most likely build a system that works really well during training but performs badly when faced with new data. To start with, we are going to use 75% of the images to train our model and 25% of the images to test it.

Note that splitting your data into a training set and a test set is a crucial step when using supervised learning. It is normal to reserve 20-30% of the data for testing, but if your dataset is very large, you may be able to use less than this.

Use the Split (df dataframe.DataFrame, valFraction float64) function from the last chapter to prepare these two datasets:

Go
training, validation := Split(df, 0.75)

A Simple Model – the Logistic Classifier

One of the simplest algorithms that solve our problem is logistic classifier. This is what mathematicians call a linear model, which we can understand by thinking about a simple example where we are trying to classify the points on the following two charts as either circles or squares. A linear model will try to do this by drawing a straight line to separate the two types of point. This works very well on the left-hand chart, where the relationship between the inputs (on the chart axes) and the output (circle or square) is simple. However, it does not work on the right-hand chart, where it is not possible to split the points into two correct groups using a straight line:

Image 2

When faced with a new machine learning problem, it is advised that you start with a linear model as a baseline, and then compare other models to it. Although linear models cannot not capture complex relationships in the input data, they are easy to understand and normally quick to implement and train. You might find that a linear model is good enough for the problem you are working on and save yourself time by not having to implement anything more complex. If not, you can try different algorithms and use the linear model to understand how much better they work.

Note that a baseline is a simple model that you can use as a point of reference when comparing different machine learning algorithms.

Going back to our image dataset, we are going to use a logistic classifier to decide whether an image contains trousers or not. First, let's do some final data preparation: simplify the labels to be either trousers (true) or not-trousers (false):

Go
func EqualsInt(s series.Series, to int) (*series.Series, error) {
 eq := make([]int, s.Len(), s.Len())
 ints, err := s.Int()
 if err != nil {
 return nil, err
 }
 for i := range ints {
 if ints[i] == to {
 eq[i] = 1
 }
    }
    ret := series.Ints(eq)
    return &ret, nil
}

trainingIsTrouser, err1 := EqualsInt(training.Col("Label"), 1)
validationIsTrouser, err2 := EqualsInt(validation.Col("Label"), 1)
if err1 != nil || err2 != nil {
    fmt.Println("Error", err1, err2)
}

We are also going to normalize the pixel data so that, instead of being stored as integers between 0 and 255, it will be represented by floats between 0 and 1:

Note that many supervised machine learning algorithms only work properly if the data is normalized, that is, rescaled so that it is between 0 and 1. If you are having trouble getting an algorithm to train properly, make sure that you have normalized the data properly.

Go
func NormalizeBytes(bs []byte) []float64 {
    ret := make([]float64, len(bs), len(bs))
    for i := range bs {
        ret[i] = float64(bs[i])/255.
    }
    return ret
}

func ImageSeriesToFloats(df dataframe.DataFrame, col string) [][]float64 {
    s := df.Col(col)
    ret := make([][]float64, s.Len(), s.Len())
    for i := 0; i < s.Len(); i++ {
        b := []byte(s.Elem(i).String())
        ret[i] = NormalizeBytes(b)
    }
    return ret
}

trainingImages := ImageSeriesToFloats(training, "Image")
validationImages := ImageSeriesToFloats(validation, "Image")

After preparing the data properly, it is finally time to create a logistic classifier and train it:

Go
model := linear.NewLogistic(base.BatchGA, 1e-4, 1, 150, 
         trainingImages, trainingIsTrouser.Float())

//Train
err := model.Learn()
if err != nil {
  fmt.Println(err)
}

Measuring Performance

Now that we have our trained model, we need to measure how well it is performing by comparing the predictions it makes on each image with the ground truth (whether or not the image is a pair of trousers). A simple way to do this is to measure accuracy.

Accuracy measures what proportion of the input data can be classified correctly by the algorithm, for example, 90%, if 90 out of 100 predictions from the algorithm are correct.

In our Go code example, we can test the model by looping over the validation dataset and counting how many images are classified correctly. This will output a model accuracy of 98.8%:

Go
//Count correct classifications
var correct = 0.
for i := range validationImages {
  prediction, err := model.Predict(validationImages[i])
  if err != nil {
    panic(err)
  }

  if math.Round(prediction[0]) == validationIsTrouser.Elem(i).Float() {
    correct++
  }
}

//accuracy
correct / float64(len(validationImages))

Precision and Recall

Measuring accuracy can be very misleading. Suppose you are building a system to classify whether medical patients will test positive for a rare disease, and in the dataset, only 0.1% of examples are in fact positive. A really bad algorithm might predict that nobody will test positive, and yet it has an accuracy of 99.9% simply because the disease is rare.

Note that a dataset that has many more examples of one classification versus another is known as unbalanced. Unbalanced datasets need to be treated carefully when measuring algorithm performance.

A better way to measure performance starts by putting each prediction from the algorithm into one of the following four categories:

Image 3

We can now define some new performance metrics:

  • Precision measures what fraction of the model's true predictions are actually correct. In the following diagram, it is the true positives that are predicted from the model (the left-hand side of the circle) divided by all of the model's positive predictions (everything in the circle).
  • Recall measures how good the model is at identifying all the positive examples. In other words, the true positives (left-hand side of the circle) divided by all the data points that are actually positive (the entire left-hand side):

Image 4

The preceding diagram shows data points that have been predicted as true by the model in the central circle. The points that are actually true are on the left half of the diagram.

Note that precision and recall are more robust performance metrics when working with unbalanced datasets. Both range between 0 and 1, where 1 indicates perfect performance.

Following is the code for the total count of true positives and false negatives:

Go
//Count true positives and false negatives
var truePositives = 0.
var falsePositives = 0.
var falseNegatives = 0.
for i := range validationImages {
  prediction, err := model.Predict(validationImages[i])
  if err != nil {
    panic(err)
  }
  if validationIsTrouser.Elem(i).Float() == 1 {
    if math.Round(prediction[0]) == 0 {
      // Predicted false, but actually true
      falseNegatives++
    } else {
      // Predicted true, correctly
      truePositives++
    }
  } else {
    if math.Round(prediction[0]) == 1 {
      // Predicted true, but actually false
      falsePositives++
    }
  }
}

We can now calculate the precision and recall with the following code:

Go
//precision
truePositives / (truePositives + falsePositives)
//recall
truePositives / (truePositives + falseNegatives)

For our linear model, we get 100% precision, meaning that there are no false positives and a recall of 90.3%.

ROC Curves

Another way to measure performance involves looking at how the classifier works in more detail. Inside our model, two things happen:

  • First, the model calculates a value between 0 and 1, indicating how likely it is that a given image should be classified as a pair of trousers.
  • A threshold is set, so that only images scoring more than the threshold get classified as trousers. Setting different thresholds can improve precision at the expense of recall and vice versa.

If we look at the model output across all the different thresholds from 0 to 1, we can understand more about how useful it is. We do this using something called the receiver operating characteristic (ROC) curve, which is a plot of the true positive rate versus the false positive rate across the dataset for different threshold values. The following three examples show ROC curves for a bad, moderate, and very good classifier:

Image 5

By measuring the shaded area under these ROC curves, we get a simple metric of how good the model is, which is known as area under curve (AUC). For the bad model, this is close to 0.5, but for the very good model, it is close to 1.0, indicating that the model can achieve both a high true positive rate and a low false positive rate.

The gonum/stat package provides a useful function for computing ROC curves, which we will use once we have extended the model to work with each of the different items of clothing in the dataset.

Note that the receiver operating characteristic, or ROC curve, is a plot of true positive rate versus false positive rate for different threshold values. It allows us to visualize how good the model is at classification. The AUC gives a simple measure of how good the classifier is.

Multi-Class Models

Up until now, we have been using binary classification; that is, it should output true if the image shows a pair of trousers, and false otherwise. For some problems, such as detecting whether an email is important or not, this is all we need. But in this example, what we really want is a model that can identify all the different types of clothing in our dataset, that is, shirt, boot, dress, and so on.

With some algorithm implementations, we will need to start by applying one-hot encoding to the output. However, for our example, we will use softmax regression in goml/linear, which perform this step automatically. We can train the model by simply feeding it with the input (pixel values) and the integer output (0, 1, 2 ... representing t-shirt, trouser, pullover, and so on):

Go
model2 := linear.NewSoftmax(base.BatchGA, 1e-4, 1, 10, 100, 
          trainingImages, training.Col("Label").Float())

//Train
err := model2.Learn()
if err != nil {
  fmt.Println(err)
}

When using this model for inference, it will output a vector of probabilities for each class; that is, it tells us whether the input image is a t-shirt, trousers, and so on. This is exactly what we need for the ROC analysis, but, if we want a single prediction for each image, we can use the following function to find the class that has the highest probability:

Go
func MaxIndex(f []float64) (i int) {
  var (
    curr float64
    ix int = -1
  )
  for i := range f {
    if f[i] > curr {
      curr = f[i]
      ix = i
    }
  }
  return ix
}

Next, we can plot the ROC curve and the AUC for each individual class. The following code will loop over each example in the validation dataset and predict probabilities for each class using the new model:

Go
//create objects for ROC generation
//as per https://godoc.org/github.com/gonum/stat#ROC
y := make([][]float64, len(categories), len(categories))
classes := make([][]bool, len(categories), len(categories))
//Validate
for i := 0; i < validation.Col("Image").Len(); i++ {
  prediction, err := model2.Predict(validationImages[i])
  if err != nil {
    panic(err)
  }
  for j := range categories {
    y[j] = append(y[j], prediction[j])
    classes[j] = append(classes[j], 
    validation.Col("Label").Elem(i).Float() != float64(j))
  }
}

//Calculate ROC
tprs := make([][]float64, len(categories), len(categories))
fprs := make([][]float64, len(categories), len(categories))

for i := range categories {
  stat.SortWeightedLabeled(y[i], classes[i], nil)
  tprs[i], fprs[i] = stat.ROC(0, y[i], classes[i], nil)
}

We can now compute AUC values for each class, which shows that our model performs better on some classes than others:

Go
for i := range categories {
  fmt.Println(categories[i])
  auc := integrate.Trapezoidal(fprs[i], tprs[i])
  fmt.Println(auc)
}

For trousers, the AUC value is 0.96, showing that even a simple linear model works really well in this case. However, a shirt and pullover both score close to 0.6. This makes intuitive sense: shirts and pullovers look very similar and are therefore much harder for the model to recognize correctly. We can see this more clearly by plotting the ROC curve for each class as separate lines: the model clearly performs the worst on shirts and pullovers, and the best on the clothes that have a very distinctive shape (boots, trousers, sandals, and so on).

The following code loads gonums plotting libraries, creates the ROC plot, and saves it as a JPEG image:

Go
import (
  "gonum.org/v1/plot"
  "gonum.org/v1/plot/plotter"
  "gonum.org/v1/plot/plotutil"
  "gonum.org/v1/plot/vg"
  "bufio"
)

func plotROCBytes(fprs, tprs [][]float64, labels []string) []byte {
  p, err := plot.New()
  if err != nil {
    panic(err)
  }

  p.Title.Text = "ROC Curves"
  p.X.Label.Text = "False Positive Rate"
  p.Y.Label.Text = "True Positive Rate"

  for i := range labels {
    pts := make(plotter.XYs, len(fprs[i]))
    for j := range fprs[i] {
      pts[j].X = fprs[i][j]
      pts[j].Y = tprs[i][j]
    }
    lines, points, err := plotter.NewLinePoints(pts)
    if err != nil {
      panic(err)
    }
    lines.Color = plotutil.Color(i)
    lines.Width = 2
    points.Shape = nil

    p.Add(lines, points)
    p.Legend.Add(labels[i], lines, points)
  }

  w, err := p.WriterTo(5*vg.Inch, 4*vg.Inch, "jpg")
  if err != nil {
    panic(err)
  }
  if err := p.Save(5*vg.Inch, 4*vg.Inch, "Multi-class ROC.jpg"); err != nil {
    panic(err)
  }
  var b bytes.Buffer
  writer := bufio.NewWriter(&b)
  w.WriteTo(writer)
  return b.Bytes()
}

Image 6

If we view the plot in Jupyter, we can see that the worst classes follow the lines close to the diagonal, again indicating an AUC close to 0.5:

A Non-Linear Model – The Support and Vector Machine

To move forward, we need to use a different machine learning algorithm: one that is able to model more complex, non-linear relationships between the pixel inputs and the output classes. While some of the mainstream Go machine learning libraries such as Golearn have support for basic algorithms like local least squares, there is not a single library that supports as broad a set of algorithms as Python's scikit-learn or R's standard library. For this reason, it is often necessary to search for alternative libraries that implement bindings to a widely used C library, or that contain a configurable implementation of an algorithm that is suited for a particular problem.

For this example, we are going to use an algorithm called the support vector machine (SVM). SVMs can be more difficult to use than linear models—they have more parameters to tune—but have the advantage of being able to model much more complex patterns in the data.

Note that an SVM is a more advanced machine learning method that can be used both for classification and regression. They allow us to apply kernels to the input data, which means that they can model non-linear relationships between the inputs/outputs.

An important feature of SVM models is their ability to use a kernel function. Put simply, this means that the algorithm can apply a transformation to the input data so that non-linear patterns can be found. For our example, we will use the LIBSVM library to train an SVM on the image data. LIBSVM is an open source library with bindings for many different languages, meaning that it is also useful if you want to port a model that has been built in Python's popular scikit-learn library. First, we need to do some data preparation to make our input/output data suitable for feeding into the Go library:

Go
trainingOutputs := make([]float64, len(trainingImages))
validationOutputs := make([]float64, len(validationImages))

ltCol:= training.Col("Label")
for i := range trainingImages {
    trainingOutputs[i] = ltCol.Elem(i).Float()
}

lvCol:= validation.Col("Label")
for i := range validationImages {
    validationOutputs[i] = lvCol.Elem(i).Float()
}

// FloatstoSVMNode converts a slice of float64 to SVMNode 
// with sequential indices starting at 1
func FloatsToSVMNode(f []float64) []libsvm.SVMNode {
    ret := make([]libsvm.SVMNode, len(f), len(f))
    for i := range f {
        ret[i] = libsvm.SVMNode{
            Index: i+1,
            Value: f[i],
        }
    }
    //End of Vector
    ret = append(ret, libsvm.SVMNode{
        Index: -1,
        Value: 0,
    })
    return ret
}

Next, we can set up the SVM model and configure it with a radial basis function (RBF) kernel. RBF kernels are a common choice when using SVMs, but do take longer to train than linear models:

Go
var (
  trainingProblem libsvm.SVMProblem
  validationProblem libsvm.SVMProblem
)

trainingProblem.L = len(trainingImages)
validationProblem.L = len(validationImages)
for i := range trainingImages {
  trainingProblem.X = append(trainingProblem.X, FloatsToSVMNode(trainingImages[i]))
}
trainingProblem.Y = trainingOutputs

for i := range validationImages {
  validationProblem.X = append(validationProblem.X, FloatsToSVMNode(validationImages[i]))
}
validationProblem.Y = validationOutputs

// configure SVM
svm := libsvm.NewSvm()
param := libsvm.SVMParameter{
  SvmType: libsvm.CSVC,
  KernelType: libsvm.RBF,
  C: 100,
  Gamma: 0.01,
  Coef0: 0,
  Degree: 3,
  Eps: 0.001,
  Probability: 1,
}

Finally, we can fit our model to the training data of 750 images and then use svm.SVMPredictProbability to predict probabilities, as we did with the linear multi-class model:

Go
model := svm.SVMTrain(&trainingProblem, &param)

As we did previously, we compute the AUC and ROC curves, which demonstrate that this model performs much better across the board, including the difficult classes, like shirt and pullover:

Image 7

Overfitting and Underfitting

The SVM model is performing much better on our validation dataset than the linear model, but, in order to understand what to do next, we need to introduce two important concepts in machine learning: overfitting and underfitting. These both refer to problems that can occur when training a model.

If a model underfits the data, it is too simple to explain the patterns in the input data, and therefore performs poorly when evaluated against the training dataset and the validation dataset. Another term for this problem is that the model has high bias. If a model overfits the data, it is too complex, and will not generalize well to new data points that were not included as part of the training. This means that the model will perform well when evaluated against the training data, but poorly when evaluated against the validation dataset. Another term for this problem is that the model has a high variance.

An easy way to understand the difference between overfitting and underfitting is to look at the following simple example: when building a model, our aim is to build something that is just right for the dataset. The example on the left underfits because a straight line model cannot accurately divide the circles and squares. The model on the right is too complex: it separates all the circles and squares correctly but is unlikely to work well on new data:

Image 8

Our linear model suffered from underfitting: it was too simplistic to model the difference between all the classes. Looking at the accuracy of the SVM, we can see that it scores 100% on the training data, but only 82% on validation. This is a clear sign that it is overfitting: it is much worse at classifying new images compared with those on which it was trained.

One way of dealing with overfitting is to use more training data: even a complex model will not be able to overfit if the training dataset is large enough. Another way to do this is to introduce regularization: many machine learning models have a parameter that you can adjust to reduce overfitting.

Deep Learning

So far, we have improved our model's performance using an SVM, but still, face two problems:

  • Our SVM is overfitting the training data.
  • It is also difficult to scale to the full dataset of 60,000 images: try training the last example with more images and you will find that it gets much slower. If we double the number of data points, the SVM algorithm takes more than double the amount of time.

In this section, we are going to tackle this problem using a deep neural network. These types of model have been able to achieve state-of-the-art performance on image classification tasks, as well as many other machine learning problems. They are able to model complex non-linear patterns, and also scale well to large datasets.

Data scientists will often use Python to develop and train neural networks because it has access to extremely well-supported deep learning frameworks such as TensorFlow and Keras. These frameworks make it easier than ever to build complex neural networks and train them on large datasets. They are usually the best choice for building sophisticated deep learning models. In this section, we will build a much simpler neural network from scratch using the go-deep library to demonstrate the key concepts.

Neural Networks

The basic building block of a neural network is a neuron (also known as a perceptron). This is actually just the same as our simple linear model: it combines all of its inputs, that is, x1, x2, x3... and so on into a single output, y, according to the following formula:

Image 9

The magic of neural networks comes from what happens when we combine these simple neurons:

  1. First, we create a layer of many neurons into which we feed the input data.
  2. At the output of each neuron, we introduce an activation function.
  3. The output of this input layer is then fed to another layer of neurons and activations, known as a hidden layer.
  4. This gets repeated for multiple hidden layers—the more layers there are, the deeper the network is said to be.
  5. A final output layer of neurons combines the result of the network into the final output.
  6. Using a technique known as backpropagation, we can train the network by finding the weights, w0, w1, w2..., for each neural network that allows the whole network to fit the training data.

The following diagram shows this layout: the arrows represent the output of each neuron, which is feeding into the input of the neurons in the next layer:

Image 10

The neurons in this network are said to be arranged fully-connected or dense layers. Recent advances in both computing power and software have allowed researchers to build and train more complex neural network architectures than ever before. For instance, a state-of-the-art image recognition system might contain millions of individual weights and require many days of computing time to train all of these parameters to fit a large dataset. They often contain different arrangements of neurons, for instance, in convolutional layers, which perform more specialized learning in these types of systems.

Much of the skill that is required to use deep learning successfully in practice involves a broad understanding of how to select and tune a network to get good performance. There are many blogs and online resources that provide more detail on how these networks work and the types of problems that they have been applied to.

A fully-connected layer in a neural network is one where the inputs of each neuron are connected to the outputs of all the neurons in the previous layer.

A Simple Deep Learning Model Architecture

Much of the skill in building a successful deep learning model involves choosing the correct model architecture: the number/size/type of layers, and the activation functions for each neuron. Before starting, it is worth researching to see if someone else has already tackled a similar problem to yours using deep learning and published an architecture that works well. As always, it is best to start with something simple and then modify the network iteratively to improve its performance.

For our example, we will start with the following architecture:

  • An input layer
  • Two hidden layers containing 128 neurons each
  • An output layer of 10 neurons (one for each output class in the dataset)
  • Each neuron in the hidden layer will use a rectified linear unit (ReLU) as its output function

ReLUs are a common choice of activation function in neural networks. They are a very simple way to introduce non-linearity into a model. Other common activation functions include the logistic function and the tanh function.

The go-deep library lets us build this architecture very quickly:

Go
import (
 "github.com/patrikeh/go-deep"
 "github.com/patrikeh/go-deep/training"
)

network := deep.NewNeural(&deep.Config{
 // Input size: 784 in our case (number of pixels in each image)
 Inputs: len(trainingImages[0]),
 // Two hidden layers of 128 neurons each, and an output layer 10 neurons 
 // (one for each class)
 Layout: []int{128, 128, len(categories)},
 // ReLU activation to introduce some additional non-linearity
 Activation: deep.ActivationReLU,
 // We need a multi-class model
 Mode: deep.ModeMultiClass,
 // Initialise the weights of each neuron using normally distributed random numbers
 Weight: deep.NewNormal(0.5, 0.1),
 Bias: true,
})

Neural Network Training

Training a neural network is another area in which you need to make skilful adjustments in order to get good results. The training algorithm works by calculating how well the model fits a small batch of training data (known as the loss) and then making small adjustments to the weights to improve the fit. This process then gets repeated over and over again on different batches of training data. The learning rate is an important parameter that controls how quickly the algorithm will adjust the neuron weights.

When training a neural network, the algorithm will feed all of the input data into the network repeatedly, and adjust the network weights as it goes. Each full pass through the data is known as an epoch.

When training a neural network, monitor the accuracy and loss of the network after each epoch (accuracy should increase, while loss should decrease). If the accuracy is not improving, try lowering the learning rate. Keep training the network until accuracy stops improving: at this point, the network is said to have converged.

The following code trains our model using a learning rate of 0.006 for 500 iterations and prints out the accuracy after each epoch:

Go
// Parameters: learning rate, momentum, alpha decay, nesterov
optimizer := training.NewSGD(0.006, 0.1, 1e-6, true)
trainer := training.NewTrainer(optimizer, 1)

trainer.Train(network, trainingExamples, validationExamples, 500) 
// training, validation, iterations

This neural network provides an accuracy of 80% on the training and validation datasets, a good sign that the model is not overfitting.

In this article, we covered a lot of grounds and explored many important machine learning concepts. The first step in tackling a supervised learning problem is to collect and preprocess the data, making sure that it is normalized and split into training and validation sets. Here, we covered a range of different algorithms for classification.

If you found this article useful, you might also find ‘Machine Learning with Go Quick Start Guide’ helpful. This book helps you in efficiently developing machine learning applications in Go. With this, you will be able to understand the types of problem that ML solves and the various approaches. You’ll also be able to visualize data with gonnum/plot and Gophernotes. If you want to learn how to set up a machine learning project for success, this is the right pick for you.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
United Kingdom United Kingdom
Founded in 2004 in Birmingham, UK, Packt's mission is to help the world put software to work in new ways, through the delivery of effective learning and information services to IT professionals.

Working towards that vision, we have published over 5000 books and videos so far, providing IT professionals with the actionable knowledge they need to get the job done - whether that's specific learning on an emerging technology or optimizing key skills in more established tools.

Comments and Discussions

 
-- There are no messages in this forum --