Introduce funnel operator i,e '|>' to allow for generator pipelines

Regarding examples, I had listed a handful of examples in another thread, see my last comment for the link. Adding to that, a very common and repetitive pattern emerges in Deep Neural Networks.

from torch import nn

class CNN(nn.Module):
    def __init__(
        self,
        input_channels: int,
        output_channels: int,
        dropout: float,
    ):
        super(CNN, self).__init__()
        self.conv0 = nn.Conv2d(input_channels, 16, kernel_size=3, padding=1)
        self.bn0 = nn.BatchNorm2d(16)
        self.conv1 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(256 * 20 * 20, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, output_channels)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(self.relu(self.bn0(self.conv0(x))))
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.pool(self.relu(self.conv4(x)))

        x = x.view(-1, 256 * 20 * 20)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.sigmoid(self.fc3(x))

        return x

The forward pass could be like this:

    def forward(self, x):
       return (
            x                                         # 640 x 640 (features)
            |> self.conv0 |> self.bn0 |> self.relu    # 3 -> 16   (channels)
            |> self.pool                              # 640 -> 320
            |> self.conv1 |> self.bn1 |> self.relu    # 16 -> 32
            |> self.pool                              # 320 -> 160
            |> self.conv2 |> self.bn2 |> self.relu    # 32 -> 64
            |> self.pool                              # 160 -> 80
            |> self.conv3 |> self.relu |>             # 64 -> 128
            |> self.pool                              # 80 -> 40
            |> self.conv4 |> self.relu |>             # 128 -> 256
            |> self.pool                              # 40 -> 20
            |> ~.view(-1, 256, 20, 20)                # flatten
            |> self.fc1 |> self.relu |> self.dropout  # fully conn 1 + dropout
            |> self.fc2 |> self.relu |> self.dropout  # fully conn 2 + dropout
            |> self.fc3 |> self.sigmoid               # output layer
       )    # fmt: skip

Not only is it representative of how the data flows through the network, but also it is extremely adaptive to change. For example, if I wanted to turn off batch normalization in some layers, and turn on in others, it would be trivial. Whereas with the previous version, I would have had to hop around parentheses, and make sure they are balanced. Not to mention, not having to allocate a single variable throughout the process.

Arguably, the first version is marginally easier to debug, but for simple networks, it is generally a set pattern, and you don’t really have to debug at batch-norm levels. And even if you had to, it is quite straightforward to split it up.

Although, this raises another question. How do formatters go about formatting? I guess it will be like any other binary operator. Also, as mentioned earlier in the thread, a way to store a pipeline would be very handy. One suggestion was pipeline = lambda arg: arg |> f1 |> f2 |> ...

But if ~ (tilde) is used as implicit partial/lambda, maybe it makes sense to consider expressions beginning with ~ |> ... as a pipeline for later use?

e.g.

# in constructor
def __init__(self):
    self.layer1 = (~ |> nn.Conv2D(3, 16, 3, 1) |> nn.BatchNorm2d(16) |> nn.ReLu)
    .
    .
    .

# then in the forward pass
def forward(self):
    return x |> self.layer1 |> self.layer2 |> ...

NOTE: PyTorch itself provides several ways of function composition (e.g. nn.Sequential, torch.nn.functional.Compose, etc.).

2 Likes