|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Sequence, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from openfold.model.dropout import ( |
|
|
DropoutRowwise, |
|
|
DropoutColumnwise, |
|
|
) |
|
|
from openfold.model.evoformer import ( |
|
|
EvoformerBlock, |
|
|
EvoformerStack, |
|
|
) |
|
|
from openfold.model.outer_product_mean import OuterProductMean |
|
|
from openfold.model.msa import ( |
|
|
MSARowAttentionWithPairBias, |
|
|
MSAColumnAttention, |
|
|
MSAColumnGlobalAttention, |
|
|
) |
|
|
from openfold.model.pair_transition import PairTransition |
|
|
from openfold.model.primitives import Attention, GlobalAttention |
|
|
from openfold.model.structure_module import ( |
|
|
InvariantPointAttention, |
|
|
BackboneUpdate, |
|
|
) |
|
|
from openfold.model.template import TemplatePairStackBlock |
|
|
from openfold.model.triangular_attention import ( |
|
|
TriangleAttentionStartingNode, |
|
|
TriangleAttentionEndingNode, |
|
|
) |
|
|
from openfold.model.triangular_multiplicative_update import ( |
|
|
TriangleMultiplicationOutgoing, |
|
|
TriangleMultiplicationIncoming, |
|
|
) |
|
|
|
|
|
|
|
|
def script_preset_(model: torch.nn.Module): |
|
|
""" |
|
|
TorchScript a handful of low-level but frequently used submodule types |
|
|
that are known to be scriptable. |
|
|
|
|
|
Args: |
|
|
model: |
|
|
A torch.nn.Module. It should contain at least some modules from |
|
|
this repository, or this function won't do anything. |
|
|
""" |
|
|
script_submodules_( |
|
|
model, |
|
|
[ |
|
|
nn.Dropout, |
|
|
Attention, |
|
|
GlobalAttention, |
|
|
EvoformerBlock, |
|
|
|
|
|
], |
|
|
attempt_trace=False, |
|
|
batch_dims=None, |
|
|
) |
|
|
|
|
|
|
|
|
def _get_module_device(module: torch.nn.Module) -> torch.device: |
|
|
""" |
|
|
Fetches the device of a module, assuming that all of the module's |
|
|
parameters reside on a single device |
|
|
|
|
|
Args: |
|
|
module: A torch.nn.Module |
|
|
Returns: |
|
|
The module's device |
|
|
""" |
|
|
return next(module.parameters()).device |
|
|
|
|
|
|
|
|
def _trace_module(module, batch_dims=None): |
|
|
if(batch_dims is None): |
|
|
batch_dims = () |
|
|
|
|
|
|
|
|
n_seq = 10 |
|
|
n_res = 10 |
|
|
|
|
|
device = _get_module_device(module) |
|
|
|
|
|
def msa(channel_dim): |
|
|
return torch.rand( |
|
|
(*batch_dims, n_seq, n_res, channel_dim), |
|
|
device=device, |
|
|
) |
|
|
|
|
|
def pair(channel_dim): |
|
|
return torch.rand( |
|
|
(*batch_dims, n_res, n_res, channel_dim), |
|
|
device=device, |
|
|
) |
|
|
|
|
|
if(isinstance(module, MSARowAttentionWithPairBias)): |
|
|
inputs = { |
|
|
"forward": ( |
|
|
msa(module.c_in), |
|
|
pair(module.c_z), |
|
|
torch.randint( |
|
|
0, 2, |
|
|
(*batch_dims, n_seq, n_res) |
|
|
), |
|
|
), |
|
|
} |
|
|
elif(isinstance(module, MSAColumnAttention)): |
|
|
inputs = { |
|
|
"forward": ( |
|
|
msa(module.c_in), |
|
|
torch.randint( |
|
|
0, 2, |
|
|
(*batch_dims, n_seq, n_res) |
|
|
), |
|
|
), |
|
|
} |
|
|
elif(isinstance(module, OuterProductMean)): |
|
|
inputs = { |
|
|
"forward": ( |
|
|
msa(module.c_m), |
|
|
torch.randint( |
|
|
0, 2, |
|
|
(*batch_dims, n_seq, n_res) |
|
|
) |
|
|
) |
|
|
} |
|
|
else: |
|
|
raise TypeError( |
|
|
f"tracing is not supported for modules of type {type(module)}" |
|
|
) |
|
|
|
|
|
return torch.jit.trace_module(module, inputs) |
|
|
|
|
|
|
|
|
def _script_submodules_helper_( |
|
|
model, |
|
|
types, |
|
|
attempt_trace, |
|
|
to_trace, |
|
|
): |
|
|
for name, child in model.named_children(): |
|
|
if(types is None or any(isinstance(child, t) for t in types)): |
|
|
try: |
|
|
scripted = torch.jit.script(child) |
|
|
setattr(model, name, scripted) |
|
|
continue |
|
|
except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: |
|
|
if(attempt_trace): |
|
|
to_trace.add(type(child)) |
|
|
else: |
|
|
raise e |
|
|
|
|
|
_script_submodules_helper_(child, types, attempt_trace, to_trace) |
|
|
|
|
|
|
|
|
def _trace_submodules_( |
|
|
model, |
|
|
types, |
|
|
batch_dims=None, |
|
|
): |
|
|
for name, child in model.named_children(): |
|
|
if(any(isinstance(child, t) for t in types)): |
|
|
traced = _trace_module(child, batch_dims=batch_dims) |
|
|
setattr(model, name, traced) |
|
|
else: |
|
|
_trace_submodules_(child, types, batch_dims=batch_dims) |
|
|
|
|
|
|
|
|
def script_submodules_( |
|
|
model: nn.Module, |
|
|
types: Optional[Sequence[type]] = None, |
|
|
attempt_trace: Optional[bool] = True, |
|
|
batch_dims: Optional[Tuple[int]] = None, |
|
|
): |
|
|
""" |
|
|
Convert all submodules whose types match one of those in the input |
|
|
list to recursively scripted equivalents in place. To script the entire |
|
|
model, just call torch.jit.script on it directly. |
|
|
|
|
|
When types is None, all submodules are scripted. |
|
|
|
|
|
Args: |
|
|
model: |
|
|
A torch.nn.Module |
|
|
types: |
|
|
A list of types of submodules to script |
|
|
attempt_trace: |
|
|
Whether to attempt to trace specified modules if scripting |
|
|
fails. Recall that tracing eliminates all conditional |
|
|
logic---with great tracing comes the mild responsibility of |
|
|
having to remember to ensure that the modules in question |
|
|
perform the same computations no matter what. |
|
|
""" |
|
|
to_trace = set() |
|
|
|
|
|
|
|
|
_script_submodules_helper_(model, types, attempt_trace, to_trace) |
|
|
|
|
|
|
|
|
if(attempt_trace and len(to_trace) > 0): |
|
|
_trace_submodules_(model, to_trace, batch_dims=batch_dims) |
|
|
|