|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partialmethod |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from openfold.model.primitives import Linear, LayerNorm |
|
|
from openfold.utils.tensor_utils import permute_final_dims |
|
|
|
|
|
|
|
|
class TriangleMultiplicativeUpdate(nn.Module): |
|
|
""" |
|
|
Implements Algorithms 11 and 12. |
|
|
""" |
|
|
def __init__(self, c_z, c_hidden, _outgoing=True): |
|
|
""" |
|
|
Args: |
|
|
c_z: |
|
|
Input channel dimension |
|
|
c: |
|
|
Hidden channel dimension |
|
|
""" |
|
|
super(TriangleMultiplicativeUpdate, self).__init__() |
|
|
self.c_z = c_z |
|
|
self.c_hidden = c_hidden |
|
|
self._outgoing = _outgoing |
|
|
|
|
|
self.linear_a_p = Linear(self.c_z, self.c_hidden) |
|
|
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating") |
|
|
self.linear_b_p = Linear(self.c_z, self.c_hidden) |
|
|
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating") |
|
|
self.linear_g = Linear(self.c_z, self.c_z, init="gating") |
|
|
self.linear_z = Linear(self.c_hidden, self.c_z, init="final") |
|
|
|
|
|
self.layer_norm_in = LayerNorm(self.c_z) |
|
|
self.layer_norm_out = LayerNorm(self.c_hidden) |
|
|
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def _combine_projections(self, |
|
|
a: torch.Tensor, |
|
|
b: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
raise NotImplementedError("This method needs to be overridden") |
|
|
|
|
|
def forward(self, |
|
|
z: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: |
|
|
[*, N_res, N_res, C_z] input tensor |
|
|
mask: |
|
|
[*, N_res, N_res] input mask |
|
|
Returns: |
|
|
[*, N_res, N_res, C_z] output tensor |
|
|
""" |
|
|
if mask is None: |
|
|
mask = z.new_ones(z.shape[:-1]) |
|
|
|
|
|
mask = mask.unsqueeze(-1) |
|
|
|
|
|
z = self.layer_norm_in(z) |
|
|
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z)) |
|
|
a = a * mask |
|
|
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z)) |
|
|
b = b * mask |
|
|
x = self._combine_projections(a, b) |
|
|
x = self.layer_norm_out(x) |
|
|
x = self.linear_z(x) |
|
|
g = self.sigmoid(self.linear_g(z)) |
|
|
z = x * g |
|
|
|
|
|
return z |
|
|
|
|
|
|
|
|
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): |
|
|
""" |
|
|
Implements Algorithm 11. |
|
|
""" |
|
|
def _combine_projections(self, |
|
|
a: torch.Tensor, |
|
|
b: torch.Tensor, |
|
|
): |
|
|
|
|
|
p = torch.matmul( |
|
|
permute_final_dims(a, (2, 0, 1)), |
|
|
permute_final_dims(b, (2, 1, 0)), |
|
|
) |
|
|
|
|
|
|
|
|
return permute_final_dims(p, (1, 2, 0)) |
|
|
|
|
|
|
|
|
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): |
|
|
""" |
|
|
Implements Algorithm 12. |
|
|
""" |
|
|
def _combine_projections(self, |
|
|
a: torch.Tensor, |
|
|
b: torch.Tensor, |
|
|
): |
|
|
|
|
|
p = torch.matmul( |
|
|
permute_final_dims(a, (2, 1, 0)), |
|
|
permute_final_dims(b, (2, 0, 1)), |
|
|
) |
|
|
|
|
|
|
|
|
return permute_final_dims(p, (1, 2, 0)) |
|
|
|
|
|
|