Tensorflow.js tf.Sequential class .trainOnBatch() Method

Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.
The .trainOnBatch() function is used to run a separate gradient update on a particular batch of data.
Note:
This method varies from fit() as well as fitDataset() in the following ways:
- This method works on absolutely one batch of data.
- This method simply returns the loss and metric values, in place of returning the batch by batch loss as well as metric values.
- This method doesn’t favor fine-grained options like verbosity and callbacks.
Syntax:
trainOnBatch(x, y)
Parameters:
- x: The stated input data. It can be of type tf.Tensor, tf.Tensor[], or {[inputName: string]: tf.Tensor}. It can be any one of the following:
- A stated tf.Tensor, or else an array of tf.Tensors if the stated model possesses multiple inputs.
- An Object plotting input names to the matching tf.Tensor in case the stated model possesses named inputs.
- y: The stated Target data. It can be of type tf.Tensor, tf.Tensor[], or {[inputName: string]: tf.Tensor}. It must be constant with respect to x.
Return Value: It returns the promise of number or number[].
Example 1:
Javascript
// Importing the tensorflow.js libraryimport * as tf from "@tensorflow/tfjs"Â
// Training Model const gfg = tf.sequential();   // Adding layer to model const layer = tf.layers.dense({units:3,                inputShape : [5]});   gfg.add(layer);     // Compiling our model const config = {optimizer:'sgd',               loss:'meanSquaredError'};  gfg.compile(config);Â
// Test tensor and target tensorconst xs = tf.ones([3, 5]);const ys = tf.ones([3, 3]);Â Â Â Â Â // Calling trainOnBatch() methodconst result = await gfg.trainOnBatch(xs, ys);Â
// Printing outputconsole.log(result); |
Output:
0.3589147925376892
Example 2:
Javascript
// Importing the tensorflow.js libraryimport * as tf from "@tensorflow/tfjs"Â
async function run() {Â
  // Training Model   const gfg = tf.sequential();     // Adding layer to model   const layer = tf.layers.dense({units:2,                inputShape : [2]});  gfg.add(layer);       // Compiling our model   const config = {optimizer:'sgd',               loss:'meanSquaredError'};  gfg.compile(config);Â
  // Test tensor and target tensor  const xs = tf.truncatedNormal([3, 2]);  const ys = tf.randomNormal([3, 2]);       // Calling trainOnBatch() method  const result = await gfg.trainOnBatch(xs, ys);Â
  // Printing output  console.log(JSON.stringify(+result));}   // Function callawait run(); |
Output:
1.6889342069625854
Reference: https://js.tensorflow.org/api/latest/#tf.Sequential.trainOnBatch
Whether you’re preparing for your first job interview or aiming to upskill in this ever-evolving tech landscape, zambiatek Courses are your key to success. We provide top-quality content at affordable prices, all geared towards accelerating your growth in a time-bound manner. Join the millions we’ve already empowered, and we’re here to do the same for you. Don’t miss out – check it out now!



