| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, List, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector |
| | from monai.inferers.inferer import Inferer |
| | from torch import Tensor |
| |
|
| |
|
| | class RetinaNetInferer(Inferer): |
| | """ |
| | RetinaNet Inferer takes RetinaNet as input |
| | |
| | Args: |
| | detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP |
| | map into boxes and classification scores. |
| | force_sliding_window: whether to force using a SlidingWindowInferer to do the inference. |
| | If False, will check the input spatial size to decide whether to simply |
| | forward the network or using SlidingWindowInferer. |
| | If True, will force using SlidingWindowInferer to do the inference. |
| | args: other optional args to be passed to detector. |
| | kwargs: other optional keyword args to be passed to detector. |
| | """ |
| |
|
| | def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None: |
| | Inferer.__init__(self) |
| | self.detector = detector |
| | self.sliding_window_size = None |
| | self.force_sliding_window = force_sliding_window |
| | if self.detector.inferer is not None: |
| | if hasattr(self.detector.inferer, "roi_size"): |
| | self.sliding_window_size = np.prod(self.detector.inferer.roi_size) |
| |
|
| | def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any): |
| | """Unified callable function API of Inferers. |
| | Args: |
| | inputs: model input data for inference. |
| | network: target detection network to execute inference. |
| | supports callable that fullfilles requirements of network in |
| | monai.apps.detection.networks.retinanet_detector.RetinaNetDetector`` |
| | args: optional args to be passed to ``network``. |
| | kwargs: optional keyword args to be passed to ``network``. |
| | """ |
| | self.detector.network = network |
| | self.detector.training = self.detector.network.training |
| |
|
| | |
| | |
| | use_inferer = ( |
| | self.force_sliding_window |
| | or self.sliding_window_size is not None |
| | and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]) |
| | ) |
| |
|
| | return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs) |
| |
|