| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class VQModelImageProcessor(BaseImageProcessor): |
| | def __init__( |
| | self, |
| | size: int = 256, |
| | convert_rgb: bool = False, |
| | resample: Image.Resampling = Image.Resampling.LANCZOS, |
| | **kwargs: dict, |
| | ) -> None: |
| | self.size = size |
| | self.convert_rgb = convert_rgb |
| | self.resample = resample |
| |
|
| | def __call__(self, image: Image.Image) -> dict: |
| | return self.preprocess(image) |
| |
|
| | def preprocess(self, image: Image.Image) -> dict: |
| | width, height = image.size |
| | size = (self.size, self.size) |
| | image = image.resize(size, resample=self.resample) |
| | image = image.convert("RGBA") |
| |
|
| | if self.convert_rgb: |
| | |
| | image_new = Image.new("RGB", image.size, (255, 255, 255)) |
| | image_new.paste(image, mask=image.split()[3]) |
| | image = image_new |
| |
|
| | return { |
| | "image": self.to_tensor(image), |
| | "width": width, |
| | "height": height, |
| | } |
| |
|
| | def to_tensor(self, image: Image.Image) -> torch.Tensor: |
| | x = np.array(image) / 127.5 - 1.0 |
| | x = x.transpose(2, 0, 1).astype(np.float32) |
| | return torch.as_tensor(x) |
| |
|
| | def postprocess( |
| | self, |
| | x: torch.Tensor, |
| | width: int | None = None, |
| | height: int | None = None, |
| | ) -> Image.Image: |
| | x_np = x.detach().cpu().numpy() |
| | x_np = x_np.transpose(1, 2, 0) |
| | x_np = (x_np + 1.0) * 127.5 |
| | x_np = np.clip(x_np, 0, 255).astype(np.uint8) |
| | image = Image.fromarray(x_np) |
| |
|
| | |
| | width = width or self.size |
| | height = height or self.size |
| | image = image.resize((width, height), resample=self.resample) |
| |
|
| | return image |
| |
|