| | from torch import nn |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| |
|
| | class BasePreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and |
| | a simple interface for downloading and loading pretrained models. |
| | """ |
| |
|
| | supports_gradient_checkpointing = True |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if ( |
| | isinstance(module, nn.Conv2d) |
| | or isinstance(module, nn.Embedding) |
| | or isinstance(module, nn.Linear) |
| | ): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if hasattr(module, "bias") and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, nn.Parameter): |
| | raise ValueError() |
| |
|