在浏览器中运行 TensorFlow.js 来训练模型并给出预测结果(Iris 数据集)
warning:
这篇文章距离上次修改已过190天,其中的内容可能已经有所变动。
// 引入TensorFlow.js库
const tf = require('@tensorflow/tfjs-node');
const iris = require('./data/iris');
// 定义模型参数
const BATCH_SIZE = 100;
const TRAIN_EPOCHS = 200;
// 准备数据
const xTrain = tf.tensor2d(iris.train.xs, [iris.train.xs.length, 4]);
const yTrain = tf.tensor2d(iris.train.ys, [iris.train.ys.length, 3]);
const xTest = tf.tensor2d(iris.test.xs, [iris.test.xs.length, 4]);
const yTest = tf.tensor2d(iris.test.ys, [iris.test.ys.length, 3]);
// 创建模型
const model = tf.sequential({
layers: [
tf.layers.dense({inputShape: [4], units: 10, activation: 'relu'}),
tf.layers.dense({units: 3, activation: 'softmax'})
]
});
// 编译模型
model.compile({
optimizer: 'adam',
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
// 训练模型
model.fit(xTrain, yTrain, {
batchSize: BATCH_SIZE,
epochs: TRAIN_EPOCHS,
validationData: [xTest, yTest],
}).then(() => {
// 进行预测
const result = model.predict(tf.tensor2d([[6.4, 3.2, 4.5, 1.5]], [1, 4]))
.arraySync();
console.log('Predicted output:', result);
});
这段代码使用TensorFlow.js在Node.js环境中创建了一个多层感知器模型,用于识别鸢尾花(Iris Flower)数据集中的不同种类。在模型训练过程中,使用了训练数据和测试数据,并在最后输出了对单个样本的预测结果。这个过程展示了如何在Node.js中进行机器学习模型的训练和预测,对于开发者来说这是一个很好的学习示例。
评论已关闭