| | from argparse import Namespace |
| | from torch.utils.checkpoint import checkpoint |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from open_lm.utils.transformers.hf_config import OpenLMConfig |
| | from open_lm.model import Transformer, create_params |
| | from open_lm.attention import get_attn_func, xformers_attn, torch_attn |
| | from open_lm.norms import get_norm_class |
| | import torch |
| | import torch.nn as nn |
| | from typing import Union, Tuple, Optional, List |
| | import os |
| |
|
| |
|
| | class OpenLMModel(PreTrainedModel): |
| | config_class = OpenLMConfig |
| |
|
| | def __init__(self, config, **kwargs): |
| | |
| | if hasattr(config, "params"): |
| | params = config.params |
| | else: |
| | params_args_dict = config.params_args_dict |
| | if not params_args_dict.get("norm_type"): |
| | params_args_dict["norm_type"] = get_norm_class(params_args_dict["model_norm"]) |
| | if not params_args_dict.get("attn_func"): |
| | params_args_dict["attn_func"] = get_attn_func( |
| | params_args_dict["attn_name"], |
| | params_args_dict["attn_activation"], |
| | params_args_dict["attn_seq_scalar"], |
| | params_args_dict["attn_seq_scalar_alpha"] |
| | ) |
| | params = create_params(Namespace(**config.params_args_dict)) |
| | config.set_params(params) |
| | super().__init__(config, **kwargs) |
| |
|
| | self.supports_gradient_checkpointing = True |
| | self.model = Transformer(params) |
| |
|
| | @property |
| | def gradient_checkpointing(self): |
| | return self.model.grad_checkpointing |
| |
|
| | @gradient_checkpointing.setter |
| | def gradient_checkpointing(self, value): |
| | self.model.grad_checkpointing = value |
| |
|
| | def forward(self, input_ids=None, inputs_embeds=None, **kwargs): |
| | return self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) |
| |
|
| |
|
| | class OpenLMforCausalLM(OpenLMModel): |
| | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] |
| |
|
| | def __init__(self, config, **kwargs): |
| | super().__init__(config, **kwargs) |
| | self.lm_head = None |
| | |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.tok_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.tok_embeddings = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.model.get_output_embeddings() |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | raise NotImplementedError |
| |
|
| | def set_decoder(self, decoder): |
| | self.model = decoder |
| |
|
| | def get_decoder(self): |
| | return self.model |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = False, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | r""" |
| | Args: |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | Returns: |
| | Example: |
| | ```python |
| | >>> from transformers import AutoTokenizer, OpenLlamaForCausalLM |
| | >>> model = OpenLlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| | >>> prompt = "Hey, are you consciours? Can you talk to me?" |
| | >>> inputs = tokenizer(prompt, return_tensors="pt") |
| | >>> # Generate |
| | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." |
| | ```""" |
| | assert position_ids is None, "Position IDs are not supported" |
| | |
| | logits, _, past_key_values = self.model( |
| | input_ids=input_ids, |
| | inputs_embeds=inputs_embeds, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | attention_mask=attention_mask, |
| | ) |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = nn.CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| | shift_labels = shift_labels.view(-1).to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | output = CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, loss=loss) |
| | return output |
| |
|
| | def prepare_inputs_for_generation( |
| | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| | ): |
| | if past_key_values is not None: |
| | past_length = past_key_values[0][0].shape[1] |
| |
|
| | |
| | if input_ids.shape[1] > past_length: |
| | remove_prefix_length = past_length |
| | else: |
| | |
| | remove_prefix_length = input_ids.shape[1] - 1 |
| |
|
| | input_ids = input_ids[:, remove_prefix_length:] |
| |
|
| | |
| | if inputs_embeds is not None and past_key_values is None: |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | model_inputs.update( |
| | { |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache"), |
| | "attention_mask": attention_mask, |
| | } |
| | ) |
| | return model_inputs |
| |
|
| | @staticmethod |
| | def _reorder_cache(past_key_values, beam_idx): |
| | reordered_cache = () |
| | for layer_past in past_key_values: |
| | reordered_cache += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) |
| | return reordered_cache |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): |
| | if ( |
| | os.path.isdir(pretrained_model_name_or_path) |
| | and kwargs.get("config", None) is not None |
| | and getattr(kwargs["config"], "checkpoint_file", None) is not None |
| | ): |
| | |
| | torch_dtype = getattr(kwargs["config"], "torch_dtype", None) |
| | if isinstance(torch_dtype, str): |
| | torch_dtype = getattr(torch, torch_dtype) |
| | if torch_dtype is not None: |
| | torch.set_default_dtype(torch_dtype) |
| |
|
| | print("Loading checkpoint from directory") |
| | checkpoint_path = kwargs["config"].checkpoint_file |
| | checkpoint = torch.load(checkpoint_path) |
| |
|
| | state_dict = checkpoint["state_dict"] |
| | state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} |
| | state_dict = {f"model.{x}": y for x, y in state_dict.items()} |
| |
|
| | return super().from_pretrained(None, state_dict=state_dict, **kwargs) |
| | elif os.path.isdir(pretrained_model_name_or_path): |
| | |
| | print("Loading checkpoint from directory") |
| | checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
| | state_dict = torch.load(checkpoint_path) |
| |
|
| | |
| | state_dict = {f"model.{x}" if "model." not in x else x: y for x, y in state_dict.items()} |
| |
|
| | return super().from_pretrained(pretrained_model_name_or_path, state_dict=state_dict, **kwargs) |
| | else: |
| | return super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
| |
|