Skip to content

Latest commit

 

History

History
190 lines (126 loc) · 6.84 KB

readme.md

File metadata and controls

190 lines (126 loc) · 6.84 KB

HAT-CL

Redesigned Hard-Attention-to-the-Task for Continual Learning

HAT-CL is a comprehensive reimagining of the Hard-Attention-to-the-Task (HAT) mechanism, designed specifically to combat catastrophic forgetting during Continual Learning (CL). Originally proposed in the paper Overcoming catastrophic forgetting with hard attention to the task, HAT has been instrumental in enabling neural networks to learn successive tasks without erasure of prior knowledge. However, the original implementation had its drawbacks, notably incompatibility with PyTorch's optimizers and the requirement for manual gradient manipulation. HAT-CL aims to rectify these issues with a user-friendly design and a host of new features:

  • Seamless compatibility with all PyTorch operations and optimizers.
  • Automated gradient manipulation through PyTorch hooks.
  • Simple transformation of PyTorch modules to HAT modules with a single line of code.
  • Out-of-the-box HAT networks integrated with timm.

Link to the paper: HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning


Table of Contents


Quick Start

Installation

To install via pip:

pip install hat-cl

Or, if you are using poetry:

poetry add hat-cl

Basic Usage

To use HAT modules, swap generic PyTorch modules for their HAT counterparts (for instance, replace torch.nn.Linear with hat.modules.HATLinear. More examples in Modules). HAT modules process hat.HATPayload instances as input and output, containing tensor, task ID, and other HAT-mechanism essential variables.

Here's a simple 2-layer MLP example:

import torch
import torch.nn as nn
from hat import HATPayload, HATConfig
from hat.modules import HATLinear


hat_config = HATConfig(num_tasks=5)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = HATLinear(input_dim, hidden_dim, hat_config)
        self.relu = nn.ReLU()
        self.linear2 = HATLinear(hidden_dim, output_dim, hat_config)
        
    def forward(self, x: HATPayload):
        x = self.linear1(x)
        # You can still pass the payload to the non-HAT modules like this
        x = x.forward_by(self.relu)
        x = self.linear2(x)
        return x
    
    
mlp = MLP(input_dim=128, hidden_dim=32, output_dim=2)

input_payload = HATPayload(torch.rand(10, 128), task_id=0, mask_scale=10.0)
output_payload = mlp(input_payload)
output_data = output_payload.data

With these steps, you've created a 2-layer MLP with the HAT mechanism and successfully conducted a forward pass through the model. Just like any other PyTorch modules, it's ready to be trained, evaluated, and more—all under-the-hood operations are handled by the HAT modules.

Additionally, HAT-CL provides ready-to-use HAT networks with timm integration. Creating a HAT model is as simple as creating any other timm model:

import timm
import hat.timm_models  # This line is necessary to register the HAT models to timm
from hat import HATConfig

hat_config = HATConfig(num_tasks=5)
hat_resnet18 = timm.create_model('hat_resnet18', hat_config=hat_config)

Modules

Here's a handy table of PyTorch modules and their HAT counterparts:

PyTorch module HAT module
torch.nn.Linear hat.modules.HATLinear
torch.nn.Conv1d hat.modules.HATConv1d
torch.nn.Conv2d hat.modules.HATConv2d
torch.nn.Conv3d hat.modules.HATConv3d
torch.nn.BatchNorm1d hat.modules.TaskIndexedBatchNorm1d
torch.nn.BatchNorm2d hat.modules.TaskIndexedBatchNorm2d
torch.nn.BatchNorm3d hat.modules.TaskIndexedBatchNorm3d
torch.nn.LayerNorm hat.modules.TaskIndexedLayerNorm

Networks

Here are the currently available timm-compatible HAT networks:

HAT Network Name Has pretrained weights Description
hat_resnet18 Yes HAT ResNet-18
hat_resnet18s No HAT ResNet-18 for smaller images
hat_resnet34 Yes HAT ResNet-34
hat_resnet34s No HAT ResNet-34 for smaller images
hat_vit_tiny_patch16_224 Yes HAT ViT-Tiny (16, 224)

Examples


Limitations

HAT-CL, while designed for broad compatibility with PyTorch, faces some constraints due to the inherent characteristics of the HAT mechanism:

  • Optimizer Re-initialization: We recommend refreshing the optimizer state after each task to avoid momentum carryover from prior tasks. This can be easily done by re-initializing the optimizer.
  • Weight Decay (L2 Regularization): Weight decay isn't compatible with HAT due to its gradient altering process, which can interfere with parameters meant to be blocked by the HAT mechanism and cause potential forgetting. This includes the weight_decay optimizer parameter, and any optimizer using weight decay, such as AdamW.

TODO

  • Add example notebook for pruning
  • Package paper for implementation details
  • Add CLOM notebook example
  • Link PyPI package to GitHub repo

Citation

If you use HAT-CL in your research, please cite:

@misc{duan2023hatcl,
    title={HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning}, 
    author={Xiaotian Duan},
    year={2023},
    eprint={2307.09653},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Authors

Xiaotian Duan (xduan7 at gmail dot com)