| | from collections import defaultdict |
| | from contextlib import contextmanager |
| | from logging import getLogger |
| | import math |
| | import sys |
| | from typing import List, Union, Iterable |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| |
|
| | from timm.models import VisionTransformer |
| | from einops import rearrange |
| |
|
| | from .extra_models import DinoWrapper |
| |
|
| | DEFAULT_NUM_WINDOWED = 5 |
| | DEFAULT_NUM_GLOBAL = 4 |
| |
|
| |
|
| | class VitDetArgs: |
| | def __init__(self, |
| | window_size: int, |
| | num_summary_tokens: int, |
| | num_windowed: int = None, |
| | num_global: int = None, |
| | ): |
| | self.window_size = window_size |
| | self.num_summary_tokens = num_summary_tokens |
| | self.num_windowed = num_windowed |
| | self.num_global = num_global |
| |
|
| |
|
| | def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs): |
| | if isinstance(model, VisionTransformer): |
| | patch_embed = getattr(model, 'patch_generator', model.patch_embed) |
| |
|
| | return ViTDetHook(patch_embed, model.blocks, args) |
| | elif isinstance(model, DinoWrapper): |
| | inner = model.inner |
| |
|
| | patch_embed = getattr(inner, 'patch_generator', inner.patch_embed) |
| | return ViTDetHook(patch_embed, inner.blocks, args) |
| | else: |
| | print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr) |
| |
|
| |
|
| | class ViTDetHook: |
| | def __init__(self, |
| | embedder: nn.Module, |
| | blocks: nn.Sequential, |
| | args: VitDetArgs, |
| | ): |
| | self.blocks = blocks |
| | self.num_summary_tokens = args.num_summary_tokens |
| | self.window_size = args.window_size |
| |
|
| | self._input_resolution = None |
| | self._num_windows = None |
| | self._cls_patch = None |
| | self._order_cache = dict() |
| |
|
| | embedder.register_forward_pre_hook(self._enter_model) |
| |
|
| | |
| | |
| | |
| | blocks.register_forward_pre_hook(self._enter_blocks) |
| |
|
| | is_global = True |
| | if args.num_windowed is not None: |
| | period = args.num_windowed + 1 |
| | else: |
| | num_global = args.num_global or DEFAULT_NUM_GLOBAL |
| | period = max(len(blocks) // num_global, 1) |
| |
|
| | for i, layer in enumerate(blocks[:-1]): |
| | ctr = i % period |
| | if ctr == 0: |
| | layer.register_forward_pre_hook(self._to_windows) |
| | is_global = False |
| | elif ctr == period - 1: |
| | layer.register_forward_pre_hook(self._to_global) |
| | is_global = True |
| |
|
| | |
| | if not is_global: |
| | blocks[-1].register_forward_pre_hook(self._to_global) |
| |
|
| | blocks.register_forward_hook(self._exit_model) |
| |
|
| | def _enter_model(self, _, input: List[torch.Tensor]): |
| | self._input_resolution = input[0].shape[-2:] |
| |
|
| | def _enter_blocks(self, _, input: List[torch.Tensor]): |
| | |
| |
|
| | patches = input[0] |
| | patches = self._rearrange_patches(patches) |
| |
|
| | return (patches,) + input[1:] |
| |
|
| | def _to_windows(self, _, input: List[torch.Tensor]): |
| | patches = input[0] |
| |
|
| | if self.num_summary_tokens: |
| | self._cls_patch = patches[:, :self.num_summary_tokens] |
| | patches = patches[:, self.num_summary_tokens:] |
| |
|
| | patches = rearrange( |
| | patches, 'b (p t) c -> (b p) t c', |
| | p=self._num_windows, t=self.window_size ** 2, |
| | ) |
| |
|
| | return (patches,) + input[1:] |
| |
|
| | def _to_global(self, _, input: List[torch.Tensor]): |
| | patches = input[0] |
| |
|
| | patches = rearrange( |
| | patches, '(b p) t c -> b (p t) c', |
| | p=self._num_windows, t=self.window_size ** 2, |
| | b=patches.shape[0] // self._num_windows, |
| | ) |
| |
|
| | if self.num_summary_tokens: |
| | patches = torch.cat([ |
| | self._cls_patch, |
| | patches, |
| | ], dim=1) |
| |
|
| | return (patches,) + input[1:] |
| |
|
| | def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor): |
| | |
| | patch_order = self._order_cache[self._input_resolution][0] |
| | patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) |
| |
|
| | ret_patches = torch.empty_like(patches) |
| | ret_patches = torch.scatter( |
| | ret_patches, |
| | dim=1, |
| | index=patch_order, |
| | src=patches, |
| | ) |
| |
|
| | return ret_patches |
| |
|
| | def _rearrange_patches(self, patches: torch.Tensor): |
| | |
| | |
| | |
| |
|
| | patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None)) |
| | if patch_order is None: |
| | num_feat_patches = patches.shape[1] - self.num_summary_tokens |
| | num_pixels = self._input_resolution[0] * self._input_resolution[1] |
| |
|
| | patch_size = int(round(math.sqrt(num_pixels / num_feat_patches))) |
| | rows = self._input_resolution[-2] // patch_size |
| | cols = self._input_resolution[-1] // patch_size |
| |
|
| | w_rows = rows // self.window_size |
| | w_cols = cols // self.window_size |
| |
|
| | patch_order = torch.arange(0, num_feat_patches, device=patches.device) |
| |
|
| | patch_order = rearrange( |
| | patch_order, '(wy py wx px) -> (wy wx py px)', |
| | wy=w_rows, wx=w_cols, |
| | py=self.window_size, px=self.window_size, |
| | ) |
| |
|
| | if self.num_summary_tokens: |
| | patch_order = torch.cat([ |
| | torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device), |
| | patch_order + self.num_summary_tokens, |
| | ]) |
| |
|
| | self._num_windows = w_rows * w_cols |
| | self._order_cache[self._input_resolution] = ( |
| | patch_order, |
| | self._num_windows, |
| | ) |
| |
|
| | patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) |
| | patches = torch.gather(patches, dim=1, index=patch_order) |
| | return patches |
| |
|