Unfreezing specific layers of a pretrained model

I am using the pre-trained vision transformer of ViT as my encoder. The encoder block of this model is as follows (there are some functions called inside it that you can find it in the source code and I avoid to mention for the sake of space saving):

class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

I want to add the following module to the above-mentioned EncoderBlock:

class Adapter(nn.Module):
        def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU):
            super().__init__()
            #self.skip_connect = skip_connect
            D_hidden_features = int(D_features * mlp_ratio)
            self.act = act_layer()
            self.D_fc1 = nn.Linear(D_features, D_hidden_features)
            self.D_fc2 = nn.Linear(D_hidden_features, D_features)
            
        def forward(self, x):
            # x is (BT, HW+1, D)
            xs = self.D_fc1(x)
            xs = self.act(xs)
            xs = self.D_fc2(xs)
            
            x = xs
            return x

To be clear, I added the Adapter to the EncoderBlock as follows:

def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
        
        self.adapter = Adapter(D_features = hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
        
        x = self.adapter(x)      # I added
     
        x = self.dropout(x)
    
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

I’m using the pre-trained model of ViT and I need to freeze whole parameters of ViT model except the Adapter modules that are added to the layers of ViT.
Although I know how to freeze the whole ViT, by considering the fact that the whole ViT model is built based on stacking the EncoderBlock layers, how I can unfreeze the added ‘Adapter’ layers?

This is a forum for Python the language, and questions about specific non-stdlib libraries are better asked on specific subforum. Pytorch’s forum is here https://discuss.pytorch.org/.

That is not to say you won’t get an answer here, but you’re more likely to get an answer there :slight_smile:

@ajoino Thank you for the point.

As I found the answer, in order to freeze the whole encoder except the Adapter module the following strategy can be used:

for name, param in model.encoder.parameters():
    if 'Adapter' not in name:
        param.requires_grad = False
1 Like