Instructions to use diffusers/matrix-game-2-modular with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use diffusers/matrix-game-2-modular with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("diffusers/matrix-game-2-modular", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from typing import List, Optional, Union, Dict | |
| import torch | |
| from diffusers import AutoencoderKLWan | |
| from diffusers.configuration_utils import FrozenDict | |
| from diffusers.schedulers import UniPCMultistepScheduler | |
| from diffusers.utils import logging | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.video_processor import VideoProcessor | |
| from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState | |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| # Constants | |
| FRAME_MULTIPLE = 4 | |
| DEFAULT_SAMPLES_PER_ACTION = 4 | |
| DEFAULT_FRAMES_PER_ACTION = 12 | |
| DEFAULT_MOUSE_DIM = 2 | |
| DEFAULT_KEYBOARD_DIM = 4 | |
| # Camera movement configuration | |
| CAMERA_MOVEMENT_VALUE = 0.1 | |
| CAMERA_VALUE_MAP = { | |
| "camera_up": [CAMERA_MOVEMENT_VALUE, 0], | |
| "camera_down": [-CAMERA_MOVEMENT_VALUE, 0], | |
| "camera_l": [0, -CAMERA_MOVEMENT_VALUE], | |
| "camera_r": [0, CAMERA_MOVEMENT_VALUE], | |
| "camera_ur": [CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE], | |
| "camera_ul": [CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE], | |
| "camera_dr": [-CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE], | |
| "camera_dl": [-CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE], | |
| } | |
| # Define available actions | |
| MOVEMENT_ACTIONS = ["forward", "left", "right"] | |
| COMPOUND_MOVEMENTS = ["forward_left", "forward_right"] | |
| CAMERA_ACTIONS = list(CAMERA_VALUE_MAP.keys()) | |
| # Keyboard action indices | |
| KEYBOARD_ACTION_INDICES = {"forward": 0, "back": 1, "left": 2, "right": 3} | |
| def sync_actions_to_frames( | |
| actions: List[str], | |
| num_frames: int, | |
| min_duration: int = 12 | |
| ) -> List[Dict[str, Union[str, int]]]: | |
| """ | |
| Synchronize a list of actions to fit exactly within the given number of frames | |
| using equal distribution strategy. | |
| Args: | |
| actions: List of action names to perform | |
| num_frames: Total frames to fill | |
| min_duration: Minimum frames per action (should be multiple of frame_multiple) | |
| frame_multiple: Actions must be multiples of this value | |
| Returns: | |
| List of action dictionaries with 'type', 'start_frame', and 'duration' | |
| """ | |
| if not actions: | |
| raise ValueError("No actions provided") | |
| max_possible_actions = num_frames // DEFAULT_FRAMES_PER_ACTION | |
| if len(actions) > max_possible_actions: | |
| actions = actions[:max_possible_actions] | |
| num_actions = len(actions) | |
| frames_per_action = num_frames // num_actions | |
| frames_per_action = (frames_per_action // FRAME_MULTIPLE) * FRAME_MULTIPLE | |
| frames_per_action = max(DEFAULT_FRAMES_PER_ACTION, frames_per_action) | |
| remaining_frames = num_frames - (frames_per_action * num_actions) | |
| output = [] | |
| current_frame = 0 | |
| for i, action in enumerate(actions): | |
| duration = frames_per_action if i != num_actions - 1 else num_frames - current_frame | |
| output.append({ | |
| "action_type": action, | |
| "start_frame": current_frame, | |
| "duration": duration | |
| }) | |
| current_frame += duration | |
| return output | |
| def actions_to_condition_tensors(actions, num_frames): | |
| keyboard_conditions = torch.zeros((num_frames, DEFAULT_KEYBOARD_DIM)) | |
| mouse_conditions = torch.zeros((num_frames, DEFAULT_MOUSE_DIM)) | |
| for action in actions: | |
| action_type = action['action_type'] | |
| start_frame = action['start_frame'] | |
| end_frame = start_frame + action['duration'] | |
| action_components = action_type.split("_") | |
| for component in action_components: | |
| if component in KEYBOARD_ACTION_INDICES: | |
| action_idx = KEYBOARD_ACTION_INDICES[component] | |
| keyboard_conditions[start_frame:end_frame, action_idx] = 1.0 | |
| if not "camera" in action_type: | |
| continue | |
| mouse_x = mouse_y = 0.0 | |
| for idx, component in enumerate(action_components): | |
| if not action_components[idx] == "camera": | |
| continue | |
| camera_action = f"camera_{action_components[idx+1]}" | |
| if camera_action not in CAMERA_VALUE_MAP: | |
| continue | |
| camera_values = CAMERA_VALUE_MAP[camera_action] | |
| mouse_x += camera_values[0] | |
| mouse_y += camera_values[1] | |
| mouse_conditions[start_frame:end_frame, 0] = mouse_x | |
| mouse_conditions[start_frame:end_frame, 1] = mouse_y | |
| return keyboard_conditions, mouse_conditions | |
| def _build_test_actions( | |
| movement_actions: List[str], | |
| compound_movements: List[str], | |
| camera_actions: List[str], | |
| ) -> List[str]: | |
| """Build comprehensive list of test action combinations. | |
| Args: | |
| movement_actions: List of basic movement actions | |
| compound_movements: List of compound movement actions | |
| camera_actions: List of camera control actions | |
| Returns: | |
| List of all action combinations to test | |
| """ | |
| # Create base test actions with repetition for variety | |
| test_actions = compound_movements * 5 + camera_actions * 5 + movement_actions * 5 | |
| # Add combined movement + camera actions | |
| for movement in movement_actions + compound_movements: | |
| for camera in camera_actions: | |
| combined_action = f"{movement}_{camera}" | |
| test_actions.append(combined_action) | |
| return test_actions | |
| def generate_random_condition_tensors(num_frames: int) -> Dict[str, torch.Tensor]: | |
| """Generate benchmark action sequences for testing. | |
| Args: | |
| num_frames: Total number of frames to generate | |
| num_samples_per_action: Number of samples per action type | |
| Returns: | |
| Dictionary containing keyboard and mouse conditions for benchmark actions | |
| """ | |
| # Build test action combinations | |
| actions = _build_test_actions( | |
| MOVEMENT_ACTIONS, COMPOUND_MOVEMENTS, CAMERA_ACTIONS | |
| ) | |
| actions = sync_actions_to_frames(actions, num_frames) | |
| return actions_to_condition_tensors(actions, num_frames) | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| def retrieve_latents( | |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | |
| ): | |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, "latents"): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError("Could not access latents of provided encoder_output") | |
| class MatrixGameWanActionInputStep(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| def description(self) -> str: | |
| return "Action Input step" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [] | |
| def inputs(self) -> List[InputParam]: | |
| return [InputParam("num_frames", type_hint=int, required=True), InputParam("actions", type_hint=List[str])] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "keyboard_conditions", | |
| type_hint=torch.Tensor, | |
| description="image embeddings used to guide the image generation", | |
| ), | |
| OutputParam( | |
| "mouse_conditions", | |
| type_hint=torch.Tensor, | |
| description="image embeddings used to guide the image generation", | |
| ) | |
| ] | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| # Get inputs and intermediates | |
| block_state = self.get_block_state(state) | |
| block_state.device = components._execution_device | |
| actions = block_state.actions | |
| if actions is not None: | |
| actions = sync_actions_to_frames(actions, block_state.num_frames) | |
| keyboard_conditions, mouse_conditions = actions_to_condition_tensors(actions, block_state.num_frames) | |
| else: | |
| keyboard_conditions, mouse_conditions = generate_random_condition_tensors(block_state.num_frames) | |
| block_state.keyboard_conditions = keyboard_conditions.to(block_state.device) | |
| block_state.mouse_conditions = mouse_conditions.to(block_state.device) | |
| # Add outputs | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class MatrixGameWanSetTimestepsStep(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", UniPCMultistepScheduler), | |
| ] | |
| def description(self) -> str: | |
| return "Step that sets the scheduler's timesteps for inference" | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("num_inference_steps", default=4), | |
| InputParam("timesteps"), | |
| InputParam("sigmas"), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), | |
| OutputParam( | |
| "num_inference_steps", | |
| type_hint=int, | |
| description="The number of denoising steps to perform at inference time", | |
| ), | |
| ] | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| block_state.device = components._execution_device | |
| block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( | |
| components.scheduler, | |
| block_state.num_inference_steps, | |
| block_state.device, | |
| block_state.timesteps, | |
| block_state.sigmas, | |
| ) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class MatrixGameWanPrepareLatentsStep(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ComponentSpec("vae", AutoencoderKLWan),] | |
| def description(self) -> str: | |
| return "Prepare latents step that prepares the latents for the text-to-video generation process" | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("height", type_hint=int), | |
| InputParam("width", type_hint=int), | |
| InputParam("num_frames", type_hint=int), | |
| InputParam("latents", type_hint=Optional[torch.Tensor]), | |
| InputParam("num_videos_per_prompt", type_hint=int, default=1), | |
| InputParam("generator"), | |
| InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" | |
| ) | |
| ] | |
| def check_inputs(components, block_state): | |
| if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( | |
| block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 | |
| ): | |
| raise ValueError( | |
| f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." | |
| ) | |
| if block_state.num_frames is not None and ( | |
| block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 | |
| ): | |
| raise ValueError( | |
| f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." | |
| ) | |
| def prepare_latents( | |
| components, | |
| batch_size: int, | |
| num_channels_latents: int = 16, | |
| height: int = 352, | |
| width: int = 640, | |
| num_frames: int = 81, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if latents is not None: | |
| return latents.to(device=device, dtype=dtype) | |
| num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1 | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| num_latent_frames, | |
| int(height) // components.vae_scale_factor_spatial, | |
| int(width) // components.vae_scale_factor_spatial, | |
| ) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| return latents | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| block_state.height = block_state.height or components.default_height | |
| block_state.width = block_state.width or components.default_width | |
| block_state.num_frames = block_state.num_frames or components.default_num_frames | |
| block_state.device = components._execution_device | |
| block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality | |
| block_state.num_channels_latents = components.num_channels_latents | |
| self.check_inputs(components, block_state) | |
| block_state.latents = self.prepare_latents( | |
| components, | |
| 1, | |
| block_state.num_channels_latents, | |
| block_state.height, | |
| block_state.width, | |
| block_state.num_frames, | |
| block_state.dtype, | |
| block_state.device, | |
| block_state.generator, | |
| block_state.latents, | |
| ) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class MatrixGameWanPrepareImageMaskLatentsStep(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec("vae", AutoencoderKLWan), | |
| ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8})) | |
| ] | |
| def description(self) -> str: | |
| return "Prepare latents step that prepares the latents for the text-to-video generation process" | |
| def inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam("image"), | |
| InputParam("height", type_hint=int), | |
| InputParam("width", type_hint=int), | |
| InputParam("num_frames", type_hint=int), | |
| InputParam("image_mask_latents", type_hint=Optional[torch.Tensor]), | |
| InputParam("num_videos_per_prompt", type_hint=int, default=1), | |
| InputParam("generator"), | |
| InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [ | |
| OutputParam( | |
| "image_mask_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" | |
| ) | |
| ] | |
| def check_inputs(components, block_state): | |
| if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( | |
| block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 | |
| ): | |
| raise ValueError( | |
| f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." | |
| ) | |
| def prepare_latents( | |
| components, | |
| image, | |
| batch_size: int, | |
| num_channels_latents: int = 16, | |
| height: int = 352, | |
| width: int = 640, | |
| num_frames: int = 81, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if latents is not None: | |
| return latents.to(device=device, dtype=dtype) | |
| image = components.video_processor.preprocess(image, height, width).to(device, torch.float32) | |
| image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] | |
| video_condition = torch.cat( | |
| [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 | |
| ) | |
| video_condition = video_condition.to(device=device, dtype=components.vae.dtype) | |
| latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") | |
| latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) | |
| latents_mean = ( | |
| torch.tensor(components.vae.config.latents_mean) | |
| .view(1, components.vae.config.z_dim, 1, 1, 1) | |
| .to(device, dtype) | |
| ) | |
| latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(1, components.vae.config.z_dim, 1, 1, 1).to( | |
| device, dtype | |
| ) | |
| latent_condition = latent_condition.to(dtype) | |
| latent_condition = (latent_condition - latents_mean) * latents_std | |
| latent_height = height // components.vae_scale_factor_spatial | |
| latent_width = width // components.vae_scale_factor_spatial | |
| mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) | |
| mask_lat_size[:, :, list(range(1, num_frames))] = 0 | |
| first_frame_mask = mask_lat_size[:, :, 0:1] | |
| first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) | |
| mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) | |
| mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) | |
| mask_lat_size = mask_lat_size.transpose(1, 2).to(latent_condition.device) | |
| image_mask_latents = torch.concat([mask_lat_size, latent_condition], dim=1) | |
| return image_mask_latents | |
| def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| block_state.height = block_state.height or components.default_height | |
| block_state.width = block_state.width or components.default_width | |
| block_state.num_frames = block_state.num_frames or components.default_num_frames | |
| block_state.device = components._execution_device | |
| block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality | |
| block_state.num_channels_latents = components.num_channels_latents | |
| self.check_inputs(components, block_state) | |
| block_state.image_mask_latents = self.prepare_latents( | |
| components, | |
| block_state.image, | |
| 1, | |
| block_state.num_channels_latents, | |
| block_state.height, | |
| block_state.width, | |
| block_state.num_frames, | |
| block_state.dtype, | |
| block_state.device, | |
| block_state.generator, | |
| block_state.image_mask_latents, | |
| ) | |
| self.set_block_state(state, block_state) | |
| return components, state | |