Neural Networks can do a lot of things. Personally, the way they find patterns in a set of complex data always felt like magic to me, even after I started understanding how they work. They have a lot of potential uses and being able to run them in a browser opens a lot more doors. So, how do you do it?

Three Questions

This article will answer 3 main questions regarding Browser-side Neural Networks:

  • What is a Neural Network?
  • How do you create one and how do you run it in a browser?
  • Why would you want to do it in the first place?

So, without further ado, let’s get to it!

What is a Neural Network?

An (artificial) Neural Network is a computing system inspired by biological neural networks – our brains. They’re really good at finding relations and recognizing patterns in a large set of complex data.

A typical neural network is a graph that consists of:

  • Input nodes (green)
  • Output nodes (blue)
  • Hidden nodes in between (grey)
  • Edges between the nodes

Long story short, given a set of weights of the edges and some input values, the network can “push” the values through the layers and give a prediction based on the computed values of the output nodes.

The simplest example of this would be a Classification Neural Network – each of its output nodes is assigned a “class” of the input data (e.g., “car”, “bike” and “person” if you’re classifying images), and, after the input values are pushed through, the output node with the highest value will be the prediction of the network.

That’s exactly what we’re going to be making later: a Neural Network that can classify images of handwritten characters.

How do you create a Neural Network?

While it is possible to make custom code that would create and train a Neural Network, it’s a lot easier to use some third-party tools.

In our case we will be using Python and TensorFlow – an open-source platform with lots of tools for machine learning. A great part about TensorFlow is that it comes with TensorFlowJS – a JavaScript package for easy integration of machine learning into your web app.

Before we get to the code though, there’s one thing we need to do first – collect a dataset. Why? Well, the easiest way to find the weights of the edges, to “train the network,” is to give it a large set of pre-classified data to study. The network will basically make predictions, compare its results to the actual labels and learn on its mistakes over and over again.

For this example, I will be using a slightly modified version of the EMNIST Dataset. It contains over 100,000 images of handwritten letters and digits. The link to my version of the dataset and to the full source code is available at the bottom of this article.

Once you have the dataset, creating and training the model in Python only takes a few steps. First, we need to load the dataset and split it into 2 parts – train dataset (for the actual training) and validation dataset (for validating the accuracy of the model after and during training).

raw_train_ds = tf.keras.preprocessing.image_dataset_from_directory(
     'characters',
     validation_split=VALIDATION_SPLIT,
     subset='training',
     seed=RANDOM_SEED,
     image_size=(IMAGE_SIZE, IMAGE_SIZE),
     batch_size=BATCH_SIZE
)

raw_val_ds = tf.keras.preprocessing.image_dataset_from_directory(
     'characters',
     validation_split=VALIDATION_SPLIT,
     subset='validation',
     seed=RANDOM_SEED,
     image_size=(IMAGE_SIZE, IMAGE_SIZE),
     batch_size=BATCH_SIZE
)

The pixel values of the images in the dataset are in range from 0 to 255. The networks work best if the inputs are normalized to be from 0 to 1, so the next step is to rescale the raw dataset:

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = raw_train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = raw_val_ds.map(lambda x, y: (normalization_layer(x), y))

After we have our datasets, it’s time to create the neural network model itself. I won’t explain how exactly this code works because it’s not the main focus of the article, but if you’re interested in more details, check out TensorFlow Quickstart. Also, just a disclaimer, the model presented here is not optimized for accuracy or speed, so you could definitely get better results if you use a more complex model.

model = Sequential([
    layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes)
])

model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

After the model is created and compiled, we need to train it. This can take anywhere from minutes to days or even weeks, depending on the size of the dataset or the complexity of the model. Personally, I managed to get decent accuracy after training this model on the dataset above for 3 hours.

history = model.fit(
    train_ds, 
    validation_data=val_ds,
    epochs=EPOCHS,
)

Finally, after the model finished training, the only thing left to do is to convert the model to a TensorFlowJS format and save it.

tfjs.converters.save_keras_model(model, 'tfjs_model')

How do you use a Network in a Browser?

Once we have a trained model, running predictions in a browser is really simple.

First, install and import `TensorFlowJS` following the instructions on their website.

import tf from '@tensorflow/tfjs';

Then, drop the saved model from the previous step into the public folder of your project and use tf.loadLayersModel to load the model into a variable.

const model = await tf.loadLayersModel(`${process.env.PUBLIC_URL}/model/model.json`);

Once the model is loaded you need to create an input for it to classify. In this case, an input is a tensor with pixel values of an image we want to run predictions on. Here, we’re using tf.tensor to create a tensor from an array pixels with shape [1, IMAGE_SIZE, IMAGE_SIZE, 3] (1 square image of size IMAGE_SIZE x IMAGE_SIZE with 3 channels for each pixel).

const inputTensor = tf.tensor([pixels], [1, IMAGE_SIZE, IMAGE_SIZE, 3]);

Finally, getting the predictions is as simple as just calling one method. The `result` will be an array of length equal to the total number of classes, and each element of the array will contain the probability of the given input representing the corresponding class.

const result = model.predict(inputTensor);

And that’s it! With only a few lines we’ve created, trained and used a neural network for character recognition in a browser. Check out a working example yourself.

Just a disclaimer: the network layout I used was not ideal in any sense, and I only trained the model for a few hours, so the accuracy is not perfect.

Why use Neural Networks?

So, now we know how to train a model and how to run it in a browser. The final question is – why would we want to do it? Actually, there are 2 questions here – why run a network in a browser and why run a network in the first place.

Why in a browser?

The opposite of running a Neural Network in a browser would be running it on a server and communicating with it via requests. Both of these approaches are valid and have their own pros and cons, so let’s take a look at a few:

  • Actual prediction speed
    • In a browser the speed depends on the user’s device
    • On a server the prediction speed is constant
  • Perceived prediction speed
    • In a browser the speed doesn’t depend on the number of users
    • On a server the app could be bottlenecked during high demand
  • Network usage
    • In a browser the users must download the whole model before they can use it (can be a very big file)
    • On a server users only need to send requests with their input data
  • Privacy
    • In a browser users don’t have to share their input data
    • On a server you don’t have to share the trained model
  • Connectivity
    • Browser-side networks can run offline (e.g. for PWA)
    • Server-side networks require internet connection

There are many more points, but these 5 provide a basic comparison. And, as you can see, the approach you should use highly depends on the type of application you’re developing.

Why use Neural Networks?

You might want to use Neural Networks whenever you need to find patterns in a complex set of data. Potential uses include:

  • Image classification
  • Computer vision
  • Voice recognition and generation
  • Optimizing search results

And many others. And a lot of companies already use machine learning. For example:

  • Google (Translator, Assistant, others)
  • Pinterest (Optimized content discovery)
  • Airbnb (Optimized search ranking)

And this list can go on forever.

This answers the “Why” question, and we already know the “What” and the “How”, so let’s sum it up.

To sum up

Neural Networks can do a lot of things. Coming up with a suitable model layout can be challenging, but for many tasks it’s still easier than writing an actual algorithm. The potential uses are endless, and if you ask me – Neural Networks are the future.

Source code?

The full source code for everything covered above is available at my GitHub.


Watch me talk about this topic on Reactive Online Meetup:


For more engineering insights shared by Mews tech team:

You've successfully subscribed to Mews Developers
Welcome back! You've successfully signed in.
Great! You've successfully signed up.
Success! Your account is fully activated, you now have access to all content.