Pytorch与MXNet对照学习

    技术2022-07-10  125

    Pytorch与MXNet对照学习

    这篇是对MXNet tutorial网站的翻译和总结笔记,适合有一定PyTorch经验,正在转向使用MXNet的同学

    1.运行效率

    根据NVidia performance benchmarks 在2019年四月测试结果显示Apache MXNet 在 training ResNet-50时优于 PyTorch ~77% : 10,925 images 每秒 vs. 6,175.

    2.读取数据

    PyTorch:
    import torch x = torch.ones(5,3) y = x + 1 y
    MXNet:

    创建tensor时,MXNet需要传入tuple

    from mxnet import nd x = nd.ones((5,3)) y = x + 1 y

    3. 创建模型

    PyTorch:
    import torch.nn as pt_nn pt_net = pt_nn.Sequential( pt_nn.Linear(28*28, 256), pt_nn.ReLU(), pt_nn.Linear(256, 10))
    MXNet:
    Dense中无需传入input size,MXNet会在第一次forward pass时自动推断input size。可以在全连接层和卷积层里传入激活函数 e.g. activation=‘relu’ import mxnet.gluon.nn as mx_nn mx_net = mx_nn.Sequential() mx_net.add(mx_nn.Dense(256, activation='relu'), mx_nn.Dense(10)) mx_net.initialize()

    4. 损失函数与优化算法

    PyTorch:
    pt_loss_fn = pt_nn.CrossEntropyLoss() pt_trainer = torch.optim.SGD(pt_net.parameters(), lr=0.1)
    MXNet:
    使用Trainer class, 可以接受一个优化算法作为参数 e.g. ‘sgd’从network中获取参数使用.collect_params() mx_loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() mx_trainer = gluon.Trainer(mx_net.collect_params(), 'sgd', {'learning_rate': 0.1})

    5. 训练

    PyTorch:
    import time for epoch in range(5): total_loss = .0 tic = time.time() for X, y in pt_train_data: pt_trainer.zero_grad() loss = pt_loss_fn(pt_net(X.view(-1, 28*28)), y) loss.backward() pt_trainer.step() total_loss += loss.mean() print('epoch %d, avg loss %.4f, time %.2f' % ( epoch, total_loss/len(pt_train_data), time.time()-tic))
    MXNet:
    计算需要在autograd.record()范围内,才能在反向传播时自动积分无需像PyTorch每次调用optimizer.zero_grad(),MXNet中默认写入新梯度,而不是累积update weights时,需要传入update的size,一般为batch_size需要调用asscalar(),把多维数组转化为标量 from mxnet import autograd for epoch in range(5): total_loss = .0 tic = time.time() for X, y in mx_train_data: with autograd.record(): loss = mx_loss_fn(mx_net(X), y) loss.backward() mx_trainer.step(batch_size=128) total_loss += loss.mean().asscalar() print('epoch %d, avg loss %.4f, time %.2f' % ( epoch, total_loss/len(mx_train_data), time.time()-tic))
    Processed: 0.013, SQL: 9