Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/stella-400m
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 29, 2024
2 parents 339b435 + 54e7fc3 commit 7415f26
Show file tree
Hide file tree
Showing 154 changed files with 1,715 additions and 193 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -34,7 +37,13 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- name: Delete huge unnecessary tools folder
if: runner.os == 'Linux'
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Traits to Define Backend Behavior
//!
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};

Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/// Methods for backpropagation of gradients.
//! Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/conv.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! 1D and 2D Convolutions
//!
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Traits and methods for CPU-backed Tensors
pub mod erf;
pub mod kernels;

Expand Down
23 changes: 12 additions & 11 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Implementation of Backend Fns for CPU
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
Expand Down Expand Up @@ -65,7 +66,7 @@ impl Map2U8 for Cmp {

struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);

impl<'a, I: IntDType> Map2 for WCond<'a, I> {
impl<I: IntDType> Map2 for WCond<'_, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
Expand Down Expand Up @@ -215,7 +216,7 @@ struct ReduceSum<'a> {
reduce_dims_and_stride: Vec<(usize, usize)>,
}

impl<'a> ReduceSum<'a> {
impl ReduceSum<'_> {
#[inline(always)]
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
where
Expand Down Expand Up @@ -280,7 +281,7 @@ impl<'a> ReduceSum<'a> {
}
}

impl<'a> Map1 for ReduceSum<'a> {
impl Map1 for ReduceSum<'_> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero())
Expand Down Expand Up @@ -453,7 +454,7 @@ struct Gather<'a, I: IntDType> {
dim: usize,
}

impl<'a, I: IntDType> Map1 for Gather<'a, I> {
impl<I: IntDType> Map1 for Gather<'_, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
Expand Down Expand Up @@ -506,7 +507,7 @@ struct IndexSelect<'a, T: IntDType> {
dim: usize,
}

impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
impl<I: IntDType> Map1 for IndexSelect<'_, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
Expand Down Expand Up @@ -559,7 +560,7 @@ struct ScatterAdd<'a, I: IntDType> {
dim: usize,
}

impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
Expand Down Expand Up @@ -615,7 +616,7 @@ struct IndexAdd<'a, I: IntDType> {
dim: usize,
}

impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
impl<I: IntDType> Map2 for IndexAdd<'_, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
Expand Down Expand Up @@ -735,7 +736,7 @@ fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l

struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);

impl<'a> Map2 for Conv1D<'a> {
impl Map2 for Conv1D<'_> {
const OP: &'static str = "conv1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
Expand Down Expand Up @@ -959,7 +960,7 @@ impl Map1 for Col2Im1D {

struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);

impl<'a> Map2 for ConvTranspose1D<'a> {
impl Map2 for ConvTranspose1D<'_> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
Expand Down Expand Up @@ -1028,7 +1029,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {

struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);

impl<'a> Map2 for Conv2D<'a> {
impl Map2 for Conv2D<'_> {
const OP: &'static str = "conv2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
Expand Down Expand Up @@ -1116,7 +1117,7 @@ impl<'a> Map2 for Conv2D<'a> {

struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);

impl<'a> Map2 for ConvTranspose2D<'a> {
impl Map2 for ConvTranspose2D<'_> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Implementation of Backend traits for CUDA device
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub enum DeviceLocation {
Metal { gpu_id: usize },
}

/// Cpu, Cuda, or Metal
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Expand Down
7 changes: 4 additions & 3 deletions candle-core/src/display.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
//! Pretty printing of tensors
//!
//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
//!
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};

Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Implementation of the Cuda backend when Cuda support has not been compiled in.
//!
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Candle-specific Error and Result
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};

#[derive(Debug, Clone)]
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/layout.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Tensor Layouts including contiguous or sparse strides
use crate::{Error, Result, Shape};

#[derive(Debug, PartialEq, Eq, Clone)]
Expand Down
8 changes: 4 additions & 4 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//!
//! let c = a.matmul(&b)?;
//!
//! # Ok(())}
//! ```
//!
Expand Down Expand Up @@ -140,7 +140,7 @@ impl ToUsize2 for (usize, usize) {
}
}

// A simple trait defining a module with forward method using a single argument.
/// Defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
Expand All @@ -160,8 +160,8 @@ impl<M: Module> Module for Option<&M> {
}
}

// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
/// A single forward method using a single single tensor argument and a flag to
/// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Implementation of Backend traits for Metal
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Tensor Opertion Enums and Traits
//!
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use half::{bf16, f16};
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/pickle.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Just enough pickle support to be able to read PyTorch checkpoints.
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/quantized/ggml_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
super::QTensor::new(data, dims)
}

/// Creates a [Tensor] from a raw GGML tensor.
/// Creates a Tensor from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
Expand Down
5 changes: 2 additions & 3 deletions candle-core/src/quantized/gguf_file.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Support for the GGUF file format.
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
Expand Down Expand Up @@ -458,7 +457,7 @@ impl Content {
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
let tensor_data_offset = position.div_ceil(alignment) * alignment;
Ok(Self {
magic,
metadata,
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1850,8 +1850,8 @@ pub fn matmul<T: GgmlType>(
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
}

let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
// TODO: Do not make this copy if the DotType is f32.
// TODO: Pre-allocate this.
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Code for GGML and GGUF files
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
Expand Down
13 changes: 12 additions & 1 deletion candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
//! Module to load `safetensor` files into CPU/GPU memory.
//!
//! There are multiple ways to load tensors from safetensor files:
//! - `load` function for loading directly into memory and returning a HashMap of tensors
//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
//! - `SliceSafetensors` for working with in-memory buffers
//! - `BufferedSafetensors` for owning a buffer of data
//!
//! Tensors can also be serialized to safetensor format using the `save` function or
//! `Tensor::save_safetensors` method.
//!
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;
Expand Down Expand Up @@ -171,7 +182,7 @@ pub trait Load {
fn load(&self, device: &Device) -> Result<Tensor>;
}

impl<'a> Load for st::TensorView<'a> {
impl Load for st::TensorView<'_> {
fn load(&self, device: &Device) -> Result<Tensor> {
convert(self, device)
}
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};

pub enum TensorScalar {
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! StreamTensror useful for streaming ops.
//!
use crate::{Result, Shape, Tensor};

pub trait Dim: crate::shape::Dim + Copy {}
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/strided_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl<'a> StridedIndex<'a> {
}
}

impl<'a> Iterator for StridedIndex<'a> {
impl Iterator for StridedIndex<'_> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
38 changes: 37 additions & 1 deletion candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, false)
}

/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
/// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
/// tensor.
///
/// ```rust
Expand Down Expand Up @@ -1760,6 +1760,42 @@ impl Tensor {
&self.op
}

/// Computes the max of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.max_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.max(0)
}
}

/// Computes the min of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.min_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn min_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.min(0)
}
}

/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Useful functions for checking features.
use std::str::FromStr;

pub fn get_num_threads() -> usize {
Expand Down
2 changes: 1 addition & 1 deletion candle-datasets/src/nlp/tinystories.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'a> DatasetRandomIter<'a> {
}
}

impl<'a> Iterator for DatasetRandomIter<'a> {
impl Iterator for DatasetRandomIter<'_> {
type Item = Result<(Tensor, Tensor)>;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
Loading

0 comments on commit 7415f26

Please sign in to comment.