2026a

# train


训练浅层神经网络

函数库: TyDeepLearning

# 语法

trainedNet = train(net, train_data, train_label; epochs=100, loss_fun="MSELoss", optimizer="Adam", lr=0.1)

# 说明

trainedNet = train(net, train_data, train_label; epochs=100, loss_fun="MSELoss", optimizer="Adam", lr=0.1) 根据参数设置训练一个网络。示例

# 示例

训练前馈神经网络

加载数据。

using TyDeepLearning
set_backend(:mindspore)
file = dataset_dir("simplefit")
x, t = simplefit_dataset(file)

训练一个前馈神经网络。

hiddenSizes = 10
net = feedforwardnet(hiddenSizes)
net_trained = train(net, x, t; epochs = 5000, lr = 0.05)

使用已训练的网络进行预测,并计算误差。

y = TyDeepLearning.predict(net_trained, x)
error = mse(t, y)
0.7069544690720578

神经网络每次训练具有随机性,输出结果不会完全一致。

# 输入参数

net-输入网络
网络对象

输入网络,指定为网络对象。

数据类型: Float16 | Float32 | Float64 | Int64

train_data-训练输入
数组、矩阵

训练输入,指定为矩阵或单元数组。

数据类型: Float16 | Float32 | Float64 | Int64

train_label-训练标签
数组、矩阵

训练标签,指定为矩阵或单元数组。

数据类型: Float16 | Float32 | Float64 | Int64

epochs-训练次数
标量

用于训练的最大次数,指定为正整数。 迭代是梯度下降算法中用于使用小批量最小化损失函数的一个步骤。epoch 是训练算法对整个训练集的完整传递。

loss_fun-损失函数
字符串

损失函数,可以指定为以下:

损失函数 说明
BCELoss 计算目标值和预测值之间的二值交叉熵损失值
BCEWithLogitsLoss 输入经过sigmoid激活函数后作为预测值,BCEWithLogitsLoss 计算预测值和目标值之间的二值交叉熵损失
CosineEmbeddingLoss 余弦相似度损失函数,用于测量两个Tensor之间的相似性
CrossEntropyLoss 计算预测值和目标值之间的交叉熵损失
DiceLoss Dice系数是一个集合相似性loss,用于计算两个样本之间的相似性
HuberLoss HuberLoss计算预测值和目标值之间的误差
L1Loss L1Loss用于计算预测值和目标值之间的平均绝对误差
MSELoss 用于计算预测值与标签值之间的均方误差
MultiClassDiceLoss 对于多标签问题,可以将标签通过one-hot编码转换为多个二分类标签
NLLLoss 计算预测值和目标值之间的负对数似然损失
RMSELoss RMSELoss用来测量 x 和 y 元素之间的均方根误差,其中 x 是输入Tensor, y 是目标值
SmoothL1Loss SmoothL1损失函数,如果预测值和目标值的逐个元素绝对误差小于设定阈值 beta 则用平方项,否则用绝对误差项
SoftmaxCrossEntropyWithLogits 计算预测值与真实值之间的交叉熵

数据类型: string

optimizer-优化器
字符串

优化器,可以指定为以下:

优化器 说明
Adadelta Adadelta算法的实现
Adagrad Adagrad算法的实现
Adam Adaptive Moment Estimation (Adam)算法的实现
AdaMax AdaMax算法是基于无穷范数的Adam的一种变体
AdamOffload 此优化器在主机CPU上运行Adam优化算法,设备上仅执行网络参数的更新,最大限度地降低内存成本
AdamWeightDecay 权重衰减Adam算法的实现
ASGD 实现平均随机梯度下降
LazyAdam Adaptive Moment Estimation (Adam)算法的实现
Momentum Momentum算法的实现
RMSProp 均方根传播(RMSProp)算法的实现
Rprop 实现弹性反向传播
SGD 随机梯度下降的实现

数据类型: string

lr-学习率
标量

用于训练的初始学习率,指定为正标量。 如果学习率太低,则训练可能需要很长时间。如果学习率太高,则训练可能会达到次优结果或发散。

数据类型: Float16 | Float32 | Float64 | Int64

# 输出参数

trainedNet-训练的网络
网络对象

已训练的网络,作为网络对象返回

# 另请参阅

feedforwardnet | predict