From e3e5386edcce25cae000bce816e883a0a24fe466 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:05:14 +0200 Subject: [PATCH] PyO3: Add CI (#1135) * Add PyO3 ci * Update python.yml * Format `bert.py` --- .github/workflows/python.yml | 62 ++++++++++++++++++++++++ candle-pyo3/py_src/candle/models/bert.py | 7 +-- 2 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/python.yml diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 0000000000..bf85f5e512 --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,62 @@ +name: PyO3-CI + +on: + workflow_dispatch: + push: + branches: + - main + paths: + - candle-pyo3/** + pull_request: + paths: + - candle-pyo3/** + +jobs: + build_and_test: + name: Check everything builds & tests + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] # For now, only test on Linux + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + architecture: "x64" + + - name: Cache Cargo Registry + uses: actions/cache@v1 + with: + path: ~/.cargo/registry + key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + + - name: Install + working-directory: ./candle-pyo3 + run: | + python -m venv .env + source .env/bin/activate + pip install -U pip + pip install pytest maturin black + python -m maturin develop -r + + - name: Check style + working-directory: ./candle-pyo3 + run: | + source .env/bin/activate + python stub.py --check + black --check . + + - name: Run tests + working-directory: ./candle-pyo3 + run: | + source .env/bin/activate + python -m pytest -s -v tests \ No newline at end of file diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py index 36e242ad00..ecb238d86f 100644 --- a/candle-pyo3/py_src/candle/models/bert.py +++ b/candle-pyo3/py_src/candle/models/bert.py @@ -59,8 +59,7 @@ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: attention_scores = attention_scores / float(self.attention_head_size) ** 0.5 if attention_mask is not None: b_size, _, _, last_dim = attention_scores.shape - attention_scores = attention_scores.broadcast_add( - attention_mask.reshape((b_size, 1, 1, last_dim))) + attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim))) attention_probs = F.softmax(attention_scores, dim=-1) context_layer = attention_probs.matmul(value) @@ -198,7 +197,9 @@ def __init__(self, config: Config, add_pooling_layer=True) -> None: self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None - def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]: + def forward( + self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None + ) -> Tuple[Tensor, Optional[Tensor]]: if attention_mask is not None: # Replace 0s with -inf, and 1s with 0s. attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)