12.4 TensorFlow.js

TensorFlow.js is a library for machine learning in JavaScript

Created Date: 2025-07-25

TensorFlow.js is a JavaScript library for training and deploying machine learning models in the web browser and in Node.js. This tutorial shows you how to get started with TensorFlow.js by training a minimal model in the browser and using the model to make a prediction.

12.4.1 Simple Demo

This demo trains a simple linear regression model \(y = 2x - 1\) and predicts a value.

Training Progress:

Epoch: 0 / 500

Current Loss: N/A

Learned Weight (w): N/A

Learned Bias (b): N/A

Prediction Result:

Click the button above to train the model and make a prediction. The training will run for 500 epochs, and you will see the progress in real-time.

model = tf.sequential();
const numEpochs = 500;
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));

// Prepare the model for training: Specify the loss and the optimizer.
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });

// Generate some synthetic data for training.
// y = 2x - 1
const xs = tf.tensor2d([1, 2, 3, 4, 5, 6, 7], [7, 1]);
const ys = tf.tensor2d([1, 3, 5, 7, 9, 10, 13], [7, 1]);

// await model.fit(xs, ys, { epochs: 500 });
// Train the model using the data with a real-time callback.
await model.fit(xs, ys, {
    epochs: numEpochs,
    callbacks: {
        onEpochEnd: async (epoch, logs) => {
            // Get the current weights and biases
            const weights = model.getWeights();
            const learnedWeight = weights[0].dataSync()[0]; // The 'w' value
            const learnedBias = weights[1].dataSync()[0];   // The 'b' value

            // Update HTML elements in real-time
            document.getElementById('epochCount').innerText = epoch + 1; // Epochs are 0-indexed
            document.getElementById('currentLoss').innerText = logs.loss.toFixed(5); // Display loss
            document.getElementById('learnedWeight').innerText = learnedWeight.toFixed(5);
            document.getElementById('learnedBias').innerText = learnedBias.toFixed(5);
        }
    }
});

// Use the model to do inference on a data point the model hasn't seen before:
const predictionTensor = model.predict(tf.tensor2d([5], [1, 1]));

12.4.2 MNIST

TensorFlow.js also supports more complex models like the MNIST digit classification model.

Draw a digit


Loading...

Prediction:

Click the button above to predict the digit you drew on the canvas. The model will classify the digit and display the result.

12.4.3 Piano Demo