The Main Class
class ViTModel(ViTPreTrainedModel):
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> ViTPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
The ViTModel
class serves as the cornerstone of the Vision Transformer architecture. This class embodies the adaptation of the transformer model for visual tasks. Let’s explore the ViTModel
class and understand how it integrates its various components.
This is a comprehensive class that encapsulates the entire process of transforming an input image into a meaningful representation for downstream tasks such as image classification or object detection.
In previous sections I mentioned that at its heart, the ViTModel
operates on the principle of viewing an image as a sequence of patches and applying the transformer mechanism to these sequences. Absolutely, we can extend your writing by going through the ViTModel
class line by line, explaining the functionality and significance of each part. Let's dive in:
The ViTModel
Class
class ViTModel(ViTPreTrainedModel):
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
super().__init__(config)
self.config = config
The ViTModel
class is derived from ViTPreTrainedModel
, which is an abstract class that handles weights initialization and loading of pretrained models. We'll go over it later in this quest.
The __init__
method initializes the ViTModel
instance. It is a constructor that takes a configuration object ViTConfig
. This config object contains all the parameters required for the model. The flexibility to add a pooling layer or use a mask token is provided through optional parameters.
We'll go over the initialized dependencies and the forward method of the ViTModel
class to understand how the model processes input images. Later in next sections, I'll explain each of these components in detail.
Embeddings Layer Initialization
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
The embeddings layer ViTEmbeddings
is where the image is decomposed into patches, and each patch is embedded into a higher-dimensional space.
Whether to include a mask token, used in certain training scenarios like masked image modeling, is controlled by the use_mask_token
flag.
Encoder Initialization
self.encoder = ViTEncoder(config)
The ViTEncoder
, consisting of multiple layers of the transformer model, is initialized here. The encoder is responsible for processing the sequence of embedded patches through self-attention mechanisms.
Layer Normalization and Pooling
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
A layer normalization nn.LayerNorm
is applied to the output of the encoder. This is crucial for stabilizing the training process and accelerating convergence.
The ViTPooler
is an optional component that can be added to the model. It processes the first token's output (corresponding to the CLS token) to be used in tasks such as classification.
Model Weights Initialization
self.post_init()
This call initializes the model weights properly. It’s a part of the PreTrainedModel
that the ViTModel
inherited in the beginning of this section, so after this line the model starts with suitable weights for training.
Forward Method
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
...
) -> Union[Tuple, BaseModelOutputWithPooling]:
The forward
method is where the actual processing of the input data (pixel values of the image) happens. It sequentially passes data through the embeddings, encoder, layer normalization, and pooler (if present).
It takes in the image tensor (pixel values) and optionally other parameters like masks for attention heads, boolean flags for outputting attentions or hidden states, and a flag for return type.
Processing in Forward Method
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output, head_mask=head_mask, output_attentions=output_attentions, ...
)
Inside the forward
method, the image is first passed through the embedding layer, and then the output of this layer is passed through the encoder. The embeddings layer converts the pixel values into a format suitable for the transformer, while the encoder processes these embeddings using self-attention.
Output of the Model
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
The output from the encoder is then normalized using layer normalization. After that, if a pooling layer is present, it processes the first token's output for tasks like classification.
At the end of its processing, the forward
method returns the processed output which can be utilized for downstream tasks like image classification or masked image modeling.
The ViTConfig
Class
I held this part to the end of the explanation to make it easier to understand. The last component to go over the configuration that ViT accepts.
This class is defining and storing these configuration parameters for the model. It is defined as a subclass of PretrainedConfig
, which means it inherits methods and attributes from the base configuration class used for all pre-trained models in Hugging Face's Transformers library.
This, I presume, is some sort of standardization to ensure consistency across different models: Let's go over these parameters that the constructor of ViTConfig takes. These parameters will make more sense as we go through the architecture of the model:
-
hidden_size
: This parameter sets the size of the hidden layers in the Transformer. It's a key factor in determining the model's capacity. -
num_hidden_layers
: This defines the depth of the Transformer, i.e., how many layers it has. -
num_attention_heads
: In each Transformer layer, multi-head attention mechanisms are used. This parameter defines how many such heads are in each layer. -
intermediate_size
: The size of the "intermediate" layer in each Transformer block. This is typically larger thanhidden_size
and allows the model to capture more complex features. -
hidden_act
: The activation function used in the hidden layers. Common options include "gelu", "relu", and others. -
hidden_dropout_prob
&attention_probs_dropout_prob
: These define the dropout probabilities for the fully connected layers and attention probabilities, respectively. Dropout is a regularization technique to prevent overfitting. -
initializer_range
: Sets the standard deviation for the truncated_normal_initializer, affecting how the model weights are initially set. -
layer_norm_eps
: The epsilon value used for layer normalization, to prevent division by zero. -
image_size
: The size of the input images. -
patch_size
: The size of each image patch. The image is divided into patches of this size. -
num_channels
: Number of channels in the input images, typically 3 for RGB images. -
qkv_bias
: Determines whether to add a bias to the query, key, and value projections in the attention mechanism. -
encoder_stride
: Used in the decoder for masked image modeling, indicating the factor for increasing spatial resolution.