Regarding examples, I had listed a handful of examples in another thread, see my last comment for the link. Adding to that, a very common and repetitive pattern emerges in Deep Neural Networks.
from torch import nn
class CNN(nn.Module):
def __init__(
self,
input_channels: int,
output_channels: int,
dropout: float,
):
super(CNN, self).__init__()
self.conv0 = nn.Conv2d(input_channels, 16, kernel_size=3, padding=1)
self.bn0 = nn.BatchNorm2d(16)
self.conv1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(256 * 20 * 20, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, output_channels)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(self.relu(self.bn0(self.conv0(x))))
x = self.pool(self.relu(self.bn1(self.conv1(x))))
x = self.pool(self.relu(self.bn2(self.conv2(x))))
x = self.pool(self.relu(self.conv3(x)))
x = self.pool(self.relu(self.conv4(x)))
x = x.view(-1, 256 * 20 * 20)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.sigmoid(self.fc3(x))
return x
The forward pass could be like this:
def forward(self, x):
return (
x # 640 x 640 (features)
|> self.conv0 |> self.bn0 |> self.relu # 3 -> 16 (channels)
|> self.pool # 640 -> 320
|> self.conv1 |> self.bn1 |> self.relu # 16 -> 32
|> self.pool # 320 -> 160
|> self.conv2 |> self.bn2 |> self.relu # 32 -> 64
|> self.pool # 160 -> 80
|> self.conv3 |> self.relu |> # 64 -> 128
|> self.pool # 80 -> 40
|> self.conv4 |> self.relu |> # 128 -> 256
|> self.pool # 40 -> 20
|> ~.view(-1, 256, 20, 20) # flatten
|> self.fc1 |> self.relu |> self.dropout # fully conn 1 + dropout
|> self.fc2 |> self.relu |> self.dropout # fully conn 2 + dropout
|> self.fc3 |> self.sigmoid # output layer
) # fmt: skip
Not only is it representative of how the data flows through the network, but also it is extremely adaptive to change. For example, if I wanted to turn off batch normalization in some layers, and turn on in others, it would be trivial. Whereas with the previous version, I would have had to hop around parentheses, and make sure they are balanced. Not to mention, not having to allocate a single variable throughout the process.
Arguably, the first version is marginally easier to debug, but for simple networks, it is generally a set pattern, and you don’t really have to debug at batch-norm levels. And even if you had to, it is quite straightforward to split it up.
Although, this raises another question. How do formatters go about formatting? I guess it will be like any other binary operator. Also, as mentioned earlier in the thread, a way to store a pipeline would be very handy. One suggestion was pipeline = lambda arg: arg |> f1 |> f2 |> ...
But if ~ (tilde) is used as implicit partial/lambda, maybe it makes sense to consider expressions beginning with ~ |> ... as a pipeline for later use?
e.g.
# in constructor
def __init__(self):
self.layer1 = (~ |> nn.Conv2D(3, 16, 3, 1) |> nn.BatchNorm2d(16) |> nn.ReLu)
.
.
.
# then in the forward pass
def forward(self):
return x |> self.layer1 |> self.layer2 |> ...
NOTE: PyTorch itself provides several ways of function composition (e.g. nn.Sequential, torch.nn.functional.Compose, etc.).