| | from transformers import Starcoder2Model |
| | import sys |
| | from config import ModularStarEncoderConfig |
| | import os |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union, List |
| | import sys |
| | import torch |
| | import torch.utils.checkpoint |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | from transformers.activations import ACT2FN |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import ( |
| | ModelOutput, |
| | logging, |
| |
|
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | class StarEncoder2PreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = ModularStarEncoderConfig |
| | base_model_prefix = "ModularStarEncoder" |
| | model_type = "ModularStarEncoder" |
| | supports_gradient_checkpointing = True |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| | _supports_cache_class = True |
| |
|
| |
|
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | class StarEncoder2Pooler(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | self.activation = nn.Tanh() |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | last_token_tensor = hidden_states[:, -1] |
| | pooled_output = self.dense(last_token_tensor) |
| | pooled_output = self.activation(pooled_output) |
| | return pooled_output |
| |
|
| | @dataclass |
| | class ModularStarEncoderOutput(ModelOutput): |
| | """ |
| | Output type of [`BertForPreTraining`]. |
| | |
| | Args: |
| | loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
| | Total loss as the sum of the masked language modeling loss and the next sequence prediction |
| | (classification) loss. |
| | prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| | seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): |
| | Prediction scores of the in context classification (classification) head (scores of True/False continuation |
| | before SoftMax). |
| | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
| | shape `(batch_size, sequence_length, hidden_size)`. |
| | |
| | Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| | sequence_length)`. |
| | |
| | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| | heads. |
| | """ |
| |
|
| | projected_pooled_normalized: Optional[List[torch.FloatTensor]] = None |
| | raw_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None): |
| | if self.is_matryoshka: |
| | device_sequence = sequence_output.get_device() |
| | if device_sequence<0: |
| | device_sequence = "cpu" |
| | prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1)) |
| | seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1)) |
| | else: |
| | prediction_scores = self.predictions(sequence_output) |
| | seq_relationship_score = self.seq_relationship(pooled_output) |
| | return prediction_scores, seq_relationship_score |
| |
|
| |
|
| | def normalize(my_tensor): |
| | embedding_norms = my_tensor.norm(dim=0) |
| |
|
| | normalizing_factor = torch.where( |
| | embedding_norms > 1.0, embedding_norms, torch.tensor(1) |
| | ) |
| |
|
| | normalized_tensor = my_tensor / normalizing_factor |
| | return normalized_tensor |
| | def pooling(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| | """Pools a batch of vector sequences into a batch of vector global representations. |
| | It does so by taking the average representation of the sequence, as indicated by the mask. |
| | |
| | Args: |
| | x (torch.Tensor): Batch of vector sequences with shape [B, T, F]. |
| | mask (torch.Tensor): Batch of masks with shape [B, T]. |
| | |
| | Returns: |
| | torch.Tensor: Pooled version of the input batch with shape [B, F]. |
| | """ |
| |
|
| | |
| | mask_expanded = mask.unsqueeze(-1) |
| |
|
| | |
| | masked_x = x * mask_expanded |
| | |
| | sum_x = masked_x.sum(dim=1) |
| | |
| | valid_lengths = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) |
| | |
| | pooled_x = sum_x / valid_lengths |
| |
|
| | return pooled_x |
| |
|
| | def pool_and_normalize( |
| | features_sequence: torch.Tensor, |
| | attention_masks: torch.Tensor, |
| | return_norms: bool = False, |
| | ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| | """Temporal ooling of sequences of vectors and projection onto the unit sphere. |
| | |
| | Args: |
| | features_sequence (torch.Tensor): Inpute features with shape [B, T, F]. |
| | attention_masks (torch.Tensor): Pooling masks with shape [B, T, F]. |
| | return_norms (bool, optional): Whether to additionally return the norms. Defaults to False. |
| | |
| | Returns: |
| | Union[torch.Tensor, List[torch.Tensor]]: Pooled and normalized vectors with shape [B, F]. |
| | """ |
| |
|
| | pooled_embeddings = pooling(features_sequence, attention_masks) |
| | embedding_norms = pooled_embeddings.norm(dim=1) |
| |
|
| | normalizing_factor = torch.where( |
| | embedding_norms > 1.0, embedding_norms, torch.ones_like(embedding_norms) |
| | ) |
| |
|
| | pooled_normalized_embeddings = pooled_embeddings / normalizing_factor[:, None] |
| |
|
| | if return_norms: |
| | return pooled_normalized_embeddings, embedding_norms |
| | else: |
| | return pooled_normalized_embeddings |
| |
|
| | def get_pooling_mask( |
| | input_ids: torch.Tensor, sep_token_id: Union[int, float] |
| | ) -> torch.Tensor: |
| | """Gets pooling masks. For a sequence of input tokens, the mask will be |
| | a sequence of zeros up until the first [SEP] occurrence, and 1 after that. |
| | |
| | Args: |
| | input_ids (torch.Tensor): Batch of input ids with shape [B, T]. |
| | sep_token_id (Union[int, float]): Id for [SEP] token. |
| | |
| | Returns: |
| | torch.Tensor: Batch of pooling masks with shape [B, T] |
| | """ |
| | |
| | idx = (input_ids == sep_token_id).float().flip(1).argmax(1) |
| |
|
| | idx = input_ids.size(-1)-idx-1 |
| |
|
| | repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1)) |
| |
|
| | ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1) |
| |
|
| | pooling_mask = (repeated_idx <= ranges).long() |
| |
|
| | return pooling_mask |
| |
|
| | def adapt_model(model,config,till_layer:27): |
| | model = model.starEncoder2 |
| |
|
| | encoder_config = config |
| | layers = encoder_config.matryoshka_layers |
| | feature_dim = encoder_config.hidden_size |
| |
|
| | model.projection_heads = torch.nn.ModuleList() |
| | if till_layer: |
| | print(f"ATTENTION: till layer is on, you are pruning the model keeping just the first {till_layer} layers") |
| | model.layers = model.layers[:till_layer] |
| | model.projection_heads.append(torch.nn.Sequential( |
| | torch.nn.Linear(feature_dim, feature_dim), |
| | torch.nn.LeakyReLU(), |
| | torch.nn.Linear(feature_dim, feature_dim), |
| | )) |
| | else: |
| | for layer in layers: |
| | model.projection_heads.append(torch.nn.Sequential( |
| | torch.nn.Linear(feature_dim, feature_dim), |
| | torch.nn.LeakyReLU(), |
| | torch.nn.Linear(feature_dim, feature_dim), |
| | )) |
| | |
| | for layer in model.layers: |
| | layer.self_attn.is_causal=False |
| |
|
| | model.temperature_coef = torch.nn.Parameter(torch.Tensor([10.0]),requires_grad=False) |
| |
|
| | return model |
| |
|
| | class ModularStarEncoder(StarEncoder2PreTrainedModel): |
| | _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
| | config_class = ModularStarEncoderConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model_type = "ModularStarEncoder" |
| | for element in dir(config): |
| | value = getattr(config, element) |
| | if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0: |
| | setattr(config, element, value[0]) |
| | self.layer_matryoshka_loss = config.layer_matryoshka_loss |
| | self.matryoshka_layers = config.matryoshka_layers |
| |
|
| |
|
| | self.starEncoder2 = Starcoder2Model(config) |
| |
|
| |
|
| | |
| | for layer in self.starEncoder2.layers: |
| | layer.self_attn.is_causal=False |
| | |
| | self.post_init() |
| | self.till_layer= 9 |
| | self.starEncoder2 = adapt_model(self ,config=config,till_layer=self.till_layer) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | next_sentence_label: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | sep_token_id:Optional[int] = 49152, |
| | ) -> Union[Tuple[torch.Tensor], ModularStarEncoderOutput]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
| | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
| | the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
| | next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | This label is assigned to the in context loss: |
| | - 0 indicates sequence B belongs to the same repository of A, |
| | - 1 indicates sequence B is a random repository. |
| | kwargs (`Dict[str, any]`, optional, defaults to *{}*): |
| | Used to hide legacy arguments that have been deprecated. |
| | |
| | |
| | """ |
| |
|
| | source_embedding = self.starEncoder2( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | ) |
| |
|
| |
|
| | DEVICE = source_embedding.hidden_states[-1].get_device() |
| | if DEVICE<0: |
| | DEVICE = "cpu" |
| |
|
| | try: |
| | projection_fn = self.starEncoder2.module.projection_heads |
| | temp_coef = self.starEncoder2.module.temperature_coef |
| | except AttributeError: |
| | projection_fn = self.starEncoder2.projection_heads |
| | temp_coef = self.starEncoder2.temperature_coef |
| |
|
| | for head in projection_fn: |
| | head.to(DEVICE) |
| | temp_coef.to(DEVICE) |
| |
|
| |
|
| |
|
| |
|
| | pooling_mask_source_targtes = get_pooling_mask( |
| | input_ids, sep_token_id |
| | ) |
| |
|
| | if self.till_layer: |
| | self.matryoshka_layers=[self.till_layer] |
| |
|
| | pooled_and_normalized = [] |
| | for idx,matr_layer in enumerate(self.matryoshka_layers): |
| | source_embedding_proj = projection_fn[idx](source_embedding.hidden_states[matr_layer]) |
| |
|
| | normalized_source_embedding, embedding_norms = pool_and_normalize( |
| | source_embedding_proj, |
| | pooling_mask_source_targtes, |
| | return_norms=True, |
| | ) |
| | |
| | pooled_and_normalized.append(normalized_source_embedding) |
| |
|
| | if not self.till_layer: |
| | return ModularStarEncoderOutput( |
| | projected_pooled_normalized = pooled_and_normalized, |
| | raw_hidden_states=source_embedding.hidden_states, |
| | attentions=source_embedding.attentions, |
| | ) |
| | else: |
| | return ModularStarEncoderOutput( |
| | projected_pooled_normalized = pooled_and_normalized[0], |
| | raw_hidden_states=source_embedding.hidden_states, |
| | attentions=source_embedding.attentions, |
| | ) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|