12.4 TensorFlow.js
TensorFlow.js is a library for machine learning in JavaScript
Created Date: 2025-07-25
TensorFlow.js is a JavaScript library for training and deploying machine learning models in the web browser and in Node.js. This tutorial shows you how to get started with TensorFlow.js by training a minimal model in the browser and using the model to make a prediction.
12.4.1 Simple Demo
This demo trains a simple linear regression model \(y = 2x - 1\) and predicts a value.
Training Progress:
Epoch: 0 / 500
Current Loss: N/A
Learned Weight (w): N/A
Learned Bias (b): N/A
Prediction Result:
Click the button above to train the model and make a prediction. The training will run for 500 epochs, and you will see the progress in real-time.
model = tf.sequential();
const numEpochs = 500;
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' });
// Generate some synthetic data for training.
// y = 2x - 1
const xs = tf.tensor2d([1, 2, 3, 4, 5, 6, 7], [7, 1]);
const ys = tf.tensor2d([1, 3, 5, 7, 9, 10, 13], [7, 1]);
// await model.fit(xs, ys, { epochs: 500 });
// Train the model using the data with a real-time callback.
await model.fit(xs, ys, {
epochs: numEpochs,
callbacks: {
onEpochEnd: async (epoch, logs) => {
// Get the current weights and biases
const weights = model.getWeights();
const learnedWeight = weights[0].dataSync()[0]; // The 'w' value
const learnedBias = weights[1].dataSync()[0]; // The 'b' value
// Update HTML elements in real-time
document.getElementById('epochCount').innerText = epoch + 1; // Epochs are 0-indexed
document.getElementById('currentLoss').innerText = logs.loss.toFixed(5); // Display loss
document.getElementById('learnedWeight').innerText = learnedWeight.toFixed(5);
document.getElementById('learnedBias').innerText = learnedBias.toFixed(5);
}
}
});
// Use the model to do inference on a data point the model hasn't seen before:
const predictionTensor = model.predict(tf.tensor2d([5], [1, 1]));
12.4.2 MNIST
TensorFlow.js also supports more complex models like the MNIST digit classification model.
Draw a digit
Loading...
Prediction:
Click the button above to predict the digit you drew on the canvas. The model will classify the digit and display the result.