| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from monai.losses.adversarial_loss import PatchAdversarialLoss |
|
|
| intensity_loss = torch.nn.L1Loss() |
| adv_loss = PatchAdversarialLoss(criterion="least_squares") |
|
|
| adv_weight = 0.1 |
| perceptual_weight = 0.1 |
| |
| |
| |
| kl_weight = 1e-7 |
|
|
|
|
| def compute_kl_loss(z_mu, z_sigma): |
| kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4]) |
| return torch.sum(kl_loss) / kl_loss.shape[0] |
|
|
|
|
| def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual): |
| recons_loss = intensity_loss(gen_images, real_images) |
| kl_loss = compute_kl_loss(z_mu, z_sigma) |
| p_loss = loss_perceptual(gen_images.float(), real_images.float()) |
| loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss |
|
|
| logits_fake = disc_net(gen_images)[-1] |
| generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False) |
| loss_g = loss_g + adv_weight * generator_loss |
|
|
| return loss_g |
|
|
|
|
| def discriminator_loss(gen_images, real_images, disc_net): |
| logits_fake = disc_net(gen_images.contiguous().detach())[-1] |
| loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True) |
| logits_real = disc_net(real_images.contiguous().detach())[-1] |
| loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True) |
| discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 |
| loss_d = adv_weight * discriminator_loss |
| return loss_d |
|
|