2026a
# train
训练浅层神经网络
函数库: TyDeepLearning
# 语法
trainedNet = train(net, train_data, train_label; epochs=100, loss_fun="MSELoss", optimizer="Adam", lr=0.1)
# 说明
trainedNet = train(net, train_data, train_label; epochs=100, loss_fun="MSELoss", optimizer="Adam", lr=0.1) 根据参数设置训练一个网络。示例
# 示例
训练前馈神经网络
加载数据。
using TyDeepLearning
set_backend(:mindspore)
file = dataset_dir("simplefit")
x, t = simplefit_dataset(file)
训练一个前馈神经网络。
hiddenSizes = 10
net = feedforwardnet(hiddenSizes)
net_trained = train(net, x, t; epochs = 5000, lr = 0.05)
使用已训练的网络进行预测,并计算误差。
y = TyDeepLearning.predict(net_trained, x)
error = mse(t, y)
0.7069544690720578
神经网络每次训练具有随机性,输出结果不会完全一致。
# 输入参数
net-输入网络网络对象
输入网络,指定为网络对象。
数据类型: Float16 | Float32 | Float64 | Int64
train_data-训练输入数组、矩阵
训练输入,指定为矩阵或单元数组。
数据类型: Float16 | Float32 | Float64 | Int64
train_label-训练标签数组、矩阵
训练标签,指定为矩阵或单元数组。
数据类型: Float16 | Float32 | Float64 | Int64
epochs-训练次数标量
用于训练的最大次数,指定为正整数。 迭代是梯度下降算法中用于使用小批量最小化损失函数的一个步骤。epoch 是训练算法对整个训练集的完整传递。
loss_fun-损失函数字符串
损失函数,可以指定为以下:
| 损失函数 | 说明 |
|---|---|
| BCELoss | 计算目标值和预测值之间的二值交叉熵损失值 |
| BCEWithLogitsLoss | 输入经过sigmoid激活函数后作为预测值,BCEWithLogitsLoss 计算预测值和目标值之间的二值交叉熵损失 |
| CosineEmbeddingLoss | 余弦相似度损失函数,用于测量两个Tensor之间的相似性 |
| CrossEntropyLoss | 计算预测值和目标值之间的交叉熵损失 |
| DiceLoss | Dice系数是一个集合相似性loss,用于计算两个样本之间的相似性 |
| HuberLoss | HuberLoss计算预测值和目标值之间的误差 |
| L1Loss | L1Loss用于计算预测值和目标值之间的平均绝对误差 |
| MSELoss | 用于计算预测值与标签值之间的均方误差 |
| MultiClassDiceLoss | 对于多标签问题,可以将标签通过one-hot编码转换为多个二分类标签 |
| NLLLoss | 计算预测值和目标值之间的负对数似然损失 |
| RMSELoss | RMSELoss用来测量 x 和 y 元素之间的均方根误差,其中 x 是输入Tensor, y 是目标值 |
| SmoothL1Loss | SmoothL1损失函数,如果预测值和目标值的逐个元素绝对误差小于设定阈值 beta 则用平方项,否则用绝对误差项 |
| SoftmaxCrossEntropyWithLogits | 计算预测值与真实值之间的交叉熵 |
数据类型: string
optimizer-优化器字符串
优化器,可以指定为以下:
| 优化器 | 说明 |
|---|---|
| Adadelta | Adadelta算法的实现 |
| Adagrad | Adagrad算法的实现 |
| Adam | Adaptive Moment Estimation (Adam)算法的实现 |
| AdaMax | AdaMax算法是基于无穷范数的Adam的一种变体 |
| AdamOffload | 此优化器在主机CPU上运行Adam优化算法,设备上仅执行网络参数的更新,最大限度地降低内存成本 |
| AdamWeightDecay | 权重衰减Adam算法的实现 |
| ASGD | 实现平均随机梯度下降 |
| LazyAdam | Adaptive Moment Estimation (Adam)算法的实现 |
| Momentum | Momentum算法的实现 |
| RMSProp | 均方根传播(RMSProp)算法的实现 |
| Rprop | 实现弹性反向传播 |
| SGD | 随机梯度下降的实现 |
数据类型: string
lr-学习率标量
用于训练的初始学习率,指定为正标量。 如果学习率太低,则训练可能需要很长时间。如果学习率太高,则训练可能会达到次优结果或发散。
数据类型: Float16 | Float32 | Float64 | Int64
# 输出参数
trainedNet-训练的网络网络对象
已训练的网络,作为网络对象返回