
Trickle AI
一站式无代码 AI 开发平台
相关链接:Gitee项目地址GitHub项目地址
昇思MindSpore是华为推出的适用端边云场景的新型开源全场景深度学习框架,昇思MindSpore具备强大的分布式训练能力,内置多种并行策略,简化大模型开发。昇思MindSpore与昇腾处理器深度适配,充分发挥硬件性能,缩短训练时间并提升推理效率。昇思MindSpore支持AI与高性能计算(HPC)融合,满足AI for Science场景需求。昇思MindSpore生态丰富,提供开源项目、案例和SOTA模型,方便开发者快速上手和应用。

pip install mindspore
pip install mindspore-gpu==1.10.0
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.6.0/MindSpore/unified/aarch64/mindspore-2.6.0-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple
import mindspore
print(mindspore.__version__)
import mindspore.dataset as ds
from mindspore.dataset.transforms import Compose, ToTensor, Normalize
# 加载MNIST数据集
dataset = ds.MnistDataset("path/to/mnist_dataset")
transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
dataset = dataset.map(operations=transforms, input_columns=["image"])
dataset = dataset.batch(batch_size=64)
import mindspore.nn as nn
import mindspore.ops as ops
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.flatten = ops.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet5()
Model类,简化训练和评估流程。from mindspore.train import Model
from mindspore.nn import SoftmaxCrossEntropyWithLogits, Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
# 定义损失函数和优化器
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = Momentum(model.trainable_params(), learning_rate=0.01, momentum=0.9)
# 创建Model实例
model = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})
# 设置保存检查点的配置
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="lenet", config=config_ck)
# 开始训练
model.train(10, dataset, callbacks=[ckpoint_cb])
# 加载测试数据集
test_dataset = ds.MnistDataset("path/to/mnist_test_dataset")
test_dataset = test_dataset.map(operations=transforms, input_columns=["image"])
test_dataset = test_dataset.batch(batch_size=64)
# 评估模型
acc = model.eval(test_dataset)
print(f"Accuracy: {acc['accuracy']}")