🌟 This repository contains an implementation of the paper FAENet: Frame Averaging Equivariant GNN for Materials Modeling, accepted at ICML 2023.🌟 More precisely, you will find:
FrameAveraging
: the transform that projects your pytorch-geometric data into a canonical space of all euclidean transformations, as defined in the paper.FAENet
: a GNN architecture for material modeling.model_forward
: a high-level forward function that computes appropriate equivariant model predictions for the Frame Averaging method, i.e. handling the different frames and mapping to equivariant predictions.
pip install faenet
Python >= 3.8
, torch > 1.11
, torch_geometric > 2.1
to the best of our knowledge. Both mendeleev
and pandas
package are also required to derive physics-aware atom embeddings in FAENet.
FrameAveraging
is a Transform method applicable to pytorch-geometric Data
object, which shall be used in the get_item()
function of your Dataset
class. This method derives a new canonical position for the atomic graph, identical for all euclidean symmetries, and stores it under the data attribute fa_pos
. You can choose among several options for the frame averaging, ranging from Full FA to Stochastic FA (in 2D or 3D) including traditional data augmentation DA with rotated samples. See the full doc for more details. Note that, although this transform is specific to pytorch-geometric data objects, it can be easily extended to new settings since the core functions frame_averaging_2D()
and frame_averaging_3D()
generalise to other data format.
import torch
from faenet.transforms import FrameAveraging
frame_averaging = "3D" # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic" # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(data) # transform the PyG graph data
model_forward()
aggregates the predictions of a chosen ML model (e.g FAENet) when Frame Averaging is applied, as stipulated by the Equation (1) of the paper. INded, applying the model on canonical positions (fa_pos
) directly would not yield equivariant predictions. This method must be applied at training and inference time to compute all model predictions. It requires batch
to have pos, batch and frame averaging attributes (see docu).
from faenet.fa_forward import model_forward
preds = model_forward(
batch=batch, # batch from, dataloader
model=model, # FAENet(**kwargs)
frame_averaging="3D", # ["2D", "3D", "DA", ""]
mode="train", # for training
crystal_task=True, # for crystals, with pbc conditions
)
Implementation of the FAENet GNN model, compatible with any dataset or transform. In short, FAENet is a very simple, scalable and expressive model. Since does not explicitly preserve data symmetries, it has the ability to process directly and unrestrictedly atom relative positions, which is very efficient and powerful. Although it was specifically designed to be applied with Frame Averaging above, to preserve symmetries without any design restrictions, note that it can also be applied without. When applied with Frame Averaging, we need to use the model_forward()
function above to compute model predictions, model(data)
is not enough. Note that the training procedure is not given here, you should refer to the original github repository. Check the documentation to see all input parameters.
Note that the model assumes input data (e.g.batch
below) to have certain attributes, like atomic_numbers, batch, pos or edge_index. If your data does not have these attributes, you can apply custom pre-processing functions, taking pbc_preprocess
or base_preprocess
in utils.py as inspiration. You simply need to pass them as argument to FAENet (preprocess
).
from faenet.model import FAENet
preds = FAENet(**kwargs)
model(batch)
The eval_model_symmetries()
function helps you evaluate the equivariant, invariant and other properties of a model, as we did in the paper.
Note: you can predict any atom-level or graph-level property, although the code explicitly refers to energy and forces.
The /tests
folder contains several useful unit-tests. Feel free to have a look at them to explore how the model can be used. For more advanced examples, please refer to the full repository used in our ICML paper to make predictions on OC20 IS2RE, S2EF, QM9 and QM7-X dataset.
This requires poetry
. Make sure to have torch
and torch_geometric
installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch
nor torch_geometric
are part of the explicit dependencies and must be installed independently.
git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing
Testing on Macs you may encounter a Library Not Loaded Error
Authors: Alexandre Duval (alexandre.duval@mila.quebec) and Victor Schmidt (schmidtv@mila.quebec). We welcome your questions and feedback via email or GitHub Issues.