| import torch |
| import torch.nn as nn |
| from transformers import PretrainedConfig, PreTrainedModel |
| from diffusionLM.model.diffusionLM import LLaDAModel |
|
|
| class DiffusionConfig(PretrainedConfig): |
| """Configuration class for Diffusion-LLM model.""" |
| model_type = "diffusionLM" |
| |
| def __init__( |
| self, |
| vocab_size: int = 50257, |
| hidden_size: int = 768, |
| num_hidden_layers: int = 12, |
| num_attention_heads: int = 12, |
| intermediate_size: int = 3072, |
| hidden_dropout_prob: float = 0.1, |
| attention_probs_dropout_prob: float = 0.1, |
| max_position_embeddings: int = 1024, |
| initializer_range: float = 0.02, |
| layer_norm_eps: float = 1e-12, |
| pad_token_id: int = 0, |
| mask_token_id: int = 50256, |
| eos_token_id: int = 50256, |
| num_timesteps: int = 100, |
| time_embed_dim: int = 128, |
| **kwargs |
| ): |
| super().__init__(pad_token_id=pad_token_id, **kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.max_position_embeddings = max_position_embeddings |
| self.initializer_range = initializer_range |
| self.layer_norm_eps = layer_norm_eps |
| self.mask_token_id = mask_token_id |
| self.eos_token_id = eos_token_id |
| self.num_timesteps = num_timesteps |
| self.time_embed_dim = time_embed_dim |
|
|
| class DiffusionLLM(PreTrainedModel): |
| """Main Diffusion-LLM model class""" |
| config_class = DiffusionConfig |
| base_model_prefix = "diffusionLM" |
|
|
| def __init__(self, config: DiffusionConfig): |
| super().__init__(config) |
| self.model = LLaDAModel(config) |
| self.init_weights() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| timesteps=None, |
| labels=None, |
| return_dict=True, |
| ): |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| timesteps=timesteps, |
| labels=labels, |
| ) |
| |
| return outputs |
|
|
| def generate( |
| self, |
| prompt=None, |
| max_length=100, |
| num_inference_steps=50, |
| temperature=1.0, |
| strategy='random', |
| top_p=0.9, |
| top_k=50, |
| num_beams=5, |
| return_scores=False, |
| use_streaming=False, |
| callback_fn=None |
| ): |
| """Unified generation interface""" |
| if use_streaming: |
| return self.generate_stream( |
| prompt=prompt, |
| max_length=max_length, |
| num_inference_steps=num_inference_steps, |
| temperature=temperature, |
| strategy=strategy, |
| top_p=top_p, |
| top_k=top_k, |
| num_beams=num_beams, |
| callback_fn=callback_fn |
| ) |
| else: |
| return self.model.generate( |
| prompt=prompt, |
| max_length=max_length, |
| num_inference_steps=num_inference_steps, |
| temperature=temperature, |
| strategy=strategy, |
| top_p=top_p, |
| top_k=top_k, |
| num_beams=num_beams, |
| return_scores=return_scores |
| ) |
|
|
| def generate_stream(self, **kwargs): |
| """Streaming generation wrapper""" |
| return self.model.generate_stream(**kwargs) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| """Prepare inputs for generation compatibility""" |
| return { |
| "input_ids": input_ids, |
| "attention_mask": kwargs.get("attention_mask", None), |
| "timesteps": kwargs.get("timesteps", None), |
| } |
|
|
| @staticmethod |
| def _reorder_cache(past, beam_idx): |
| """Reorder cache for beam search compatibility""" |
| return past |
|
|