| import math |
| import warnings |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss, LayerNorm |
| from torch.nn import functional as F |
| from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, \ |
| add_start_docstrings_to_model_forward |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| from .configuration_codify import CodifyConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CHECKPOINT_FOR_DOC = "smallcloudai/codify_medium_multi" |
| _CONFIG_FOR_DOC = "CodifyConfig" |
| _TOKENIZER_FOR_DOC = "CodifyTokenizerFast" |
|
|
|
|
| CODIFY_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
| "smallcloudai/codify_medium_multi", |
| "smallcloudai/codify_3b_multi" |
| ] |
|
|
| def _make_causal_mask( |
| input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int |
| ) -> torch.BoolTensor: |
| """ |
| Make causal mask used for self-attention. |
| """ |
| batch_size, target_length = input_ids_shape |
| mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) |
| |
| seq_ids = torch.arange(target_length, device=device) |
| mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] |
|
|
| if past_key_values_length > 0: |
| mask[:, :past_key_values_length] = False |
|
|
| expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) |
| return expanded_mask |
|
|
|
|
| def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: |
| """ |
| Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. |
| """ |
| batch_size, src_length = mask.shape |
| tgt_length = tgt_length if tgt_length is not None else src_length |
|
|
| expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) |
| return expanded_mask.expand(batch_size, 1, tgt_length, src_length) |
|
|
|
|
| def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: |
| """ |
| Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it |
| relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value |
| `softmax(l+a) = softmax(l)`. Based on |
| https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 |
| TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. |
| |
| Args: |
| Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) |
| attention_mask (`torch.Tensor`): |
| Token-wise attention mask, this should be of shape (batch_size, max_seq_len). |
| num_heads (`int`, *required*): |
| number of heads |
| dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): |
| dtype of the output tensor |
| """ |
| batch_size, seq_length = attention_mask.shape |
| closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) |
| base = torch.tensor( |
| 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 |
| ) |
| powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) |
| slopes = torch.pow(base, powers) |
|
|
| if closest_power_of_2 != num_heads: |
| extra_base = torch.tensor( |
| 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 |
| ) |
| num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) |
| extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) |
| slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) |
|
|
| |
| |
| |
| |
| |
| |
| arange_tensor = ((attention_mask.cumsum(dim=-1)) * attention_mask)[:, None, :] |
| alibi = slopes[..., None] * arange_tensor |
| return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) |
|
|
|
|
|
|
| def codify_gelu_forward(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to |
| make the model jitable. |
| |
| Args: |
| x (`torch.tensor`, *required*): |
| input hidden states |
| """ |
| return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) |
|
|
|
|
| def codify_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
| """ |
| gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + |
| 0.3989423 * x * torch.exp(-0.5 * x * x) |
| |
| Args: |
| g (`torch.tensor`, *required*): |
| gradient output tensor |
| x (`torch.tensor`, *required*): |
| input tensor |
| """ |
| x = x[0] |
| tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
| |
| ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) |
| return ff * g |
|
|
|
|
| class GeLUFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input: torch.Tensor) -> torch.Tensor: |
| ctx.save_for_backward(input) |
| return codify_gelu_forward(input) |
|
|
| @staticmethod |
| def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: |
| input = ctx.saved_tensors |
| tmp = codify_gelu_back(grad_output, input) |
| return tmp |
|
|
|
|
| class CodifyGelu(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.training: |
| return GeLUFunction.apply(x) |
| else: |
| return codify_gelu_forward(x) |
|
|
|
|
| class CodifyAttention(nn.Module): |
| def __init__(self, config: CodifyConfig): |
| super().__init__() |
|
|
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.split_size = self.hidden_size |
|
|
| if self.head_dim * self.num_heads != self.hidden_size: |
| raise ValueError( |
| f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" |
| f" {self.num_heads})." |
| ) |
|
|
| |
| |
| self.inv_norm_factor = 8.0 / self.head_dim |
| self.beta = 1.0 |
|
|
| self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) |
| self.dense = nn.Linear(self.hidden_size, self.hidden_size) |
|
|
| def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory |
| storage as `fused_qkv` |
| |
| Args: |
| fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] |
| |
| Returns: |
| query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] |
| value: [batch_size, seq_length, num_heads, head_dim] |
| """ |
| batch_size, seq_length, _ = fused_qkv.shape |
| q, k, v = fused_qkv.chunk(3, dim=-1) |
| return q.view(batch_size, seq_length, self.num_heads, self.head_dim),\ |
| k.view(batch_size, seq_length, self.num_heads, self.head_dim),\ |
| v.view(batch_size, seq_length, self.num_heads, self.head_dim) |
|
|
| def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Merge heads together over the last dimenstion |
| |
| Args: |
| x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] |
| |
| Returns: |
| torch.tensor: [batch_size, seq_length, num_heads * head_dim] |
| """ |
| |
| |
| batch_size_and_num_heads, seq_length, _ = x.shape |
| batch_size = batch_size_and_num_heads // self.num_heads |
|
|
| |
| |
| x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) |
|
|
| |
| x = x.permute(0, 2, 1, 3) |
|
|
| |
| return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| alibi: torch.Tensor, |
| attention_mask: torch.Tensor, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| output_attentions: bool = False, |
| ): |
| fused_qkv = self.query_key_value(hidden_states) |
|
|
| |
| (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) |
|
|
| batch_size, q_length, _, _ = query_layer.shape |
|
|
| query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) |
| key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) |
| value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) |
| if layer_past is not None: |
| past_key, past_value = layer_past |
| |
| |
| |
| key_layer = torch.cat((past_key, key_layer), dim=2) |
| value_layer = torch.cat((past_value, value_layer), dim=1) |
|
|
| _, _, kv_length = key_layer.shape |
|
|
| if use_cache is True: |
| present = (key_layer, value_layer) |
| else: |
| present = None |
|
|
| |
| |
| matmul_result = alibi.baddbmm( |
| batch1=query_layer, |
| batch2=key_layer, |
| beta=self.beta, |
| alpha=self.inv_norm_factor, |
| ) |
|
|
| |
| attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) |
|
|
| |
| input_dtype = attention_scores.dtype |
| |
| if input_dtype == torch.float16: |
| attention_scores = attention_scores.to(torch.float) |
| attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) |
| attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) |
|
|
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| |
| attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) |
|
|
| |
| context_layer = torch.bmm(attention_probs_reshaped, value_layer) |
|
|
| |
| context_layer = self._merge_heads(context_layer) |
|
|
| output_tensor = self.dense(context_layer) |
| outputs = (output_tensor, present) |
| if output_attentions: |
| outputs += (attention_probs,) |
|
|
| return outputs |
|
|
|
|
| class CodifyMLP(nn.Module): |
| def __init__(self, config: CodifyConfig): |
| super().__init__() |
| hidden_size = config.hidden_size |
| self.dense_h_to_4h = nn.Linear(hidden_size, config.mlp_mult * hidden_size) |
| self.gelu_impl = CodifyGelu() |
| self.dense_4h_to_h = nn.Linear(config.mlp_mult * hidden_size, hidden_size) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) |
| output = self.dense_4h_to_h(hidden_states) |
| return output |
|
|
|
|
| class CodifyBlock(nn.Module): |
| def __init__(self, config: CodifyConfig): |
| super().__init__() |
| hidden_size = config.hidden_size |
|
|
| self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| self.num_heads = config.num_attention_heads |
| self.self_attention = CodifyAttention(config) |
| self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
| self.mlp = CodifyMLP(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| alibi: torch.Tensor, |
| attention_mask: torch.Tensor, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| output_attentions: bool = False, |
| ): |
| |
|
|
| |
| layernorm_output = self.input_layernorm(hidden_states) |
|
|
| |
| attn_outputs = self.self_attention( |
| layernorm_output, |
| layer_past=layer_past, |
| attention_mask=attention_mask, |
| alibi=alibi, |
| head_mask=head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
|
|
| attention_output = attn_outputs[0] |
| outputs = attn_outputs[1:] |
|
|
| attention_mix = attention_output + hidden_states |
| layernorm_output = self.post_attention_layernorm(attention_mix) |
|
|
| |
| output = self.mlp(layernorm_output) |
| output = output + attention_output + hidden_states |
|
|
| if use_cache: |
| outputs = (output,) + outputs |
| else: |
| outputs = (output,) + outputs[1:] |
|
|
| return outputs |
|
|
| class CodifyPreTrainedModel(PreTrainedModel): |
| _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = CodifyConfig |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["CodifyBlock"] |
|
|
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
|
|
| def _init_weights(self, module: nn.Module): |
| """Initialize the weights.""" |
| if isinstance(module, nn.Linear): |
| |
| |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): |
| if isinstance(module, CodifyModel): |
| module.gradient_checkpointing = value |
|
|
| @staticmethod |
| def _convert_to_standard_cache( |
| past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int |
| ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, |
| num_heads, ...])) |
| """ |
| batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape |
| num_heads = batch_size_times_num_heads // batch_size |
| |
| |
| return tuple( |
| ( |
| layer_past[0].view(batch_size, num_heads, head_dim, seq_length), |
| layer_past[1].view(batch_size, num_heads, seq_length, head_dim), |
| ) |
| for layer_past in past_key_value |
| ) |
|
|
| @staticmethod |
| def _convert_to_codify_cache( |
| past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] |
| ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: |
| batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape |
| batch_size_times_num_heads = batch_size * num_heads |
| |
| |
| return tuple( |
| ( |
| layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), |
| layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), |
| ) |
| for layer_past in past_key_value |
| ) |
|
|
| class CodifyModel(CodifyPreTrainedModel): |
| def __init__(self, config: CodifyConfig): |
| super().__init__(config) |
|
|
| self.embed_dim = config.hidden_size |
| self.num_heads = config.num_attention_heads |
|
|
| |
| self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) |
|
|
| |
| self.h = nn.ModuleList([CodifyBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
| |
| self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
|
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.word_embeddings |
|
|
| def _prepare_attn_mask( |
| self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int |
| ) -> torch.BoolTensor: |
| |
| |
| combined_attention_mask = None |
| device = attention_mask.device |
| _, src_length = input_shape |
|
|
| if src_length > 1: |
| combined_attention_mask = _make_causal_mask( |
| input_shape, device=device, past_key_values_length=past_key_values_length |
| ) |
|
|
| |
| expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) |
| combined_attention_mask = ( |
| expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask |
| ) |
|
|
| return combined_attention_mask |
|
|
| def set_input_embeddings(self, new_embeddings: torch.Tensor): |
| self.word_embeddings = new_embeddings |
|
|
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=BaseModelOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **deprecated_arguments |
| ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: |
| if deprecated_arguments.pop("position_ids", False) is not False: |
| |
| warnings.warn( |
| "`position_ids` have no functionality in Codify and will be removed in v5.0.0. You can safely ignore" |
| " passing `position_ids`.", |
| FutureWarning, |
| ) |
| if len(deprecated_arguments) > 0: |
| raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| elif inputs_embeds is not None: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| if past_key_values is None: |
| past_key_values = tuple([None] * len(self.h)) |
|
|
| |
| |
| |
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.word_embeddings(input_ids) |
|
|
| hidden_states = inputs_embeds |
|
|
| presents = () if use_cache else None |
| all_self_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
|
|
| |
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
| if past_key_values[0] is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) |
| else: |
| attention_mask = attention_mask.to(hidden_states.device) |
|
|
| alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) |
|
|
| causal_mask = self._prepare_attn_mask( |
| attention_mask, |
| input_shape=(batch_size, seq_length), |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| if use_cache: |
| logger.warning( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) |
|
|
| return custom_forward |
|
|
| outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| hidden_states, |
| alibi, |
| causal_mask, |
| head_mask[i], |
| ) |
| else: |
| outputs = block( |
| hidden_states, |
| layer_past=layer_past, |
| attention_mask=causal_mask, |
| head_mask=head_mask[i], |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| alibi=alibi, |
| ) |
|
|
| hidden_states = outputs[0] |
| if use_cache is True: |
| presents = presents + (outputs[1],) |
|
|
| if output_attentions: |
| all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |
|
|
| |
| hidden_states = self.ln_f(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=presents, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attentions, |
| ) |
|
|
|
|
| class CodifyForCausalLM(CodifyPreTrainedModel): |
| _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] |
|
|
| def __init__(self, config: CodifyConfig): |
| super().__init__(config) |
| self.transformer = CodifyModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: torch.Tensor): |
| self.lm_head = new_embeddings |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs |
| ) -> dict: |
| |
| if past: |
| input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
| if past[0][0].shape[0] == input_ids.shape[0]: |
| past = self._convert_to_codify_cache(past) |
|
|
| return { |
| "input_ids": input_ids, |
| "past_key_values": past, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| } |
|
|
| @add_code_sample_docstrings( |
| processor_class=_TOKENIZER_FOR_DOC, |
| checkpoint=_CHECKPOINT_FOR_DOC, |
| output_type=CausalLMOutputWithPast, |
| config_class=_CONFIG_FOR_DOC, |
| ) |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **deprecated_arguments |
| ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
| `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
| are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
| """ |
| if deprecated_arguments.pop("position_ids", False) is not False: |
| |
| warnings.warn( |
| "`position_ids` have no functionality in Codify and will be removed in v5.0.0. You can safely ignore" |
| " passing `position_ids`.", |
| FutureWarning, |
| ) |
| if len(deprecated_arguments) > 0: |
| raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| transformer_outputs = self.transformer( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = transformer_outputs[0] |
|
|
| lm_logits = self.lm_head(hidden_states / 2.0) |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| batch_size, seq_length, vocab_size = shift_logits.shape |
| |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) |
| ) |
|
|
| if not return_dict: |
| output = (lm_logits,) + transformer_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=transformer_outputs.past_key_values, |
| hidden_states=transformer_outputs.hidden_states, |
| attentions=transformer_outputs.attentions, |
| ) |
|
|
| def _reorder_cache( |
| self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor |
| ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: |
| """ |
| This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or |
| [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
| beam_idx at every generation step. |
| |
| Output shares the same memory storage as `past`. |
| """ |
| standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx)) |
|
|
| |
| device_to_beam_idx = { |
| past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past |
| } |
| reordered_past = tuple( |
| ( |
| layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), |
| layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), |
| ) |
| for layer_past in standardized_past |
| ) |
| return self._convert_to_codify_cache(reordered_past) |
|
|