137

機械学習関連の技術記事を投稿します。137と言えば微細構造定数

【PyTorch】C++ APIを使ってResNet50を実装する

はじめに

Experimentalな機能として提供されてきたPyTorchのC++ APIが、バージョン1.5よりStableになった1

C++ APIを利用することで、Pythonを利用できない環境や速度を求められる環境においてもPyTorchを使って深層学習のモデルを記述して学習できるようになる。 PyTorchのC++ APIは、既存のPython APIと可能な限り似た記述になるようにAPIが設計されている。 このため、Python APIを使って深層学習のモデルをこれまで書いてきた人であれば、C++ APIを使いこなすのはそこまでハードルが高くないと思われる。 ただしC++ APIは現在も開発中であり、Python APIに対応するAPIが存在しないことが多いため注意が必要である。 例えば、Forward/BackwardのHook登録一部のOptimizer は未実装となっている。

本記事では、ResNet50をC++ APIで実装するための解説を行う。 いまさらResNet50かとツッコミどころはあるが、著名で実装しやすいこともあり、あえてこのモデルを選んだ。 Python APIC++ APIを比較しやすくするために、Python APIC++ APIの両方の実装を比較しながら説明していく。 本記事を読むことにより、Python APIC++ APIがよく似た仕様となっていることを理解できるだろう。

なお、PyTorchのC++ APIを使ってプログラムを書くための準備は、PyTorch公式のチュートリアル も参考になる。 チュートリアルでは、入力層と出力層から構成される簡単なニューラルネットワークの作成から始まり、DCGANの実装までを丁寧に説明している。 必要に応じてこちらも参照されたい。

本記事で紹介するソースコード一式は、GitHub にて公開している。

注意点

本記事で紹介するResNet50の実装は、以下の点で手を抜いた実装となっている。 ソースコードを参考にする際は注意が必要である。

  • ImageNetの代わりにMNISTをデータセットとして使用
    • 入力画像サイズ:[224, 224] → [28, 28]
    • 入力画像Channel数:3(RGB) → 1(白黒)
    • クラス数:1000 → 10
  • Learning Rate固定
    • 本来であればepoch数にしたがってLearning Rateを減衰させる必要があるが、該当する機能がC++ APIとして提供されていないため実装をサボった

前提知識

本記事は、以下の読者を対象としている。

前準備

PyTorchのC++ APIを利用するためには、libtorchと呼ばれるPyTorchのライブラリをダウンロードする必要がある。 CUDAなしの環境とCUDAありの環境とでライブラリが異なるため注意が必要である。

  • CUDAなし(CPU)
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
  • CUDAあり(CPU+GPU
wget https://download.pytorch.org/libtorch/cu102/libtorch-shared-with-deps-1.5.1.zip
unzip libtorch-shared-with-deps-1.5.1.zip

また、本プログラムでは前処理のためにOpenCVを利用しているため、こちらの記事 を参考にしてインストールする。

プログラムのビルドにはCMakeを使うため、次のようにCMakeLists.txtを作成する。

対応するGitHub上のコード

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(pytorch-cpp-example)

find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)

include_directories(
    ${PROJECT_SOURCE_DIR}
    ${OPENCV_INCLUDE_DIRS}
)
add_executable(train model.cpp train.cpp)
add_executable(predict model.cpp predict.cpp)
target_link_libraries(train ${TORCH_LIBRARIES})
target_link_libraries(predict ${TORCH_LIBRARIES} ${OpenCV_LIBRARIES})
set_property(TARGET train PROPERTY CXX_STANDARD 14)
set_property(TARGET predict PROPERTY CXX_STANDARD 14)

ResNet50のプログラムを書く

ResNet50の論文 を参考に、プログラムを作成する。

1. Residual Blockの実装

ResNetは残差ブロック(Residual Block)を導入することにより、層の深いネットワークにおける勾配損失問題を解消したことで有名なニューラルネットワークモデルである。 Residual Blockの実装は、Python APIで記述すると次のようになる。

対応するGitHub上のコード

class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        width = out_channels // 4

        # (1.a-1)
        self.conv1 = torch.nn.Conv2d(in_channels, width, kernel_size=(1, 1),
                                     stride=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(width)
        self.relu1 = torch.nn.ReLU(inplace=True)

        self.conv2 = torch.nn.Conv2d(width, width, kernel_size=(3, 3),
                                     stride=stride, padding=1, groups=1,
                                     bias=False, dilation=1)
        self.bn2 = torch.nn.BatchNorm2d(width)
        self.relu2 = torch.nn.ReLU(inplace=True)

        self.conv3 = torch.nn.Conv2d(width, out_channels, kernel_size=(1, 1),
                                     stride=1, padding=0, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(out_channels)

        # (1.a-2)
        def shortcut(in_, out):
            if in_ != out:
                return torch.nn.Sequential(
                    torch.nn.Conv2d(in_, out, kernel_size=(1, 1),
                                    stride=stride, padding=0, bias=False),
                    torch.nn.BatchNorm2d(out),
                )
            else:
                return lambda x: x
        self.shortcut = shortcut(in_channels, out_channels)

        self.relu3 = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        shortcut = self.shortcut(x)

        out = self.relu3(out + shortcut)

        return out

torch.nn.Module を継承したResidualBlockクラスの中で、torch.nn モジュールで提供される各層に対応する演算APIを利用してResidual Blockを定義する (1.a-1)。 入力channelと出力channelが異なる場合に、1x1のConvolutionによってUp-sampling (1.a-2) を行う shortcut 関数に注意しよう。

つづいて、C++ APIを用いてResidual Blockを実装した場合のソースコードを示す。

対応するGitHub上のコード

Python APIで記述した場合とC++ APIで記述した場合の違いは、おもに2つある。

1つ目は、Conv2dのStrideなどの設定について、PythonではAPIの引数に渡すことで実現しているのに対し、C++では Conv2dOptions 構造体などをAPIの引数に渡すことで実現している点である (1.b-1)

2つ目は、ResidualBlock 構造体のコンストラクタの最後の処理で register_module 関数を呼び出している点である (1.b-3)register_module 関数は学習する層を登録するための関数で、この登録処理を忘れてしまうと学習時の誤差逆伝搬の処理が行えなくなってしまう。

ResidualBlockImpl::ResidualBlockImpl(int in_channels, int out_channels,
                                     int stride) {
    int width = out_channels / 4;

    // (1.b-1)
    conv1 = Conv2d(Conv2dOptions(in_channels, width, {1, 1})
                   .stride(1).bias(false));
    bn1 = BatchNorm2d(BatchNormOptions(width));
    relu1 = ReLU(ReLUOptions().inplace(true));

    conv2 = Conv2d(Conv2dOptions(width, width, {3, 3})
                   .stride(stride).padding(1).groups(1)
                   .bias(false).dilation(1));
    bn2 = BatchNorm2d(BatchNormOptions(width));
    relu2 = ReLU(ReLUOptions().inplace(true));

    conv3 = Conv2d(Conv2dOptions(width, out_channels, {1, 1})
                   .stride(1).padding(0).bias(false));
    bn3 = BatchNorm2d(BatchNormOptions(out_channels));

    // (1.b-2)
    Sequential shortcut(
        Conv2d(Conv2dOptions(in_channels, out_channels, {1, 1})
               .stride(stride).padding(0).bias(false)),
        BatchNorm2d(BatchNormOptions(out_channels))
    );
    this->shortcut = shortcut;
    relu3 = ReLU(ReLUOptions().inplace(true));

    this->in_channels = in_channels;
    this->out_channels = out_channels;

    // // (1.b-3)
    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("relu1", relu1);
    register_module("conv2", conv2);
    register_module("bn2", bn2);
    register_module("relu2", relu2);
    register_module("conv3", conv3);
    register_module("bn3", bn3);
    register_module("shortcut", shortcut);
    register_module("relu3", relu3);
}

torch::Tensor ResidualBlockImpl::forward(torch::Tensor input) {
    torch::Tensor out;
    torch::Tensor tmp;

    out = conv1->forward(input);
    out = bn1->forward(out);
    out = relu1->forward(out);

    out = conv2->forward(out);
    out = bn2->forward(out);
    out = relu2->forward(out);

    out = conv3->forward(out);
    out = bn3->forward(out);

    if (in_channels != out_channels) {
        tmp = shortcut->forward(input);
    } else {
        tmp = input;
    }
    out = relu3->forward(out + tmp);

    return out;
}

2. ResNet50の実装

Residual Blockを組み合わせて、ResNet50を実装する。 最初にPython APIによるResNet50の実装を示す。

対応するGitHub上のコード

学習データとしてMNISTを利用するため、ResNet50モデルの入力データのサイズや出力クラス数に気をつけよう。 ResNet50モデルでは、最初の torch.nn.Conv2d (2.a-1) (2.b-1) と最後の torch.nn.Linear (2.a-2) (2.b-2) の引数に気をつける。

class ResNet50(torch.nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        # (2.a-1)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7),
                                     stride=2, padding=3, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=(3, 3), stride=2,
                                          padding=1)

        self.layer1 = torch.nn.Sequential(
            ResidualBlock(64, 256),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
        )

        self.layer2 = torch.nn.Sequential(
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
        )

        self.layer3 = torch.nn.Sequential(
            ResidualBlock(512, 1024, stride=2),
            ResidualBlock(1024, 1024),
            ResidualBlock(1024, 1024),
            ResidualBlock(1024, 1024),
            ResidualBlock(1024, 1024),
            ResidualBlock(1024, 1024),
        )

        self.layer4 = torch.nn.Sequential(
            ResidualBlock(1024, 2048, stride=2),
            ResidualBlock(2048, 2048),
            ResidualBlock(2048, 2048),
        )

        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = torch.nn.Flatten(1)
        # (2.a-2)
        self.fc = torch.nn.Linear(2048, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x

続いて、C++ APIによる実装を示す。

対応するGitHub上のコード

ResNet50Impl::ResNet50Impl() {
    // (2.b-1)
    conv1 = Conv2d(Conv2dOptions(1, 64, {7, 7})
                   .stride(2).padding(3).bias(false));
    bn1 = BatchNorm2d(BatchNormOptions(64));
    relu = ReLU(ReLUOptions().inplace(true));
    maxpool = MaxPool2d(MaxPoolOptions<2>({3, 3}).stride(2).padding(1));

    Sequential layer1(
        ResidualBlock(64, 256),
        ResidualBlock(256, 256),
        ResidualBlock(256, 256)
    );
    this->layer1 = layer1;

    Sequential layer2(
        ResidualBlock(256, 512, 2),
        ResidualBlock(512, 512),
        ResidualBlock(512, 512),
        ResidualBlock(512, 512)
    );
    this->layer2 = layer2;

    Sequential layer3(
        ResidualBlock(512, 1024, 2),
        ResidualBlock(1024, 1024),
        ResidualBlock(1024, 1024),
        ResidualBlock(1024, 1024),
        ResidualBlock(1024, 1024),
        ResidualBlock(1024, 1024)
    );
    this->layer3 = layer3;

    Sequential layer4(
        ResidualBlock(1024, 2048, 2),
        ResidualBlock(2048, 2048),
        ResidualBlock(2048, 2048)
    );
    this->layer4 = layer4;

    avgpool = AdaptiveAvgPool2d(AdaptiveAvgPool2dOptions({1, 1}));
    flatten = Flatten(FlattenOptions().start_dim(1));

    // (2.b-2)
    fc = Linear(2048, 10);

    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("relu", relu);
    register_module("maxpool", maxpool);
    register_module("layer1", this->layer1);
    register_module("layer2", this->layer2);
    register_module("layer3", this->layer3);
    register_module("layer4", this->layer4);
    register_module("avgpool", avgpool);
    register_module("flatten", flatten);
    register_module("fc", fc);
}

torch::Tensor ResNet50Impl::forward(torch::Tensor input) {
    torch::Tensor out;

    out = conv1->forward(input);
    out = bn1->forward(out);
    out = relu->forward(out);
    out = maxpool->forward(out);

    out = layer1->forward(out);
    out = layer2->forward(out);
    out = layer3->forward(out);
    out = layer4->forward(out);

    out = avgpool->forward(out);
    out = flatten->forward(out);
    out = fc->forward(out);

    return out;
}

ResidualBlock の実装と同じような要領で実装していけばよく、C++ APIでの実装に関して特筆すべきことはない。

ここで、実装が ResidualBlockImpl 構造体であるのに対して ResidualBlock としてアクセスできていることが気になるかもしれない。 これはPyTorchをC++で書いた時の流儀で、実装を XXXXImpl という構造体名として TORCH_MODULE マクロでその構造体を登録する ことで、XXXX 構造体としてアクセスできるようになる。

3. 学習処理の実装

ResNet50を学習するための処理を実装する。 実装は次の順番で行っていく。

i) コマンドライン引数の解析、乱数固定
ii) 学習を実行するデバイス(CPUかGPUか)の決定
iii) ResNet50モデルとOptimizerの定義
iv) MNISTデータセットの読み込み
v) 学習
vi) 評価
vii) 学習済みモデルの保存

i) コマンドライン引数の解析、乱数固定

最初にコマンドラインの解析と、再現性確保のための乱数固定を行う。 乱数固定に関しては、こちらの記事 を参照のこと。

対応するGitHub上のコード(Python API)

対応するGitHub上のコード(C++ API)

ii) 学習を実行するデバイス(CPUかGPUか)の決定

学習を実行するデバイスを決定する処理をPython APIで実装すると、次のようになる。

対応するGitHub上のコード

    # (3.a-1)
    # Parse arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m", dest="saved_model_path", type=str,
        help="Path to saved model", required=True)
    args = parser.parse_args()

    fix_randomness(1)

torch.cuda.is_available 関数を使い、CUDAが利用可能な場合はGPU、利用できない場合はCPUで学習を実行するようにデバイスを決定する (3.a-1)

同様の処理をC++ APIで実装した場合は次のようになる。

対応するGitHub上のコード

    // (3.b-1)
    // Create device.
    torch::DeviceType device_type;
    if (torch::cuda::is_available()) {
        std::cout << "Train on GPU." << std::endl;
        device_type = torch::kCUDA;
    } else {
        std::cout << "Train on CPU." << std::endl;
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);

対応するC++ APIに置き換えればよいだけなので、特別に注意する点はない。

iii) ResNet50モデルとOptimizerの定義

ResNet50モデルとOptimizerの定義は、Python APIでは次のようになる。

対応するGitHub上のコード

    # Build model.
    model = ResNet50()
    model.to(device)
    summary(model, (1, 28, 28))
    # (3.a-2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

本記事の最初にも書いたが、OptimizerのLearning Rateはepochが進むにつれて減衰させるべきところを決め打ちで値を指定している (3.a-2) (3.b-2)。 Learning Rateを逐次変更したい場合は、Pythonでは torch.optim.lr_scheduler モジュールを利用できるが、C++ APIに同様のAPIが存在しないため自力で実装する必要がある。

C++ APIにおける実装を次に示す。

対応するGitHub上のコード

    // Build model.
    ResNet50 model;
    model->to(device);
    // (3.b-2)
    torch::optim::Adam optimizer(
        model->parameters(), torch::optim::AdamOptions(0.01));

ここに関しても特別に注意する点はないだろう。

iv) MNISTデータセットの読み込み

次に、MNISTのデータセットを読み込む。

対応するGitHub上のコード

    # Load dataset.
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(DATA_ROOT, train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   ])
        ),
        batch_size=TRAIN_BATCH_SIZE,
    )
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(DATA_ROOT, train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   ])
        ),
        batch_size=TEST_BATCH_SIZE,
    )

Python APIの場合は、torch.utils.data.DataLoader クラスを利用し、引数に torch.utils.data.Dataset から継承されたMNISTのデータセットを読み込むための便利クラス torchvision.datasets.MNISTインスタンスを渡す。 torchvision.datasets.MNIST の引数 downloadTrue を指定することで、DATA_ROOT にデータが存在しない場合は、MNISTのデータセットをダウンロードしてから読み込んでくれる。 引数 shuffle が指定されていないことから、引数 shuffle はデフォルトで False となり、シーケンシャルにデータセットからサンプリングするようになる。

データ読み込み時は、transform 引数に入力データに対して適用する変形処理を渡す。 ここでは、次のような変形処理を行った。

  • テンソルの各要素の値を[0, 255]から[0.0, 1.0]に正規化し、データレイアウトを(H, W, C)から(C, H, W)へ変換
  • 各要素を平均0.1307、標準偏差0.3081に標準化

続いてC++ APIを用いた実装を示す。

対応するGitHub上のコード

    // Load dataset.
    auto train_dataset = torch::data::datasets::MNIST(kDataRoot)
        .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
        .map(torch::data::transforms::Stack<>());
    auto train_loader =
        torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
        std::move(train_dataset), kTrainBatchSize);

    auto test_dataset = torch::data::datasets::MNIST(
        kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
        .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
        .map(torch::data::transforms::Stack<>());
    auto test_loader =
        torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
        std::move(test_dataset), kTestBatchSize);

Python APItorch.utils.data.DataLoader に対応する DataLoader は、テンプレート関数 torch::data::make_data_loader を使って生成できる。 テンプレート引数には、バッチを作るためのデータのサンプリング方法を指定可能である。 今回はPython APIと同様にシーケンシャルにサンプリングするために、torch::data::samplers::SequentialSampler を利用した。

v) 学習

学習のループ処理を実装する。 最初に、DataLoaderから学習用のデータを使ってResNet50のモデルを訓練するコードを示す。

対応するGitHub上のコード

    # Train loop.
    for epoch in range(NUMBER_OF_EPOCHS):
        print("Epoch {}:".format(epoch))

        # Train.
        print("Start train.")
        model.train()    # (3.a-3)
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()    # (3.a-4)
            data = data.to(device)    # (3.a-5)
            target = target.to(device)

            output = model(data)    # (3.a-6)

            # (3.a-7)
            prob = F.log_softmax(output, dim=1)
            loss = F.nll_loss(prob, target)
            loss.backward()
            optimizer.step()

            if batch_idx % LOG_INTERVAL == 0:
                print("Batch: {}, Loss: {}".format(batch_idx, loss.item()))

ResNet50のモデル model をtraining modeに変更したあと (3.a-3) (3.b-3)、DataLoaderから学習用のデータを1バッチずつ取得する。 DataLoaderから読み込んだデータは、data に説明変数、target に目的変数がテンソルtorch.Tensor)として保存されている。 batch_idx は、一定間隔でloss値をログとして使用するときに利用する。

バッチを読み込んだあと、Backward Propagationの初期値をリセットするために optimizer.zero_grad を呼びだす (3.a-4) (3.b-4)。 そして、dataとtargetのテンソルデータを学習を実行するデバイスのメモリに移動させ (3.a-5) (3.b-5)、ResNet50モデルのForward Propagationを実行する (3.a-6) (3.b-6)。 最後に、Forward Propagationの結果を使ってロス値を計算(Log Softmax + NLL-Loss = CrossEntropyLoss)したあと、Backward Propagationを行って各パラメータの勾配を求め、求めた勾配を使ってパラメータを更新する (3.a-7) (3.b-7)

この処理をC++ APIで実装すると、次のようになる。

対応するGitHub上のコード

    // Train loop.
    for (size_t epoch = 0; epoch < kNumberOfEpochs; ++epoch) {
        std::cout << "Epoch " << epoch << ":" << std::endl;

        // Train.
        std::cout << "Start train." << std::endl;
        size_t batch_idx = 0;
        model->train();    // (3.b-3)
        for (auto& batch : *train_loader) {
            optimizer.zero_grad();    // (3.b-4)
            auto data = batch.data.to(device);    // (3.b-5)
            auto target = batch.target.to(device);

            auto output = model->forward(data);    // (3.b-6)

            // (3.b-7)
            auto prob = F::log_softmax(output, 1);
            auto loss = F::nll_loss(prob, target);
            AT_ASSERT(!std::isnan(loss.template item<float>()));
            loss.backward();
            optimizer.step();

            if ((batch_idx % kLogInterval) == 0) {
                std::cout << "Batch: " << batch_idx << ", Loss: "
                          << loss.template item<float>() << std::endl;
            }
            batch_idx++;
        }

vi) 評価

学習結果を評価する処理を実装する。

対応するGitHub上のコード

        # Evaluate.
        print("Start eval.")
        model.eval()    # (3.a-8)
        test_loss = 0
        correct = 0.0
        total = 0
        with torch.no_grad():    # (3.a-9)
            for data, target in test_loader:
                data = data.to(device)
                target = target.to(device)

                output = model(data)    # (3.a-10)

                prob = F.log_softmax(output, dim=1)
                test_loss += F.nll_loss(prob, target, reduction="sum").item()    # (3.a-11)
                pred = output.argmax(1, keepdim=True)    # (3.a-12)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += TEST_BATCH_SIZE

        print("Average loss: {}, Accuracy: {}"
              .format(test_loss / loss, correct / total))

ResNet50のモデル model をeval modeに変更したあと (3.a-8) (3.b.8)、DataLoaderから評価用のデータを1バッチずつ取得する。

評価時は勾配計算が行われないように、torch.no_grad コンテキスト内で評価のためのコードを実行させる (3.a-9) (3.b-9)torch.no_grad コンテキストにより、勾配のためにテンソルデータを保持する必要がなくなり、メモリの消費量を抑えることができる。

評価用の画像データに対してクラスを予測するため、ResNet50のForward Propagationを行う (3.a-10) (3.a-10)。 その結果を利用し、loss値の計算 (3.a-11) (3.b-11) と評価用データのクラス予測 (3.a-12) (3.b-12) を行う。

上記の処理をC++ APIで実装すると、次のようになる。

対応するGitHub上のコード

        // Evaluate.
        std::cout << "Start eval." << std::endl;
        torch::NoGradGuard no_grad;    // (3.b-9)
        model->eval();    // (3.b-8)
        double test_loss = 0.0;
        size_t correct = 0;
        size_t total = 0;
        for (auto& batch : *test_loader) {
            auto data = batch.data.to(device);
            auto target = batch.target.to(device);
            auto output = model->forward(data);    // (3.b-10)

            auto prob = F::log_softmax(output, 1);
            test_loss += F::nll_loss(
                prob, target,
                F::NLLLossFuncOptions().reduction(torch::kSum)).template item<double>();    // (3.b-11)
            auto pred = output.argmax(1, true);    // (3.b-12)
            correct += pred.eq(target.view_as(pred)).sum().template item<int64_t>();
            total += kTestBatchSize;
        }

        std::cout << "Average loss: " << test_loss / total
                  << ", Accuracy: " << static_cast<double>(correct) / total
                  << std::endl;

torch.no_grad コンテキストに相当する torch::NoGradGuard の変数 no_grad が、変数の生存期間中有効になる点に注意が必要である。

vii) 学習済みモデルの保存

最後に学習済みのモデルを保存する。

対応するGitHub上のコード

    # Save trained model.
    os.makedirs(args.saved_model_path, exist_ok=True)
    model_path = "{}/{}".format(args.saved_model_path, SAVED_MODEL_NAME)
    torch.save(model.state_dict(), model_path)    # (3.a-13)
    print("Saved model to '{}'".format(model_path))

Python APIでは、torch.save の引数にResNet50モデルの state_dict を渡すことで、Optimizerの状態も含めてファイルに保存できる。

続いて、C++ APIを用いた実装を示す。

対応するGitHub上のコード

    // Save trained model.
    std::string model_path = args.saved_model_path + "/"
        + kSavedModelNamePrefix + "_model.pth";
    std::string optimizer_path = args.saved_model_path + "/"
        + kSavedModelNamePrefix + "_optimizer.pth";
    struct stat buf;
    if (stat(args.saved_model_path.c_str(), &buf)) {
        int rc;
        rc = mkdir(args.saved_model_path.c_str(), 0755);
        if (rc < 0) {
            std::cout << "Error: Failed to create diretory '"
                      << args.saved_model_path  << "' ("
                      << errno << ": " << strerror(errno) << ")" << std::endl;
            return 1;
        }
    }
    torch::save(model, model_path);
    torch::save(optimizer, optimizer_path);
    std::cout << "Saved model." << std::endl;
    std::cout << "  Model: " << model_path << std::endl;
    std::cout << "  Optimizer: " << model_path << std::endl;

ディレクトリ作成処理がC++ではやや煩雑になっているが、学習済みモデルの保存は torch::save を呼び出すだけで完了する。 なお、C++ APIでは state_dict が提供されていないため、Optimizerの状態を別途 torch::save 関数に渡して保存しなければならないことに注意しよう。

以上で学習処理の実装が完了した。

ビルド

作成したソースコードをビルドする。

cd cpp
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release

ビルドが完了したら、同ディレクトリに trainpredict の実行プログラムが作られていることを確認しよう。

MNISTデータセットのダウンロード

MNISTデータセット をダウンロードし、ビルド時に作成した実行プログラムと同じディレクトリにMNISTデータセットを配置する。

mkdir mnist
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gunzip *.gz

結果

ResNet50を学習するために次のコマンドを実行する。 オプション -m は、学習済みモデルを保存するディレクトリを指定するものである。 ディレクトリが存在しない場合は、自動的に指定されたディレクトリが作成される。

./train -m saved_model

プログラムの実行が完了すると、学習済みモデルがオプションで指定したディレクトリに保存される。

続いて、学習済みのモデルを使って、手書きの数字の画像が正しく推論できるか確認する。 今回利用した画像は、新たに作成した手書きの数字の画像 であり実行プログラム predict からの相対パス ../../data/digit.png に配置されているものとする。

https://raw.githubusercontent.com/nuka137/pytorch-cpp-example/6d82b0240af6cb33af015e30b01ee3f0fc3deec2/resnet/data/digit.png

次のコマンドを実行することで、saved_model に保存された学習済みモデルを利用して画像データ ../../data/digit.png について推論する。

./predict -i ../../data/digit.png -m saved_model

上記のコマンドを実行すると次のような出力が得られ、画像データが正しく推論できていることがわかる。

Predict: 7

なお、学習済みのモデルを使って推論するためのソースコードGitHub上で公開している。 OpenCVを使って画像データの読み込み&前処理を行っていることを除いて特に難しいところは無いため、具体的なソースコードの説明に関しては省略する。

Pythonによる実装

C++による実装

おわりに

CNNであるResNet50をPyTorchのC++ APIを使って実行し、MNISTデータセットを使って実際に学習させた。 また学習したモデルを使用し、手書きの数字の画像が正しく推論できていることを確認した。

Python APIによる記述とC++ APIはどちらも似たようなAPI仕様となっているため、これまでPython APIを使ってPyTorchを使っていた人がC++ APIに移行するのはそれほど大変ではないだろう。 一方で、C++ APIと同等の機能を持ったAPIが存在しないことも多く、C++ APIが安定版として提供され始めたとはいえ、まだまだ機能不足感は否めない。 このため、ドキュメントやPyTorchのIssueを確認しながら、使いたいC++ APIがサポートされているか否かを事前に確認することが必要になるだろう。


  1. C++ APIの開発は2019/9頃に GitHub にてContributorの募集があり、私もこの開発に参加してBatchNormなど12個のAPIを実装した。