| import torch.nn as nn |
| from compressai.entropy_models import EntropyBottleneck |
| from timm.models.vision_transformer import Block |
|
|
| class IF_Module(nn.Module): |
| def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm): |
| super(IF_Module, self).__init__() |
|
|
| self.encoder_blocks = nn.ModuleList([ |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| for i in range(depth)]) |
| self.encoder_norm = norm_layer(embed_dim) |
|
|
| self.decoder_blocks = nn.ModuleList([ |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| for i in range(depth)]) |
| |
| self.decoder_norm = norm_layer(embed_dim) |
| self.entropy_bottleneck = EntropyBottleneck(embed_dim) |
|
|
| def forward(self, x, is_training=False): |
| |
| for blk in self.encoder_blocks: |
| x = blk(x) |
| x = self.encoder_norm(x) |
|
|
| if is_training: |
| x = x.permute(0, 2, 1) |
| x_hat, x_likelihood = self.entropy_bottleneck(x) |
| x_hat = x_hat.permute(0, 2, 1) |
| else: |
| x_hat = x |
| x_likelihood = None |
|
|
| |
| for blk in self.decoder_blocks: |
| x_hat = blk(x_hat) |
| x_hat = self.decoder_norm(x_hat) |
|
|
| return x_hat, x_likelihood |
|
|