
Let’s get started and write a basic U-Net in Pytorch based on the diagram in Figure 1. To make the code simpler and clean we can define several modules, starting with the ConvLayer that is no more than a 2d convolutional layer followed by a 2d batch normalization and the ReLU activation. Notice that the bias is set to False in the convolutional layer since the batch norm that follows already has a bias term.
Next, let’s define the ConvBlock, this will correspond to a set of 2 ConvLayers as defined above, followed by a max-pooling that reduces the image size by half.
Note: I defined the ConvLayer and ConvBlock as subclasses of nn.Sequential. That’s why there is no forward method as it must happen when creating subclasses of nn.Module. I borrowed this idea from fastai course lessons 🙂 Also, note that fastai library already implements more complete versions of ConvLayer and ConvBlock (with extra options).
Moving on to the decoder, each block will need two inputs, one corresponding to the green arrow in the U-Net diagram (Figure 1) and other for the grey arrow.
- The green arrow corresponds to an upscaling and therefore to the first input (x1 in the code below) is applied a 2d transposed convolution with kernel-size 2 and stride 2. This makes the output twice the size of the input. (Check this story if you want to read more about convolutions).
- Then the upscaled feature map (x_up in the code below) is concatenated with the input coming from the encoder (x2 in the code below).
- Finally, two ConvLayers are applied, concluding our UpConvBlock.
That’s it! With three modules of a few lines of code each, we have now the building blocks to create a U-Net!
The code below defines the U-Net model, you can see we have 5 encoder blocks and 5 decoder blocks. I included in comments the shapes of each feature map starting with an input tensor of size: batch-size x 3 x 128 x 128.
Take some time to look at the numbers and to make sure you understand how the calculations work to get the final output with the same image size as the input.
It’s now time to put this model into practice!