Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

advance common_modules.py #846

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 65 additions & 173 deletions alphafold/model/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,180 +12,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A collection of common Haiku modules for use in protein folding."""
import numbers
from typing import Union, Sequence

# Advanced Haiku Protein Folding Modules
import haiku as hk
import jax.numpy as jnp
import numpy as np


# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
dtype=np.float32)


def get_initializer_scale(initializer_name, input_shape):
"""Get Initializer for weights and scale to multiply activations by."""

if initializer_name == 'zeros':
w_init = hk.initializers.Constant(0.0)
else:
# fan-in scaling
scale = 1.
for channel_dim in input_shape:
scale /= channel_dim
if initializer_name == 'relu':
scale *= 2

noise_scale = scale

stddev = np.sqrt(noise_scale)
# Adjust stddev for truncation.
stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)

return w_init


class Linear(hk.Module):
"""Protein folding specific Linear module.

This differs from the standard Haiku Linear in a few ways:
* It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings
"""

def __init__(self,
num_output: Union[int, Sequence[int]],
initializer: str = 'linear',
num_input_dims: int = 1,
use_bias: bool = True,
bias_init: float = 0.,
precision = None,
name: str = 'linear'):
"""Constructs Linear Module.

Args:
num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
"""
super().__init__(name=name)
if isinstance(num_output, numbers.Integral):
self.output_shape = (num_output,)
else:
self.output_shape = tuple(num_output)
self.initializer = initializer
self.use_bias = use_bias
self.bias_init = bias_init
self.num_input_dims = num_input_dims
self.num_output_dims = len(self.output_shape)
self.precision = precision

def __call__(self, inputs):
"""Connects Module.

Args:
inputs: Tensor with at least num_input_dims dimensions.

Returns:
output of shape [...] + num_output.
"""

num_input_dims = self.num_input_dims

if self.num_input_dims > 0:
in_shape = inputs.shape[-self.num_input_dims:]
else:
in_shape = ()

weight_init = get_initializer_scale(self.initializer, in_shape)

in_letters = 'abcde'[:self.num_input_dims]
out_letters = 'hijkl'[:self.num_output_dims]

weight_shape = in_shape + self.output_shape
weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
weight_init)

equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'

output = jnp.einsum(equation, inputs, weights, precision=self.precision)

if self.use_bias:
bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
hk.initializers.Constant(self.bias_init))
output += bias

return output


class LayerNorm(hk.LayerNorm):
"""LayerNorm module.

Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""

def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)

param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)

param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)

if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)

out = super().__call__(x, scale=scale, offset=offset)

if is_bf16:
out = out.astype(jnp.bfloat16)

return out

class ProteinLinear(hk.Module):
def __init__(self, num_output, initializer='linear', num_input_dims=1, use_bias=True, bias_init=0., name='protein_linear'):
super().__init__(name=name)
self.num_output = num_output
self.initializer = initializer
self.num_input_dims = num_input_dims
self.use_bias = use_bias
self.bias_init = bias_init

def __call__(self, inputs):
input_shape = inputs.shape[-self.num_input_dims:]
weight_init = self._get_initializer_scale(input_shape)
weights = hk.get_parameter('weights', input_shape + self.num_output, inputs.dtype, weight_init)
output = jnp.matmul(inputs, weights)

if self.use_bias:
bias = hk.get_parameter('bias', self.num_output, inputs.dtype, hk.initializers.Constant(self.bias_init))
output += bias

return output

def _get_initializer_scale(self, input_shape):
if self.initializer == 'zeros':
return hk.initializers.Constant(0.0)
else:
scale = 1.0 / jnp.prod(input_shape)
if self.initializer == 'relu':
scale *= 2
stddev = jnp.sqrt(scale)
stddev /= np.sqrt(0.87962566103423978) # Adjusted for truncation
return hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)

class ProteinLayerNorm(hk.LayerNorm):
def __init__(self, axis, create_scale=True, create_offset=True, eps=1e-5, name=None):
super().__init__(axis=axis, create_scale=False, create_offset=False, eps=eps, name=name)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset

def __call__(self, x):
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)

param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)
param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None

if self._temp_create_scale:
scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)

if self._temp_create_offset:
offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)

out = super().__call__(x, scale=scale, offset=offset)

if is_bf16:
out = out.astype(jnp.bfloat16)

return out