-
Notifications
You must be signed in to change notification settings - Fork 26
39 lines (37 loc) · 1.4 KB
/
jax_tests.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
name: Dependency test JAX
on:
pull_request:
types: [labeled]
workflow_dispatch:
jobs:
jax_tests:
if: ${{ github.event.label.name == 'test_jax' && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14]
group: [1, 2]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9
cache: pip
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Remove jax
run: |
pip uninstall jax jaxlib -y
- name: install jax
run: |
pip install "jax[cpu]==${{ matrix.jax-version }}"
- name: Test with pytest
run: |
pwd
lscpu
python -m pytest -m unit --durations=0 --mpl --maxfail=1 --splits 3 --group ${{ matrix.group }} --splitting-algorithm least_duration