Skip to content

Commit

Permalink
Fix CI (#140)
Browse files Browse the repository at this point in the history
* Drop support for Pythons older than 3.8

* Remove 2020-era Windows patch from CI

* Remove baroque check_version step from CI

* CI: renovate, add pip cache

* CI: split linting to separate step

* CI: don't allow compiling libraries from scratch, use correct Torch repo

* CI: print pytest progress

* Drop Boltons dependency for small custom LRU dict
  • Loading branch information
akx authored Sep 26, 2023
1 parent efa291a commit 7222329
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 63 deletions.
75 changes: 21 additions & 54 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,78 +9,45 @@ on:
- cron: "0 2 * * 6"

jobs:
check_version:
build:
strategy:
matrix:
python-version: [ 3.8 ]
os: [ ubuntu-latest ]
os: [ ubuntu-latest, macOS-latest, windows-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: setup.py

- name: Check version
if: (github.event_name == 'pull_request' && github.base_ref == 'master')
run: |
python -m pip install --upgrade pip
python -m pip install git+https://github.com/google-research/torchsde.git
master_info=$(pip list | grep torchsde)
master_version=$(echo ${master_info} | cut -d " " -f2)
python -m pip uninstall -y torchsde
python setup.py install
pr_info=$(pip list | grep torchsde)
pr_version=$(echo ${pr_info} | cut -d " " -f2)
- name: Install
run: pip install pytest -e . --only-binary=numpy,scipy,matplotlib,torch
env:
PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu

python -c "import itertools as it
import sys
_, master_version, pr_version = sys.argv
master_version_ = [int(i) for i in master_version.split('.')]
pr_version_ = [int(i) for i in pr_version.split('.')]
master_version__ = tuple(m for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0))
pr_version__ = tuple(p for p, m in it.zip_longest(pr_version_, master_version_, fillvalue=0))
sys.exit(pr_version__ < master_version__)" ${master_version} ${pr_version}
- name: Test with pytest
run: python -m pytest -v

build:
needs: [ check_version ]
strategy:
matrix:
python-version: [ 3.6, 3.8 ]
os: [ ubuntu-latest, macOS-latest, windows-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
- name: Windows patch # Specifically for windows, since pip fails to fetch torch 1.6.0 as of Oct 2020.
if: runner.os == 'Windows'
run: python -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python-version: "3.11"
cache: pip
cache-dependency-path: setup.py

- name: Lint with flake8
run: |
python -m pip install flake8
python -m flake8 .
- name: Test with pytest
run: |
python setup.py install
python -m pytest
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This library provides [stochastic differential equation (SDE)](https://en.wikipe
pip install torchsde
```

**Requirements:** Python >=3.6 and PyTorch >=1.6.0.
**Requirements:** Python >=3.8 and PyTorch >=1.6.0.

## Documentation
Available [here](./DOCUMENTATION.md).
Expand Down
9 changes: 3 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@
url="https://github.com/google-research/torchsde",
packages=setuptools.find_packages(exclude=['benchmarks', 'diagnostics', 'examples', 'tests']),
install_requires=[
"boltons>=20.2.1",
"numpy==1.19;python_version<'3.7'",
"numpy>=1.19;python_version>='3.7'",
"scipy==1.5;python_version<'3.7'",
"scipy>=1.5;python_version>='3.7'",
"numpy>=1.19",
"scipy>=1.5",
"torch>=1.6.0",
"trampoline>=0.1.2",
],
python_requires='~=3.6',
python_requires='>=3.8',
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
Expand Down
18 changes: 16 additions & 2 deletions torchsde/_brownian/brownian_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import trampoline
import warnings

import boltons.cacheutils
import numpy as np
import torch

Expand Down Expand Up @@ -112,6 +111,21 @@ def __getitem__(self, item):
raise KeyError


class _LRUDict(dict):
def __init__(self, max_size):
super().__init__()
self._max_size = max_size
self._keys = []

def __setitem__(self, key, value):
if key in self:
self._keys.remove(key)
elif len(self) >= self._max_size:
del self[self._keys.pop(0)]
super().__setitem__(key, value)
self._keys.append(key)


class _Interval:
# Intervals correspond to some subinterval of the overall interval [t0, t1].
# They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left
Expand Down Expand Up @@ -505,7 +519,7 @@ def __init__(self,
elif cache_size == 0:
self._increment_and_space_time_levy_area_cache = _EmptyDict()
else:
self._increment_and_space_time_levy_area_cache = boltons.cacheutils.LRU(max_size=cache_size)
self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size)

# We keep track of the most recently queried interval, and start searching for the next interval from that
# element of the binary tree. This is because subsequent queries are likely to be near the most recent query.
Expand Down

0 comments on commit 7222329

Please sign in to comment.