| import os.path |
| import sys |
|
|
| import PIL.Image |
| import numpy as np |
| import torch |
| from tqdm import tqdm |
|
|
| from basicsr.utils.download_util import load_file_from_url |
|
|
| import modules.upscaler |
| from modules import devices, modelloader, script_callbacks, errors |
| from scunet_model_arch import SCUNet as net |
|
|
| from modules.shared import opts |
|
|
|
|
| class UpscalerScuNET(modules.upscaler.Upscaler): |
| def __init__(self, dirname): |
| self.name = "ScuNET" |
| self.model_name = "ScuNET GAN" |
| self.model_name2 = "ScuNET PSNR" |
| self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" |
| self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" |
| self.user_path = dirname |
| super().__init__() |
| model_paths = self.find_models(ext_filter=[".pth"]) |
| scalers = [] |
| add_model2 = True |
| for file in model_paths: |
| if "http" in file: |
| name = self.model_name |
| else: |
| name = modelloader.friendly_name(file) |
| if name == self.model_name2 or file == self.model_url2: |
| add_model2 = False |
| try: |
| scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) |
| scalers.append(scaler_data) |
| except Exception: |
| errors.report(f"Error loading ScuNET model: {file}", exc_info=True) |
| if add_model2: |
| scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) |
| scalers.append(scaler_data2) |
| self.scalers = scalers |
|
|
| @staticmethod |
| @torch.no_grad() |
| def tiled_inference(img, model): |
| |
| h, w = img.shape[2:] |
| tile = opts.SCUNET_tile |
| tile_overlap = opts.SCUNET_tile_overlap |
| if tile == 0: |
| return model(img) |
|
|
| device = devices.get_device_for('scunet') |
| assert tile % 8 == 0, "tile size should be a multiple of window_size" |
| sf = 1 |
|
|
| stride = tile - tile_overlap |
| h_idx_list = list(range(0, h - tile, stride)) + [h - tile] |
| w_idx_list = list(range(0, w - tile, stride)) + [w - tile] |
| E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) |
| W = torch.zeros_like(E, dtype=devices.dtype, device=device) |
|
|
| with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: |
| for h_idx in h_idx_list: |
|
|
| for w_idx in w_idx_list: |
|
|
| in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] |
|
|
| out_patch = model(in_patch) |
| out_patch_mask = torch.ones_like(out_patch) |
|
|
| E[ |
| ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf |
| ].add_(out_patch) |
| W[ |
| ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf |
| ].add_(out_patch_mask) |
| pbar.update(1) |
| output = E.div_(W) |
|
|
| return output |
|
|
| def do_upscale(self, img: PIL.Image.Image, selected_file): |
|
|
| torch.cuda.empty_cache() |
|
|
| model = self.load_model(selected_file) |
| if model is None: |
| print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) |
| return img |
|
|
| device = devices.get_device_for('scunet') |
| tile = opts.SCUNET_tile |
| h, w = img.height, img.width |
| np_img = np.array(img) |
| np_img = np_img[:, :, ::-1] |
| np_img = np_img.transpose((2, 0, 1)) / 255 |
| torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) |
|
|
| if tile > h or tile > w: |
| _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) |
| _img[:, :, :h, :w] = torch_img |
| torch_img = _img |
|
|
| torch_output = self.tiled_inference(torch_img, model).squeeze(0) |
| torch_output = torch_output[:, :h * 1, :w * 1] |
| np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() |
| del torch_img, torch_output |
| torch.cuda.empty_cache() |
|
|
| output = np_output.transpose((1, 2, 0)) |
| output = output[:, :, ::-1] |
| return PIL.Image.fromarray((output * 255).astype(np.uint8)) |
|
|
| def load_model(self, path: str): |
| device = devices.get_device_for('scunet') |
| if "http" in path: |
| filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) |
| else: |
| filename = path |
| if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: |
| print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) |
| return None |
|
|
| model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) |
| model.load_state_dict(torch.load(filename), strict=True) |
| model.eval() |
| for _, v in model.named_parameters(): |
| v.requires_grad = False |
| model = model.to(device) |
|
|
| return model |
|
|
|
|
| def on_ui_settings(): |
| import gradio as gr |
| from modules import shared |
|
|
| shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) |
| shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) |
|
|
|
|
| script_callbacks.on_ui_settings(on_ui_settings) |
|
|