Instructions to use diffusers/matrix-game-2-modular with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use diffusers/matrix-game-2-modular with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("diffusers/matrix-game-2-modular", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from typing import Any, List, Tuple, Optional, Union, Dict | |
| from einops import rearrange | |
| from flash_attn import flash_attn_func | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| from torch.nn.attention.flex_attention import flex_attention | |
| try: | |
| import flash_attn | |
| except: | |
| from flash_attn import flash_attn_func | |
| FLASH_ATTN_3_AVAILABLE = False | |
| DISABLE_COMPILE = False # get os env | |
| flex_attention = torch.compile( | |
| flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" | |
| ) | |
| import torch | |
| from typing import Union, Tuple, List | |
| def _to_tuple(x, dim=2): | |
| if isinstance(x, int): | |
| return (x,) * dim | |
| elif len(x) == dim: | |
| return x | |
| else: | |
| raise ValueError(f"Expected length {dim} or int, but got {x}") | |
| def get_meshgrid_nd(start, *args, dim=2): | |
| """ | |
| Get n-D meshgrid with start, stop and num. | |
| Args: | |
| start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, | |
| step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num | |
| should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in | |
| n-tuples. | |
| *args: See above. | |
| dim (int): Dimension of the meshgrid. Defaults to 2. | |
| Returns: | |
| grid (np.ndarray): [dim, ...] | |
| """ | |
| if len(args) == 0: | |
| # start is grid_size | |
| num = _to_tuple(start, dim=dim) | |
| start = (0,) * dim | |
| stop = num | |
| elif len(args) == 1: | |
| # start is start, args[0] is stop, step is 1 | |
| start = _to_tuple(start, dim=dim) | |
| stop = _to_tuple(args[0], dim=dim) | |
| num = [stop[i] - start[i] for i in range(dim)] | |
| elif len(args) == 2: | |
| # start is start, args[0] is stop, args[1] is num | |
| start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 | |
| stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 | |
| num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 | |
| else: | |
| raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") | |
| # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) | |
| axis_grid = [] | |
| for i in range(dim): | |
| a, b, n = start[i], stop[i], num[i] | |
| g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n] | |
| axis_grid.append(g) | |
| grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] | |
| grid = torch.stack(grid, dim=0) # [dim, W, H, D] | |
| return grid | |
| ################################################################################# | |
| # Rotary Positional Embedding Functions # | |
| ################################################################################# | |
| # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 | |
| def reshape_for_broadcast( | |
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
| x: torch.Tensor, | |
| head_first=False, | |
| ): | |
| """ | |
| Reshape frequency tensor for broadcasting it with another tensor. | |
| This function reshapes the frequency tensor to have the same shape as the target tensor 'x' | |
| for the purpose of broadcasting the frequency tensor during element-wise operations. | |
| Notes: | |
| When using FlashMHAModified, head_first should be False. | |
| When using Attention, head_first should be True. | |
| Args: | |
| freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. | |
| x (torch.Tensor): Target tensor for broadcasting compatibility. | |
| head_first (bool): head dimension first (except batch dim) or not. | |
| Returns: | |
| torch.Tensor: Reshaped frequency tensor. | |
| Raises: | |
| AssertionError: If the frequency tensor doesn't match the expected shape. | |
| AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. | |
| """ | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| if isinstance(freqs_cis, tuple): | |
| # freqs_cis: (cos, sin) in real space | |
| if head_first: | |
| assert freqs_cis[0].shape == ( | |
| x.shape[-2], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" | |
| shape = [ | |
| d if i == ndim - 2 or i == ndim - 1 else 1 | |
| for i, d in enumerate(x.shape) | |
| ] | |
| else: | |
| # assert freqs_cis[0].shape == ( | |
| # x.shape[1], | |
| # x.shape[-1], | |
| # ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" | |
| # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]] | |
| return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) | |
| else: | |
| # freqs_cis: values in complex space | |
| if head_first: | |
| assert freqs_cis.shape == ( | |
| x.shape[-2], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" | |
| shape = [ | |
| d if i == ndim - 2 or i == ndim - 1 else 1 | |
| for i, d in enumerate(x.shape) | |
| ] | |
| else: | |
| assert freqs_cis.shape == ( | |
| x.shape[1], | |
| x.shape[-1], | |
| ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def rotate_half(x): | |
| x_real, x_imag = ( | |
| x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) | |
| ) # [B, S, H, D//2] | |
| return torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
| def apply_rotary_emb( | |
| xq: torch.Tensor, | |
| xk: torch.Tensor, | |
| freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
| head_first: bool = False, | |
| start_offset: int = 0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply rotary embeddings to input tensors using the given frequency tensor. | |
| This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided | |
| frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor | |
| is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are | |
| returned as real tensors. | |
| Args: | |
| xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] | |
| xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] | |
| freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. | |
| head_first (bool): head dimension first (except batch dim) or not. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | |
| """ | |
| # print(freqs_cis[0].shape, xq.shape, xk.shape) | |
| xk_out = None | |
| assert isinstance(freqs_cis, tuple) | |
| if isinstance(freqs_cis, tuple): | |
| cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] | |
| cos, sin = cos.to(xq.device), sin.to(xq.device) | |
| # real * cos - imag * sin | |
| # imag * cos + real * sin | |
| xq_out = (xq.float() * cos[:, start_offset:start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset:start_offset + xq.shape[1], :, :]).type_as(xq) | |
| xk_out = (xk.float() * cos[:, start_offset:start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset:start_offset + xk.shape[1], :, :]).type_as(xk) | |
| else: | |
| # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) | |
| xq_ = torch.view_as_complex( | |
| xq.float().reshape(*xq.shape[:-1], -1, 2) | |
| ) # [B, S, H, D//2] | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( | |
| xq.device | |
| ) # [S, D//2] --> [1, S, 1, D//2] | |
| # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) | |
| # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) | |
| xk_ = torch.view_as_complex( | |
| xk.float().reshape(*xk.shape[:-1], -1, 2) | |
| ) # [B, S, H, D//2] | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) | |
| return xq_out, xk_out | |
| def get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| start, | |
| *args, | |
| theta=10000.0, | |
| use_real=False, | |
| theta_rescale_factor: Union[float, List[float]] = 1.0, | |
| interpolation_factor: Union[float, List[float]] = 1.0, | |
| ): | |
| """ | |
| This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. | |
| Args: | |
| rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. | |
| sum(rope_dim_list) should equal to head_dim of attention layer. | |
| start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, | |
| args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. | |
| *args: See above. | |
| theta (float): Scaling factor for frequency computation. Defaults to 10000.0. | |
| use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. | |
| Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real | |
| part and an imaginary part separately. | |
| theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. | |
| Returns: | |
| pos_embed (torch.Tensor): [HW, D/2] | |
| """ | |
| grid = get_meshgrid_nd( | |
| start, *args, dim=len(rope_dim_list) | |
| ) # [3, W, H, D] / [2, W, H] | |
| if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): | |
| theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) | |
| elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: | |
| theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) | |
| assert len(theta_rescale_factor) == len( | |
| rope_dim_list | |
| ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" | |
| if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): | |
| interpolation_factor = [interpolation_factor] * len(rope_dim_list) | |
| elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: | |
| interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) | |
| assert len(interpolation_factor) == len( | |
| rope_dim_list | |
| ), "len(interpolation_factor) should equal to len(rope_dim_list)" | |
| # use 1/ndim of dimensions to encode grid_axis | |
| embs = [] | |
| for i in range(len(rope_dim_list)): | |
| emb = get_1d_rotary_pos_embed( | |
| rope_dim_list[i], | |
| grid[i].reshape(-1), | |
| theta, | |
| use_real=use_real, | |
| theta_rescale_factor=theta_rescale_factor[i], | |
| interpolation_factor=interpolation_factor[i], | |
| ) # 2 x [WHD, rope_dim_list[i]] | |
| embs.append(emb) | |
| if use_real: | |
| cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) | |
| sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) | |
| return cos, sin | |
| else: | |
| emb = torch.cat(embs, dim=1) # (WHD, D/2) | |
| return emb | |
| def get_1d_rotary_pos_embed( | |
| dim: int, | |
| pos: Union[torch.FloatTensor, int], | |
| theta: float = 10000.0, | |
| use_real: bool = False, | |
| theta_rescale_factor: float = 1.0, | |
| interpolation_factor: float = 1.0, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Precompute the frequency tensor for complex exponential (cis) with given dimensions. | |
| (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) | |
| This function calculates a frequency tensor with complex exponential using the given dimension 'dim' | |
| and the end index 'end'. The 'theta' parameter scales the frequencies. | |
| The returned tensor contains complex values in complex64 data type. | |
| Args: | |
| dim (int): Dimension of the frequency tensor. | |
| pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar | |
| theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. | |
| use_real (bool, optional): If True, return real part and imaginary part separately. | |
| Otherwise, return complex numbers. | |
| theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. | |
| Returns: | |
| freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] | |
| freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] | |
| """ | |
| if isinstance(pos, int): | |
| pos = torch.arange(pos, device=torch.cuda.current_device()).float() | |
| # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning | |
| # has some connection to NTK literature | |
| if theta_rescale_factor != 1.0: | |
| theta *= theta_rescale_factor ** (dim / (dim - 2)) | |
| freqs = 1.0 / ( | |
| theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim) | |
| ) # [D/2] | |
| # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" | |
| freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] | |
| if use_real: | |
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | |
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | |
| return freqs_cos, freqs_sin | |
| else: | |
| freqs_cis = torch.polar( | |
| torch.ones_like(freqs), freqs | |
| ) # complex64 # [S, D/2] | |
| return freqs_cis | |
| class MatrixGameWanRMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| r""" | |
| Args: | |
| x(Tensor): Shape [B, L, C] | |
| """ | |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| class ActionModule(nn.Module): | |
| """ | |
| action module from https://arxiv.org/pdf/2501.08325 | |
| 鼠标控制信号的输入是一个 L*D 的向量 | |
| 键盘同样 | |
| """ | |
| def __init__( | |
| self, | |
| mouse_dim_in: int = 2, | |
| keyboard_dim_in: int = 6, | |
| hidden_size: int = 128, | |
| img_hidden_size: int = 1536, | |
| keyboard_hidden_dim: int = 1024, | |
| mouse_hidden_dim: int = 1024, | |
| vae_time_compression_ratio: int = 4, | |
| windows_size: int = 3, | |
| heads_num: int = 16, | |
| patch_size: list = [1, 2, 2], | |
| qk_norm: bool = True, | |
| qkv_bias: bool = False, | |
| rope_dim_list: list = [8, 28, 28], | |
| rope_theta=256, | |
| mouse_qk_dim_list=[8, 28, 28], | |
| enable_mouse=True, | |
| enable_keyboard=True, | |
| local_attn_size=6, | |
| blocks=[], | |
| ): | |
| device = None | |
| super().__init__() | |
| self.local_attn_size = local_attn_size | |
| self.enable_mouse = enable_mouse | |
| self.enable_keyboard = enable_keyboard | |
| self.rope_dim_list = rope_dim_list | |
| self.rope_theta = rope_theta | |
| if self.enable_keyboard: | |
| self.keyboard_embed = nn.Sequential( | |
| nn.Linear(keyboard_dim_in, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.mouse_qk_dim_list = mouse_qk_dim_list | |
| self.heads_num = heads_num | |
| if self.enable_mouse: | |
| c = mouse_hidden_dim | |
| self.mouse_mlp = torch.nn.Sequential( | |
| torch.nn.Linear( | |
| mouse_dim_in * vae_time_compression_ratio * windows_size | |
| + img_hidden_size, | |
| c, | |
| bias=True, | |
| ), | |
| torch.nn.GELU(approximate="tanh"), | |
| torch.nn.Linear(c, c), | |
| torch.nn.LayerNorm(c), | |
| ) | |
| head_dim = c // heads_num | |
| self.t_qkv = nn.Linear(c, c * 3, bias=qkv_bias) | |
| self.img_attn_q_norm = ( | |
| MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| ) | |
| self.img_attn_k_norm = ( | |
| MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| ) | |
| self.proj_mouse = nn.Linear(c, img_hidden_size, bias=qkv_bias) | |
| if self.enable_keyboard: | |
| head_dim_key = keyboard_hidden_dim // heads_num | |
| self.key_attn_q_norm = ( | |
| MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() | |
| ) | |
| self.key_attn_k_norm = ( | |
| MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity() | |
| ) | |
| self.mouse_attn_q = nn.Linear( | |
| img_hidden_size, keyboard_hidden_dim, bias=qkv_bias | |
| ) | |
| self.keyboard_attn_kv = nn.Linear( | |
| hidden_size * windows_size * vae_time_compression_ratio, | |
| keyboard_hidden_dim * 2, | |
| bias=qkv_bias, | |
| ) | |
| self.proj_keyboard = nn.Linear( | |
| keyboard_hidden_dim, img_hidden_size, bias=qkv_bias | |
| ) | |
| self.vae_time_compression_ratio = vae_time_compression_ratio | |
| self.windows_size = windows_size | |
| self.patch_size = patch_size | |
| self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed( | |
| 7500, | |
| self.patch_size[1], | |
| self.patch_size[2], | |
| 64, | |
| self.mouse_qk_dim_list, | |
| start_offset=0, | |
| ) | |
| def patchify(self, x, patch_size): | |
| """ | |
| x : (N C T H W) | |
| """ | |
| pt, ph, pw = self.patch_size | |
| t, h, w = x.shape[2] // pt, x.shape[3] // ph, x.shape[4] // pw | |
| c = x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], c, t, pt, h, ph, w, pw)) | |
| x = torch.einsum("nctohpwq->nthwcopq", x) | |
| x = x.reshape(shape=(x.shape[0], t * h * w, c * pt * ph * pw)) | |
| return x | |
| def unpatchify(self, x, t, h, w, patch_size): | |
| """ | |
| x: (N, T, patch_size**2 * C) | |
| imgs: (N, H, W, C) | |
| """ | |
| c = x.shape[2] // patch_size # self.unpatchify_channels | |
| pt, ph, pw = self.patch_size | |
| assert t * h * w == x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) | |
| x = torch.einsum("nthwcopq->nctohpwq", x) | |
| imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | |
| return imgs | |
| def get_rotary_pos_embed( | |
| self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0 | |
| ): | |
| target_ndim = 3 | |
| ndim = 5 - 2 | |
| latents_size = [video_length + start_offset, height, width] | |
| if isinstance(self.patch_size, int): | |
| assert all(s % self.patch_size == 0 for s in latents_size), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [s // self.patch_size for s in latents_size] | |
| elif isinstance(self.patch_size, list): | |
| assert all( | |
| s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size) | |
| ), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [ | |
| s // self.patch_size[idx] for idx, s in enumerate(latents_size) | |
| ] | |
| if len(rope_sizes) != target_ndim: | |
| rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
| if rope_dim_list is None: | |
| rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
| assert ( | |
| sum(rope_dim_list) == head_dim | |
| ), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
| freqs_cos, freqs_sin = get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| rope_sizes, | |
| theta=self.rope_theta, | |
| use_real=True, | |
| theta_rescale_factor=1, | |
| ) | |
| return freqs_cos[ | |
| -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : | |
| ], freqs_sin[ | |
| -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] : | |
| ] | |
| def forward( | |
| self, | |
| x, | |
| tt, | |
| th, | |
| tw, | |
| mouse_condition=None, | |
| keyboard_condition=None, | |
| block_mask_mouse=None, | |
| block_mask_keyboard=None, | |
| is_causal=False, | |
| kv_cache_mouse=None, | |
| kv_cache_keyboard=None, | |
| start_frame=0, | |
| use_rope_keyboard=True, | |
| num_frame_per_block=3, | |
| ): | |
| """ | |
| hidden_states: B, tt*th*tw, C | |
| mouse_condition: B, N_frames, C1 | |
| keyboard_condition: B, N_frames, C2 | |
| """ | |
| assert use_rope_keyboard == True | |
| B, N_frames, C = keyboard_condition.shape | |
| assert tt * th * tw == x.shape[1] | |
| assert ( | |
| (N_frames - 1) + self.vae_time_compression_ratio | |
| ) % self.vae_time_compression_ratio == 0 | |
| N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1 | |
| # Defined freqs_cis early so it's available for both mouse and keyboard | |
| freqs_cis = (self.freqs_cos, self.freqs_sin) | |
| assert ( | |
| N_feats == tt and ((is_causal and kv_cache_mouse == None) or not is_causal) | |
| ) or ( | |
| (N_frames - 1) // self.vae_time_compression_ratio + 1 == start_frame + num_frame_per_block and is_causal | |
| ) | |
| if self.enable_mouse and mouse_condition is not None: | |
| hidden_states = rearrange( | |
| x, "B (T S) C -> (B S) T C", T=tt, S=th * tw | |
| ) # 65*272*480 -> 17*(272//16)*(480//16) -> 8670 | |
| B, N_frames, C = mouse_condition.shape | |
| else: | |
| hidden_states = x | |
| # padding | |
| pad_t = self.vae_time_compression_ratio * self.windows_size | |
| if self.enable_mouse and mouse_condition is not None: | |
| pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1) | |
| mouse_condition = torch.cat([pad, mouse_condition], dim=1) | |
| if is_causal and kv_cache_mouse is not None: | |
| mouse_condition = mouse_condition[ | |
| :, | |
| self.vae_time_compression_ratio | |
| * (N_feats - num_frame_per_block - self.windows_size) | |
| + pad_t :, | |
| :, | |
| ] | |
| group_mouse = [ | |
| mouse_condition[ | |
| :, | |
| self.vae_time_compression_ratio * (i - self.windows_size) | |
| + pad_t : i * self.vae_time_compression_ratio + pad_t, | |
| :, | |
| ] | |
| for i in range(num_frame_per_block) | |
| ] | |
| else: | |
| group_mouse = [ | |
| mouse_condition[ | |
| :, | |
| self.vae_time_compression_ratio * (i - self.windows_size) | |
| + pad_t : i * self.vae_time_compression_ratio + pad_t, | |
| :, | |
| ] | |
| for i in range(N_feats) | |
| ] | |
| group_mouse = torch.stack(group_mouse, dim=1) | |
| S = th * tw | |
| group_mouse = group_mouse.unsqueeze(-1).expand( | |
| B, num_frame_per_block, pad_t, C, S | |
| ) | |
| group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape( | |
| B * S, num_frame_per_block, pad_t * C | |
| ) | |
| group_mouse = torch.cat([hidden_states, group_mouse], dim=-1) | |
| group_mouse = self.mouse_mlp(group_mouse) | |
| # qkv | |
| mouse_qkv = self.t_qkv(group_mouse) | |
| q, k, v = rearrange( | |
| mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num | |
| ) # BHW F H C | |
| q = self.img_attn_q_norm(q).to(v) | |
| k = self.img_attn_k_norm(k).to(v) | |
| # rope embd | |
| # freqs_cis = (self.freqs_cos, self.freqs_sin) | |
| q, k = apply_rotary_emb( | |
| q, k, freqs_cis, start_offset=start_frame, head_first=False | |
| ) | |
| ## TODO: adding cache here | |
| if is_causal: | |
| if kv_cache_mouse is None: | |
| assert ( | |
| q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 | |
| ) # == 880, f"{q.shape[0]},{k.shape[0]}" | |
| padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] | |
| padded_q = torch.cat( | |
| [ | |
| q, | |
| torch.zeros( | |
| [q.shape[0], padded_length, q.shape[2], q.shape[3]], | |
| device=q.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_k = torch.cat( | |
| [ | |
| k, | |
| torch.zeros( | |
| [k.shape[0], padded_length, k.shape[2], k.shape[3]], | |
| device=k.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_v = torch.cat( | |
| [ | |
| v, | |
| torch.zeros( | |
| [v.shape[0], padded_length, v.shape[2], v.shape[3]], | |
| device=v.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| attn = flex_attention( | |
| query=padded_q.transpose(2, 1), # after: B, HW, F, C | |
| key=padded_k.transpose(2, 1), | |
| value=padded_v.transpose(2, 1), | |
| block_mask=block_mask_mouse, | |
| )[:, :, :-padded_length].transpose(2, 1) | |
| else: | |
| current_start = start_frame | |
| current_end = current_start + q.shape[1] | |
| assert q.shape[1] == num_frame_per_block | |
| sink_size = 0 | |
| max_attention_size = self.local_attn_size | |
| sink_tokens = sink_size * 1 | |
| kv_cache_size = kv_cache_mouse["k"].shape[1] | |
| num_new_tokens = q.shape[1] | |
| if (current_end > kv_cache_mouse["global_end_index"].item()) and ( | |
| num_new_tokens + kv_cache_mouse["local_end_index"].item() | |
| > kv_cache_size | |
| ): | |
| num_evicted_tokens = ( | |
| num_new_tokens | |
| + kv_cache_mouse["local_end_index"].item() | |
| - kv_cache_size | |
| ) | |
| num_rolled_tokens = ( | |
| kv_cache_mouse["local_end_index"].item() | |
| - num_evicted_tokens | |
| - sink_tokens | |
| ) | |
| kv_cache_mouse["k"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_mouse["k"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| kv_cache_mouse["v"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_mouse["v"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| # Insert the new keys/values at the end | |
| local_end_index = ( | |
| kv_cache_mouse["local_end_index"].item() | |
| + current_end | |
| - kv_cache_mouse["global_end_index"].item() | |
| - num_evicted_tokens | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| else: | |
| local_end_index = ( | |
| kv_cache_mouse["local_end_index"].item() | |
| + current_end | |
| - kv_cache_mouse["global_end_index"].item() | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| kv_cache_mouse["k"][:, local_start_index:local_end_index] = k | |
| kv_cache_mouse["v"][:, local_start_index:local_end_index] = v | |
| if FLASH_ATTN_3_AVAILABLE: | |
| attn, attn_prob = flash_attn.flash_attn_func( | |
| q, | |
| kv_cache_mouse["k"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| kv_cache_mouse["v"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| ) | |
| else: | |
| attn = flash_attn_func( | |
| q, | |
| kv_cache_mouse["k"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| kv_cache_mouse["v"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| ) | |
| kv_cache_mouse["global_end_index"].fill_(current_end) | |
| kv_cache_mouse["local_end_index"].fill_(local_end_index) | |
| else: | |
| attn = flash_attn_func( | |
| q, # 880, f, 16, 64 | |
| k, # 880, f, 16, 64 | |
| v, # 880, f, 16, 64 | |
| ) | |
| # Compute cu_squlens and max_seqlen for flash attention | |
| # qk norm | |
| attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B) | |
| hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B) | |
| attn = self.proj_mouse(attn) | |
| hidden_states = hidden_states + attn | |
| if self.enable_keyboard and keyboard_condition is not None: | |
| pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1) | |
| keyboard_condition = torch.cat([pad, keyboard_condition], dim=1) | |
| if is_causal and kv_cache_keyboard is not None: | |
| keyboard_condition = keyboard_condition[ | |
| :, | |
| self.vae_time_compression_ratio | |
| * (N_feats - num_frame_per_block - self.windows_size) | |
| + pad_t :, | |
| :, | |
| ] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:] | |
| keyboard_condition = self.keyboard_embed(keyboard_condition) | |
| group_keyboard = [ | |
| keyboard_condition[ | |
| :, | |
| self.vae_time_compression_ratio * (i - self.windows_size) | |
| + pad_t : i * self.vae_time_compression_ratio + pad_t, | |
| :, | |
| ] | |
| for i in range(num_frame_per_block) | |
| ] | |
| else: | |
| keyboard_condition = self.keyboard_embed(keyboard_condition) | |
| group_keyboard = [ | |
| keyboard_condition[ | |
| :, | |
| self.vae_time_compression_ratio * (i - self.windows_size) | |
| + pad_t : i * self.vae_time_compression_ratio + pad_t, | |
| :, | |
| ] | |
| for i in range(N_feats) | |
| ] | |
| group_keyboard = torch.stack(group_keyboard, dim=1) # B F RW C | |
| group_keyboard = group_keyboard.reshape( | |
| shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1) | |
| ) | |
| # apply cross attn | |
| mouse_q = self.mouse_attn_q(hidden_states) | |
| keyboard_kv = self.keyboard_attn_kv(group_keyboard) | |
| B, L, HD = mouse_q.shape | |
| D = HD // self.heads_num | |
| q = mouse_q.view(B, L, self.heads_num, D) | |
| B, L, KHD = keyboard_kv.shape | |
| k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4) | |
| # Compute cu_squlens and max_seqlen for flash attention | |
| # qk norm | |
| q = self.key_attn_q_norm(q).to(v) | |
| k = self.key_attn_k_norm(k).to(v) | |
| S = th * tw | |
| assert S == 880 | |
| # position embed | |
| if use_rope_keyboard: | |
| B, TS, H, D = q.shape | |
| T_ = TS // S | |
| q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D) | |
| q, k = apply_rotary_emb( | |
| q, k, freqs_cis, start_offset=start_frame, head_first=False | |
| ) | |
| k1, k2, k3, k4 = k.shape | |
| k = k.expand(S, k2, k3, k4) | |
| v = v.expand(S, k2, k3, k4) | |
| if is_causal: | |
| if kv_cache_keyboard is None: | |
| assert q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0 | |
| padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] | |
| padded_q = torch.cat( | |
| [ | |
| q, | |
| torch.zeros( | |
| [q.shape[0], padded_length, q.shape[2], q.shape[3]], | |
| device=q.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_k = torch.cat( | |
| [ | |
| k, | |
| torch.zeros( | |
| [k.shape[0], padded_length, k.shape[2], k.shape[3]], | |
| device=k.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_v = torch.cat( | |
| [ | |
| v, | |
| torch.zeros( | |
| [v.shape[0], padded_length, v.shape[2], v.shape[3]], | |
| device=v.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| attn = flex_attention( | |
| query=padded_q.transpose(2, 1), # after: B, HW, F, C | |
| key=padded_k.transpose(2, 1), | |
| value=padded_v.transpose(2, 1), | |
| block_mask=block_mask_keyboard, | |
| )[:, :, :-padded_length].transpose(2, 1) | |
| else: | |
| current_start = start_frame | |
| current_end = current_start + k.shape[1] | |
| assert k.shape[1] == num_frame_per_block | |
| sink_size = 0 | |
| max_attention_size = self.local_attn_size | |
| sink_tokens = sink_size * 1 | |
| kv_cache_size = kv_cache_keyboard["k"].shape[1] | |
| num_new_tokens = k.shape[1] | |
| if ( | |
| current_end > kv_cache_keyboard["global_end_index"].item() | |
| ) and ( | |
| num_new_tokens + kv_cache_keyboard["local_end_index"].item() | |
| > kv_cache_size | |
| ): | |
| num_evicted_tokens = ( | |
| num_new_tokens | |
| + kv_cache_keyboard["local_end_index"].item() | |
| - kv_cache_size | |
| ) | |
| num_rolled_tokens = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| - num_evicted_tokens | |
| - sink_tokens | |
| ) | |
| kv_cache_keyboard["k"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_keyboard["k"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| kv_cache_keyboard["v"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_keyboard["v"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| # Insert the new keys/values at the end | |
| local_end_index = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| + current_end | |
| - kv_cache_keyboard["global_end_index"].item() | |
| - num_evicted_tokens | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| else: | |
| local_end_index = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| + current_end | |
| - kv_cache_keyboard["global_end_index"].item() | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| assert ( | |
| k.shape[0] == 880 | |
| ) # BS == 1 or the cache should not be saved/ load method should be modified | |
| kv_cache_keyboard["k"][:, local_start_index:local_end_index] = ( | |
| k[:1] | |
| ) | |
| kv_cache_keyboard["v"][:, local_start_index:local_end_index] = ( | |
| v[:1] | |
| ) | |
| if FLASH_ATTN_3_AVAILABLE: | |
| attn, attn_prob = flash_attn.flash_attn_func( | |
| q, | |
| kv_cache_keyboard["k"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ].repeat(S, 1, 1, 1), | |
| kv_cache_keyboard["v"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ].repeat(S, 1, 1, 1), | |
| ) | |
| else: | |
| attn = flash_attn_func( | |
| q, | |
| kv_cache_keyboard["k"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ].repeat(S, 1, 1, 1), | |
| kv_cache_keyboard["v"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ].repeat(S, 1, 1, 1), | |
| ) | |
| kv_cache_keyboard["global_end_index"].fill_(current_end) | |
| kv_cache_keyboard["local_end_index"].fill_(local_end_index) | |
| else: | |
| attn = flash_attn_func( | |
| q, # 1, f*880, 16, 64 | |
| k, # 1, f, 16, 64 | |
| v, # 1, f, 16, 64 | |
| causal=False, | |
| ) | |
| attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S) | |
| else: | |
| if is_causal: | |
| if kv_cache_keyboard is None: | |
| padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1] | |
| padded_q = torch.cat( | |
| [ | |
| q, | |
| torch.zeros( | |
| [q.shape[0], padded_length, q.shape[2], q.shape[3]], | |
| device=q.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_k = torch.cat( | |
| [ | |
| k, | |
| torch.zeros( | |
| [k.shape[0], padded_length, k.shape[2], k.shape[3]], | |
| device=k.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| padded_v = torch.cat( | |
| [ | |
| v, | |
| torch.zeros( | |
| [v.shape[0], padded_length, v.shape[2], v.shape[3]], | |
| device=v.device, | |
| dtype=v.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| attn = flex_attention( | |
| query=padded_q.transpose(2, 1), # after: B, HW, F, C | |
| key=padded_k.transpose(2, 1), | |
| value=padded_v.transpose(2, 1), | |
| block_mask=block_mask_keyboard, | |
| )[:, :, :-padded_length].transpose(2, 1) | |
| else: | |
| current_start = start_frame | |
| current_end = current_start + k.shape[1] | |
| assert k.shape[1] == num_frame_per_block | |
| sink_size = 0 | |
| local_attn_size = self.local_attn_size | |
| max_attention_size = self.local_attn_size | |
| sink_tokens = sink_size * 1 | |
| kv_cache_size = kv_cache_keyboard["k"].shape[1] | |
| num_new_tokens = k.shape[1] | |
| if ( | |
| current_end > kv_cache_keyboard["global_end_index"].item() | |
| ) and ( | |
| num_new_tokens + kv_cache_keyboard["local_end_index"].item() | |
| > kv_cache_size | |
| ): | |
| num_evicted_tokens = ( | |
| num_new_tokens | |
| + kv_cache_keyboard["local_end_index"].item() | |
| - kv_cache_size | |
| ) | |
| num_rolled_tokens = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| - num_evicted_tokens | |
| - sink_tokens | |
| ) | |
| kv_cache_keyboard["k"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_keyboard["k"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| kv_cache_keyboard["v"][ | |
| :, sink_tokens : sink_tokens + num_rolled_tokens | |
| ] = kv_cache_keyboard["v"][ | |
| :, | |
| sink_tokens + num_evicted_tokens : sink_tokens | |
| + num_evicted_tokens | |
| + num_rolled_tokens, | |
| ].clone() | |
| # Insert the new keys/values at the end | |
| local_end_index = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| + current_end | |
| - kv_cache_keyboard["global_end_index"].item() | |
| - num_evicted_tokens | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| else: | |
| local_end_index = ( | |
| kv_cache_keyboard["local_end_index"].item() | |
| + current_end | |
| - kv_cache_keyboard["global_end_index"].item() | |
| ) | |
| local_start_index = local_end_index - num_new_tokens | |
| kv_cache_keyboard["k"][:, local_start_index:local_end_index] = k | |
| kv_cache_keyboard["v"][:, local_start_index:local_end_index] = v | |
| attn = flash_attn_func( | |
| q, | |
| kv_cache_keyboard["k"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| kv_cache_keyboard["v"][ | |
| :, | |
| max( | |
| 0, local_end_index - max_attention_size | |
| ) : local_end_index, | |
| ], | |
| # causal=is_causal | |
| ) | |
| kv_cache_keyboard["global_end_index"].fill_(current_end) | |
| kv_cache_keyboard["local_end_index"].fill_(local_end_index) | |
| else: | |
| attn = flash_attn_func( | |
| q, # 1, f*880, 16, 64 | |
| k, # 1, f, 16, 64 | |
| v, # 1, f, 16, 64 | |
| # causal=is_causal, | |
| ) | |
| attn = rearrange(attn, "B L H D -> B L (H D)") | |
| attn = self.proj_keyboard(attn) | |
| hidden_states = hidden_states + attn | |
| return hidden_states | |