| from typing import overload, Tuple, Optional |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| import numpy as np |
| from PIL import Image |
| from einops import rearrange |
|
|
| from model.cldm import ControlLDM |
| from model.gaussian_diffusion import Diffusion |
| from model.bsrnet import RRDBNet |
| from model.swinir import SwinIR |
| from model.scunet import SCUNet |
| from utils.sampler import SpacedSampler |
| from utils.cond_fn import Guidance |
| from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage |
|
|
|
|
| def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: |
| pil = Image.fromarray(img) |
| res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) |
| return np.array(res) |
|
|
|
|
| def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor: |
| _, _, h, w = imgs.size() |
| if h == w: |
| new_h, new_w = size, size |
| elif h < w: |
| new_h, new_w = size, int(w * (size / h)) |
| else: |
| new_h, new_w = int(h * (size / w)), size |
| return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True) |
|
|
|
|
| def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: |
| _, _, h, w = imgs.size() |
| if h % multiple == 0 and w % multiple == 0: |
| return imgs.clone() |
| |
| get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x |
| ph, pw = get_pad(h), get_pad(w) |
| return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) |
|
|
|
|
| class Pipeline: |
|
|
| def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: |
| self.stage1_model = stage1_model |
| self.cldm = cldm |
| self.diffusion = diffusion |
| self.cond_fn = cond_fn |
| self.device = device |
| self.final_size: Tuple[int] = None |
|
|
| def set_final_size(self, lq: torch.Tensor) -> None: |
| h, w = lq.shape[2:] |
| self.final_size = (h, w) |
|
|
| @overload |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: |
| ... |
|
|
| @count_vram_usage |
| def run_stage2( |
| self, |
| clean: torch.Tensor, |
| steps: int, |
| strength: float, |
| tiled: bool, |
| tile_size: int, |
| tile_stride: int, |
| pos_prompt: str, |
| neg_prompt: str, |
| cfg_scale: float, |
| better_start: float |
| ) -> torch.Tensor: |
| |
| bs, _, ori_h, ori_w = clean.shape |
| |
| pad_clean = pad_to_multiples_of(clean, multiple=64) |
| h, w = pad_clean.shape[2:] |
| |
| if not tiled: |
| cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs) |
| uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs) |
| else: |
| cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride) |
| uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride) |
| if self.cond_fn: |
| self.cond_fn.load_target(pad_clean * 2 - 1) |
| old_control_scales = self.cldm.control_scales |
| self.cldm.control_scales = [strength] * 13 |
| if better_start: |
| |
| |
| |
| _, low_freq = wavelet_decomposition(pad_clean) |
| if not tiled: |
| x_0 = self.cldm.vae_encode(low_freq) |
| else: |
| x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride) |
| x_T = self.diffusion.q_sample( |
| x_0, |
| torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device), |
| torch.randn(x_0.shape, dtype=torch.float32, device=self.device) |
| ) |
| |
| else: |
| x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) |
| |
| sampler = SpacedSampler(self.diffusion.betas) |
| z = sampler.sample( |
| model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8), |
| cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True, |
| progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| ) |
| if not tiled: |
| x = self.cldm.vae_decode(z) |
| else: |
| x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8) |
| |
| self.cldm.control_scales = old_control_scales |
| sample = x[:, :, :ori_h, :ori_w] |
| return sample |
|
|
| @torch.no_grad() |
| def run( |
| self, |
| lq: np.ndarray, |
| steps: int, |
| strength: float, |
| tiled: bool, |
| tile_size: int, |
| tile_stride: int, |
| pos_prompt: str, |
| neg_prompt: str, |
| cfg_scale: float, |
| better_start: bool |
| ) -> np.ndarray: |
| |
| lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) |
| lq = rearrange(lq, "n h w c -> n c h w").contiguous() |
| |
| self.set_final_size(lq) |
| clean = self.run_stage1(lq) |
| sample = self.run_stage2( |
| clean, steps, strength, tiled, tile_size, tile_stride, |
| pos_prompt, neg_prompt, cfg_scale, better_start |
| ) |
| |
| sample = (sample + 1) / 2 |
| sample = wavelet_reconstruction(sample, clean) |
| |
| sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True) |
| |
| sample = rearrange(sample * 255., "n c h w -> n h w c") |
| sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() |
| return sample |
|
|
|
|
| class BSRNetPipeline(Pipeline): |
|
|
| def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None: |
| super().__init__(bsrnet, cldm, diffusion, cond_fn, device) |
| self.upscale = upscale |
|
|
| def set_final_size(self, lq: torch.Tensor) -> None: |
| h, w = lq.shape[2:] |
| self.final_size = (int(h * self.upscale), int(w * self.upscale)) |
|
|
| @count_vram_usage |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: |
| |
| clean = self.stage1_model(lq) |
| |
| if min(self.final_size) < 512: |
| clean = resize_short_edge_to(clean, size=512) |
| else: |
| clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True) |
| return clean |
|
|
|
|
| class SwinIRPipeline(Pipeline): |
|
|
| def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: |
| super().__init__(swinir, cldm, diffusion, cond_fn, device) |
|
|
| @count_vram_usage |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: |
| |
| |
| if min(lq.shape[2:]) < 512: |
| lq = resize_short_edge_to(lq, size=512) |
| ori_h, ori_w = lq.shape[2:] |
| |
| pad_lq = pad_to_multiples_of(lq, multiple=64) |
| |
| clean = self.stage1_model(pad_lq) |
| |
| clean = clean[:, :, :ori_h, :ori_w] |
| return clean |
|
|
|
|
| class SCUNetPipeline(Pipeline): |
|
|
| def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: |
| super().__init__(scunet, cldm, diffusion, cond_fn, device) |
|
|
| @count_vram_usage |
| def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: |
| clean = self.stage1_model(lq) |
| if min(clean.shape[2:]) < 512: |
| clean = resize_short_edge_to(clean, size=512) |
| return clean |
|
|