Skip to content

Commit

Permalink
PyO3: Add CI (#1135)
Browse files Browse the repository at this point in the history
* Add PyO3 ci

* Update python.yml

* Format `bert.py`
  • Loading branch information
LLukas22 authored Oct 20, 2023
1 parent 7366aea commit cfb423a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
62 changes: 62 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: PyO3-CI

on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**

jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable

- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"

- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}

- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests
7 changes: 4 additions & 3 deletions candle-pyo3/py_src/candle/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
attention_scores = attention_scores.broadcast_add(
attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1)

context_layer = attention_probs.matmul(value)
Expand Down Expand Up @@ -198,7 +197,9 @@ def __init__(self, config: Config, add_pooling_layer=True) -> None:
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None

def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
def forward(
self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
Expand Down

0 comments on commit cfb423a

Please sign in to comment.