From 0ed24b9852ccc7dfb92d555afba3d56c2a3f3224 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 14 Nov 2024 21:08:04 +0100 Subject: [PATCH 01/15] Add max-all/min-all. (#2616) --- candle-core/src/tensor.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e7355aadc5..75dc1c8a55 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1760,6 +1760,42 @@ impl Tensor { &self.op } + /// Computes the max of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.max_all()?; + /// assert_eq!(tensor.to_scalar::()?, 5.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn max_all(&self) -> Result { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.max(0) + } + } + + /// Computes the min of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.min_all()?; + /// assert_eq!(tensor.to_scalar::()?, 0.); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn min_all(&self) -> Result { + if self.rank() == 0 { + Ok(self.clone()) + } else { + self.flatten_all()?.min(0) + } + } + /// Computes the sum of all the elements in this tensor and returns a tensor holding this /// scalar with zero dimensions. /// From f689ce5d39c6f1475dfc71503288ea2905c8f685 Mon Sep 17 00:00:00 2001 From: zachcp Date: Fri, 15 Nov 2024 02:30:15 -0500 Subject: [PATCH 02/15] Documentation Pass for Models (#2617) * links in chinese_clip * links for clip model * add mod docs for flux and llava * module doc for MMDIT and MIMI * add docs for a few more modesl * mod docs for bert naser and beit * add module docs for convmixer colpali codegeex and chatglm * add another series of moddocs * add fastvit-llama2_c * module docs mamba -> mobileone * module docs from moondream-phi3 * mod docs for quantized and qwen * update to yi * fix long names * Update llama2_c.rs * Update llama2_c_weights.rs * Fix the link for mimi + tweaks --------- Co-authored-by: Laurent Mazare --- candle-transformers/src/models/based.rs | 7 +++---- candle-transformers/src/models/beit.rs | 7 +++++++ candle-transformers/src/models/bert.rs | 6 ++++++ candle-transformers/src/models/bigcode.rs | 7 +++++++ candle-transformers/src/models/blip.rs | 7 +++++++ candle-transformers/src/models/blip_text.rs | 6 ++++++ candle-transformers/src/models/chatglm.rs | 7 +++++++ .../src/models/chinese_clip/mod.rs | 5 +++-- candle-transformers/src/models/clip/mod.rs | 5 +++-- .../src/models/codegeex4_9b.rs | 7 +++++++ candle-transformers/src/models/colpali.rs | 5 +++++ candle-transformers/src/models/convmixer.rs | 7 +++++++ candle-transformers/src/models/convnext.rs | 14 ++++++------- candle-transformers/src/models/dac.rs | 7 ++++++- .../src/models/depth_anything_v2.rs | 6 ++++++ candle-transformers/src/models/dinov2.rs | 5 +++++ candle-transformers/src/models/dinov2reg4.rs | 7 +++++++ candle-transformers/src/models/distilbert.rs | 5 +++++ .../src/models/efficientnet.rs | 5 +++++ .../src/models/efficientvit.rs | 7 +++---- candle-transformers/src/models/encodec.rs | 6 ++++++ candle-transformers/src/models/eva2.rs | 6 ++++++ candle-transformers/src/models/falcon.rs | 6 ++++++ candle-transformers/src/models/fastvit.rs | 8 +++---- candle-transformers/src/models/flux/mod.rs | 7 +++++++ candle-transformers/src/models/gemma.rs | 6 ++++++ candle-transformers/src/models/gemma2.rs | 6 ++++++ candle-transformers/src/models/glm4.rs | 6 ++++++ candle-transformers/src/models/granite.rs | 7 +++++++ candle-transformers/src/models/hiera.rs | 8 +++---- candle-transformers/src/models/jina_bert.rs | 6 ++++++ candle-transformers/src/models/llama.rs | 6 ++++++ candle-transformers/src/models/llama2_c.rs | 6 ++++++ .../src/models/llama2_c_weights.rs | 6 ++++++ candle-transformers/src/models/llava/mod.rs | 10 +++++++++ candle-transformers/src/models/mamba.rs | 9 ++++++-- candle-transformers/src/models/marian.rs | 6 ++++++ candle-transformers/src/models/metavoice.rs | 6 ++++++ candle-transformers/src/models/mimi/mod.rs | 11 +++++++--- candle-transformers/src/models/mistral.rs | 7 +++++++ candle-transformers/src/models/mixformer.rs | 7 +++++++ candle-transformers/src/models/mixtral.rs | 17 +++++++++++++++ candle-transformers/src/models/mmdit/mod.rs | 9 ++++++++ candle-transformers/src/models/mobileclip.rs | 16 ++++++++++++++ candle-transformers/src/models/mobilenetv4.rs | 11 +++++++--- candle-transformers/src/models/mobileone.rs | 5 +++-- candle-transformers/src/models/moondream.rs | 11 ++++++++++ candle-transformers/src/models/mpt.rs | 8 +++++++ candle-transformers/src/models/olmo.rs | 16 ++++++++++++++ .../src/models/openclip/mod.rs | 8 +++++++ candle-transformers/src/models/paligemma.rs | 16 ++++++++++++++ candle-transformers/src/models/parler_tts.rs | 17 +++++++++++++++ candle-transformers/src/models/persimmon.rs | 16 ++++++++++++++ candle-transformers/src/models/phi.rs | 17 +++++++++++++++ candle-transformers/src/models/phi3.rs | 19 +++++++++++++++++ candle-transformers/src/models/pixtral/mod.rs | 8 +++++++ .../src/models/quantized_blip.rs | 16 ++++++++++++++ .../src/models/quantized_blip_text.rs | 17 +++++++++++++++ .../src/models/quantized_llama.rs | 17 +++++++++++++++ .../src/models/quantized_llama2_c.rs | 16 ++++++++++++++ .../src/models/quantized_metavoice.rs | 16 ++++++++++++++ .../src/models/quantized_mistral.rs | 17 +++++++++++++++ .../src/models/quantized_mixformer.rs | 13 ++++++++++++ .../src/models/quantized_moondream.rs | 15 +++++++++++++ .../src/models/quantized_mpt.rs | 18 ++++++++++++++++ .../src/models/quantized_phi.rs | 17 +++++++++++++++ .../src/models/quantized_phi3.rs | 15 +++++++++++++ .../src/models/quantized_qwen2.rs | 15 +++++++++++++ .../src/models/quantized_recurrent_gemma.rs | 17 +++++++++++++++ .../src/models/quantized_rwkv_v5.rs | 17 +++++++++++++++ .../src/models/quantized_rwkv_v6.rs | 18 ++++++++++++++++ .../src/models/quantized_stable_lm.rs | 15 +++++++++++++ .../src/models/quantized_t5.rs | 18 ++++++++++++++-- candle-transformers/src/models/qwen2.rs | 17 +++++++++++++++ candle-transformers/src/models/qwen2_moe.rs | 18 ++++++++++++++++ .../src/models/recurrent_gemma.rs | 21 +++++++++++++++++-- candle-transformers/src/models/repvgg.rs | 11 ++++++++++ candle-transformers/src/models/resnet.rs | 14 ++++++++++--- candle-transformers/src/models/rwkv_v5.rs | 17 +++++++++++++++ candle-transformers/src/models/rwkv_v6.rs | 16 ++++++++++++++ candle-transformers/src/models/segformer.rs | 16 ++++++++++++++ .../src/models/segment_anything/mod.rs | 8 +++++++ candle-transformers/src/models/siglip.rs | 8 +++++++ .../src/models/stable_diffusion/mod.rs | 9 ++++++++ candle-transformers/src/models/stable_lm.rs | 15 +++++++++++++ candle-transformers/src/models/starcoder2.rs | 17 +++++++++++++++ .../src/models/stella_en_v5.rs | 17 +++++++++++++++ candle-transformers/src/models/t5.rs | 18 ++++++++++++++-- candle-transformers/src/models/trocr.rs | 16 ++++++++++++++ candle-transformers/src/models/vgg.rs | 15 +++++++++++-- candle-transformers/src/models/vit.rs | 17 +++++++++++++++ candle-transformers/src/models/whisper/mod.rs | 8 +++++++ .../src/models/wuerstchen/mod.rs | 9 ++++++++ candle-transformers/src/models/yi.rs | 16 +++++++++++++- 94 files changed, 1001 insertions(+), 51 deletions(-) diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index aa28f52333..c54ff96629 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -1,10 +1,9 @@ //! Based from the Stanford Hazy Research group. //! //! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 -//! - -//! Original code: -//! https://github.com/HazyResearch/based +//! - [Arxiv](https://arxiv.org/abs/2402.18668) +//! - [Github](https://github.com/HazyResearch/based) +//! use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/beit.rs b/candle-transformers/src/models/beit.rs index 8f6284a8e6..2f61d9d6f1 100644 --- a/candle-transformers/src/models/beit.rs +++ b/candle-transformers/src/models/beit.rs @@ -1,3 +1,10 @@ +//! Based on the BEIT vision-language model. +//! +//! See "BEIT: BERT Pre-Training of Image Transformers", Bao et al. 2021 +//! - [Arxiv](https://arxiv.org/abs/2106.08254) +//! - [Github](https://github.com/microsoft/unilm/tree/master/beit) +//! + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index bdc0385deb..a7db075cbb 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,3 +1,9 @@ +//! BERT (Bidirectional Encoder Representations from Transformers) +//! +//! See "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", Devlin et al. 2018 +//! - [Arxiv](https://arxiv.org/abs/1810.04805) +//! - [Github](https://github.com/google-research/bert) +//! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index f6b4a4efdc..8ed1462b1c 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,3 +1,10 @@ +//! BigCode implementation in Rust based on the GPT-BigCode model. +//! +//! See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2305.06161) +//! - [Github](https://github.com/bigcode-project/starcoder) +//! + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index e0b0b6a596..0330386574 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -1,3 +1,10 @@ +//! Based on the BLIP paper from Salesforce Research. +//! +//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" +//! - [Arxiv](https://arxiv.org/abs/2201.12086) +//! - [Github](https://github.com/salesforce/BLIP) +//! + use super::blip_text; use super::with_tracing::{conv2d, linear, Conv2d, Linear}; use candle::{Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index 1862abef4b..aceaf4ac1b 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -1,3 +1,9 @@ +//! Implementation of BLIP text encoder/decoder. +//! +//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" +//! https://arxiv.org/abs/2201.12086 +//! + use super::with_tracing::{linear, Embedding, Linear}; use candle::{Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 0686b34ef3..8d5d9ec601 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -1,3 +1,10 @@ +//! Implementation of the ChatGLM2/3 models from THUDM. +//! +//! See: +//! - ChatGLM3: ["ChatGLM3: Advancing Multilingual Conversational Language Models with High-Quality Data"](https://github.com/THUDM/ChatGLM3) +//! - ChatGLM2: ["ChatGLM2: An Open Bilingual Chat LLM"](https://github.com/THUDM/ChatGLM2-6B) +//! + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 0f6eedd0f2..86616baa1c 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -3,8 +3,9 @@ //! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/OFA-Sys/Chinese-CLIP -//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py +//! - [GH Link](https://github.com/OFA-Sys/Chinese-CLIP) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! use candle::{Module, Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs index 3dd5fb485b..e83f27e388 100644 --- a/candle-transformers/src/models/clip/mod.rs +++ b/candle-transformers/src/models/clip/mod.rs @@ -3,8 +3,9 @@ //! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/openai/CLIP -//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip +//! - [GH Link](https://github.com/openai/CLIP) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip) + use self::{ text_model::{Activation, ClipTextTransformer}, vision_model::ClipVisionTransformer, diff --git a/candle-transformers/src/models/codegeex4_9b.rs b/candle-transformers/src/models/codegeex4_9b.rs index aaa99fd96d..baf4745922 100644 --- a/candle-transformers/src/models/codegeex4_9b.rs +++ b/candle-transformers/src/models/codegeex4_9b.rs @@ -1,3 +1,10 @@ +//! CodeGeeX4 - A multi-language code generation model +//! +//! See "CodeGeeX: A Pre-Trained Model For Code Generation with Multilingual Evaluations on HumanEval-X", Qian et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2303.17568) +//! - [Github](https://github.com/THUDM/CodeGeeX) +//! + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs index 1299b0a410..16ca4eb304 100644 --- a/candle-transformers/src/models/colpali.rs +++ b/candle-transformers/src/models/colpali.rs @@ -1,3 +1,8 @@ +//! Colpali Model for text/image similarity scoring. +//! +//! Colpali combines a vision encoder with an efficient LM for retrieving content. +//! + use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index f5abfa5da3..e095f793a4 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -1,3 +1,10 @@ +//! ConvMixer implementation. +//! +//! See "Patches Are All You Need?" by Trockman et al. 2022 +//! - [Arxiv](https://arxiv.org/abs/2201.09792) +//! - [Github](https://github.com/locuslab/convmixer) +//! + use candle::Result; use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder}; diff --git a/candle-transformers/src/models/convnext.rs b/candle-transformers/src/models/convnext.rs index 94b1833ec2..d791895f1d 100644 --- a/candle-transformers/src/models/convnext.rs +++ b/candle-transformers/src/models/convnext.rs @@ -1,15 +1,13 @@ //! ConvNeXt implementation. //! -//! See "A ConvNet for the 2020s" Liu et al. 2022 -//! +//! See ["A ConvNet for the 2020s" Liu et al. 2022](https://arxiv.org/abs/2201.03545) //! and -//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023 -//! - +//! ["ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023](https://arxiv.org/abs/2301.00808) +//! //! Original code: -//! https://github.com/facebookresearch/ConvNeXt/ -//! https://github.com/facebookresearch/ConvNeXt-V2/ -//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py +//! - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/) +//! - [ConvNeXt-V2](https://github.com/facebookresearch/ConvNeXt-V2/) +//! - [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py) use candle::shape::ShapeWithOneHole; use candle::{Result, D}; diff --git a/candle-transformers/src/models/dac.rs b/candle-transformers/src/models/dac.rs index fa6c8c7120..78728b4d09 100644 --- a/candle-transformers/src/models/dac.rs +++ b/candle-transformers/src/models/dac.rs @@ -1,4 +1,9 @@ -/// Adapted from https://github.com/descriptinc/descript-audio-codec +//! Implementation of the Descript Audio Codec (DAC) model +//! +//! See: [Descript Audio Codec](https://github.com/descriptinc/descript-audio-codec) +//! +/// An efficient neural codec for compressing/decompressing audio +/// use crate::models::encodec; use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder}; diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 9eee6d1130..411b0764ff 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -1,3 +1,9 @@ +//! Implementation of the Depth Anything model from FAIR. +//! +//! See: +//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything) +//! + use candle::D::Minus1; use candle::{Module, Result, Tensor}; use candle_nn::ops::Identity; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 706dfda0e7..df8834d1f7 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -1,3 +1,8 @@ +//! Implementation of the DINOv2 models from Meta Research. +//! +//! See: +//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs index 1d81703c9c..0d2320e14c 100644 --- a/candle-transformers/src/models/dinov2reg4.rs +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -1,3 +1,10 @@ +//! Implementation of the DINOv2 revision (4 regularization) +//! +//! See: +//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! +//! This code implements the regularization tokens version with 4 regularization tokens. +//! use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs index f899d772a2..fad76cfcce 100644 --- a/candle-transformers/src/models/distilbert.rs +++ b/candle-transformers/src/models/distilbert.rs @@ -1,3 +1,8 @@ +//! Implementation of DistilBert, a distilled version of BERT. +//! +//! See: +//! - ["DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter"](https://arxiv.org/abs/1910.01108) +//! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index f15c9c797e..ecca2509ae 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -1,3 +1,8 @@ +//! Implementation of EfficientBert, an efficient variant of BERT for computer vision tasks. +//! +//! See: +//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462) +//! use candle::{Result, Tensor, D}; use candle_nn as nn; use nn::{Module, VarBuilder}; diff --git a/candle-transformers/src/models/efficientvit.rs b/candle-transformers/src/models/efficientvit.rs index b17c4ea0a1..9724f702a6 100644 --- a/candle-transformers/src/models/efficientvit.rs +++ b/candle-transformers/src/models/efficientvit.rs @@ -1,9 +1,8 @@ //! EfficientViT (MSRA) inference implementation based on timm. //! -//! See "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention" -//! https://arxiv.org/abs/2305.07027 - -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py +//! See ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027) +//! +//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py) use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index ba6686f605..a8d509ce8b 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -1,3 +1,9 @@ +//! EnCodec neural audio codec based on the Encodec implementation. +//! +//! See ["High Fidelity Neural Audio Compression"](https://arxiv.org/abs/2210.13438) +//! +//! Based on implementation from [huggingface/transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py) + #![allow(unused)] use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D}; use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder}; diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs index 013c385d1c..ee84cca43c 100644 --- a/candle-transformers/src/models/eva2.rs +++ b/candle-transformers/src/models/eva2.rs @@ -1,3 +1,9 @@ +//! EVA-2 inference implementation. +//! +//! See ["EVA-02: A Visual Representation for Neon Genesis"](https://arxiv.org/abs/2303.11331) +//! +//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py) + use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 50ec66f316..c75b4d70d3 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,3 +1,9 @@ +//! Falcon language model inference implementation +//! +//! See ["Falcon: a new approach to large language models"](https://huggingface.co/blog/falcon) +//! +//! Based on implementation from [Huggingface Transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon) + use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; use serde::Deserialize; diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 8eae8bb200..4e29665358 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -1,9 +1,9 @@ -//! FastViT inference implementation based on timm +//! # FastViT inference implementation based on timm //! -//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization" -//! https://arxiv.org/pdf/2303.14189 +//! ## Description +//! See ["FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"](https://arxiv.org/pdf/2303.14189) //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py +//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py) use candle::{DType, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index b0c8a6939a..8eb928f557 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,3 +1,10 @@ +//! Flux Model +//! +//! Flux is a series of text-to-image generation models based on diffusion transformers. +//! +//! - [GH Link](https://github.com/black-forest-labs/flux) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! use candle::{Result, Tensor}; pub trait WithForward { diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index c22a39480c..4b656d6a7f 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -1,3 +1,9 @@ +//! Gemma inference implementation. +//! +//! See ["Gemma: Open Models Based on Gemini Technology"](https://blog.google/technology/developers/gemma-open-ai-model/) +//! +//! Based on implementation from Google and PyTorch + use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/gemma2.rs b/candle-transformers/src/models/gemma2.rs index f0d650479e..ec23efc529 100644 --- a/candle-transformers/src/models/gemma2.rs +++ b/candle-transformers/src/models/gemma2.rs @@ -1,3 +1,9 @@ +//! Gemma LLM architecture (Google) inference implementation. +//! +//! See ["Gemma: Open Models Based on Gemini Technology"](https://blog.google/technology/developers/gemma-open-models/) +//! +//! Based on implementations from Google and OpenLLM + use std::sync::Arc; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index 3b436eaa6d..de6581d0b7 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -1,3 +1,9 @@ +//! GLM-4 inference implementation. +//! +//! An open bilingual language model with 130B parameters. +//! +//! Based on implementation from [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) + use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/granite.rs b/candle-transformers/src/models/granite.rs index 6d25c339b2..f1b2c4db5b 100644 --- a/candle-transformers/src/models/granite.rs +++ b/candle-transformers/src/models/granite.rs @@ -1,3 +1,10 @@ +//! Granite is a Long Context Transformer Language Model. +//! +//! A high performance transformer model optimized for efficient processing +//! of very long context sequences +//! +//! Based on implementation from [Nod.ai](https://github.com/nod-ai/granite) + use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/hiera.rs b/candle-transformers/src/models/hiera.rs index 52efb78ea3..39f8d639b6 100644 --- a/candle-transformers/src/models/hiera.rs +++ b/candle-transformers/src/models/hiera.rs @@ -1,9 +1,9 @@ -//! Hiera inference implementation based on timm. +//! [Hiera] inference implementation based on timm. //! -//! See "Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles" -//! https://arxiv.org/abs/2306.00989 +//! See "[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]" +//! [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]: https://arxiv.org/abs/2306.00989 //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py +//! [Hiera]: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py use candle::{Result, D}; use candle_nn::{conv2d, layer_norm, linear, ops::softmax, Conv2dConfig, Func, VarBuilder}; diff --git a/candle-transformers/src/models/jina_bert.rs b/candle-transformers/src/models/jina_bert.rs index 1f0fae1ee4..40535a8bb9 100644 --- a/candle-transformers/src/models/jina_bert.rs +++ b/candle-transformers/src/models/jina_bert.rs @@ -1,3 +1,9 @@ +//! # JinaBERT inference implementation +//! +//! Based on implementation from huggingface for Jina BERT and its variants +//! +//! See: [Jina Embeddings on HuggingFace](https://huggingface.co/jinaai/jina-embeddings-v2-base-en) + use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index e77697340e..4396063ff7 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,3 +1,9 @@ +//! Llama inference implementation. +//! +//! See ["LLaMA: Open and Efficient Foundation Language Models"](https://arxiv.org/abs/2302.13971) +//! +//! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) + use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 923a270646..d825d8e4dd 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -1,3 +1,9 @@ +//! Llama2 inference implementation. +//! +//! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288) +//! +//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation + use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/llama2_c_weights.rs b/candle-transformers/src/models/llama2_c_weights.rs index e5a8bb8806..8149c214c9 100644 --- a/candle-transformers/src/models/llama2_c_weights.rs +++ b/candle-transformers/src/models/llama2_c_weights.rs @@ -1,3 +1,9 @@ +//! Llama2 inference implementation. +//! +//! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288) +//! +//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation + use byteorder::{LittleEndian, ReadBytesExt}; use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index 1ed3b50c63..44a00bf9a1 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -1,3 +1,13 @@ +//! The LLaVA (Large Language and Vision Assistant) model. +//! +//! This provides the main model implementation combining a vision tower (CLIP) with +//! language model (Llama) for multimodal capabilities. +//! +//! The architecture implements the training-free projection technique from the paper: +//! [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485). +//! +//! - [GH Link](https://github.com/haotian-liu/LLaVA/tree/main) +//! pub mod config; pub mod utils; diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index a75ee87a6e..18a0285ff6 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -1,5 +1,10 @@ -/// A fast implementation of mamba for inference only. -/// This is based on: https://github.com/LaurentMazare/mamba.rs +//! Mamba inference implementation. +//! +//! See ["Mamba: Linear-Time Sequence Modeling with Selective State Spaces"](https://arxiv.org/abs/2312.00752) +//! +//! Based on reference implementation from the AlbertMamba project +//! A fast implementation of mamba for inference only. +//! Based on Laurent Mazare's rust implementation: [mamba.rs](https://github.com/LaurentMazare/mamba.rs) use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index e93370c23e..c4ba0a154d 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -1,3 +1,9 @@ +//! Marian Neural Machine Translation +//! +//! See "Marian: Fast Neural Machine Translation in C++" Junczys-Dowmunt et al. 2018 +//! - [ACL Anthology](https://aclanthology.org/P18-4020/) +//! - [Github](https://github.com/marian-nmt/marian) +//! use super::with_tracing::{linear, Embedding, Linear}; use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 43de594f9d..92d3ffba08 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -1,3 +1,9 @@ +//! MetaVoice Studio ML Models +//! +//! See MetaVoice's TTS and voice cloning models: +//! - [Github](https://github.com/metavoiceio/metavoice-src) +//! - [Website](https://studio.metavoice.ai/) + use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; diff --git a/candle-transformers/src/models/mimi/mod.rs b/candle-transformers/src/models/mimi/mod.rs index dc40e38e29..f19f9ae5fa 100644 --- a/candle-transformers/src/models/mimi/mod.rs +++ b/candle-transformers/src/models/mimi/mod.rs @@ -1,9 +1,14 @@ -// Adapted from the reference implementation at: -// https://github.com/kyutai-labs/moshi +//! mimi model +//! +//! Mimi is a state-of-the-art audio neural codec. +//! +//! - [HuggingFace Model Card](https://huggingface.co/kyutai/mimi) +//! - [GitHub](https://github.com/kyutai-labs/moshi) +//! + // Copyright (c) Kyutai, all rights reserved. // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. - pub use candle; pub use candle_nn; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index e8f7a7c4b8..f927f88b2d 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -1,3 +1,10 @@ +//! Mixtral Model, based on the Mistral architecture +//! +//! See Mistral and Mixtral at: +//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral) +//! - [Github](https://github.com/mistralai/mistral-src) +//! + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mistral LLM, https://github.com/mistralai/mistral-src use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 700829e33b..2c2909c3e0 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -1,3 +1,10 @@ +//! MixFormer (Microsoft's Phi Architecture) +//! +//! See "Textbooks Are All You Need II: phi-1.5 technical report", Lin et al. 2023 +//! - [Arxiv](https://arxiv.org/abs/2309.05463) +//! - [Github](https://huggingface.co/microsoft/phi-1_5) +//! + use crate::models::with_tracing::{linear, Embedding as E, Linear}; /// MixFormer model. /// https://huggingface.co/microsoft/phi-1_5 diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs index a578d6fed0..70115e10a3 100644 --- a/candle-transformers/src/models/mixtral.rs +++ b/candle-transformers/src/models/mixtral.rs @@ -1,3 +1,20 @@ +//! Mixtral Model, a sparse mixture of expert model based on the Mistral architecture +//! +//! See Mixtral model details at: +//! - [Hugging Face](https://huggingface.co/docs/transformers/model_doc/mixtral) +//! - [Mixtral-8x7B Blog Post](https://mistral.ai/news/mixtral-of-experts/) +//! +//! The model uses a mixture of experts architecture with: +//! - 8 experts per layer +//! - Top 2 expert routing +//! - Sliding window attention +//! - RoPE embeddings +//! +//! References: +//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py) +//! - [Mixtral Blog Post](https://mistral.ai/news/mixtral-of-experts/) +//! + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs index 9c4db6e085..ce4872e0b2 100644 --- a/candle-transformers/src/models/mmdit/mod.rs +++ b/candle-transformers/src/models/mmdit/mod.rs @@ -1,3 +1,12 @@ +//! Mix of Multi-scale Dilated and Traditional Convolutions +//! +//! Mix of Multi-scale Dilated and Traditional Convolutions (MMDiT) is an architecture +//! introduced for Stable Diffusion 3, with the MMDiT-X variant used in Stable Diffusion 3.5. +//! +//! - [Research Paper](https://arxiv.org/abs/2403.03206) +//! - ComfyUI [reference implementation](https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py) +//! - Stability-AI [MMDiT-X implementation](https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py) + pub mod blocks; pub mod embedding; pub mod model; diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs index 45a5dbad9f..f0baf9e10c 100644 --- a/candle-transformers/src/models/mobileclip.rs +++ b/candle-transformers/src/models/mobileclip.rs @@ -1,3 +1,19 @@ +//! Mobile CLIP model, combining a lightweight vision encoder with a text encoder +//! +//! A mobile-optimized CLIP implementation that uses: +//! - FastViT as the vision encoder +//! - OpenCLIP text encoder +//! - Projection layers to align the feature spaces +//! +//! See model details at: +//! - [FastViT](https://arxiv.org/abs/2303.14189) +//! - [OpenCLIP](https://github.com/mlfoundations/open_clip) +//! +//! References: +//! - [MobileVLM](https://huggingface.co/mobileVLM) +//! - [MetaCLIP](https://arxiv.org/abs/2309.16671) +//! + use super::fastvit; use super::openclip::text_model; use candle::{Result, Tensor, D}; diff --git a/candle-transformers/src/models/mobilenetv4.rs b/candle-transformers/src/models/mobilenetv4.rs index 7cbae7c385..ab1e70803f 100644 --- a/candle-transformers/src/models/mobilenetv4.rs +++ b/candle-transformers/src/models/mobilenetv4.rs @@ -1,9 +1,14 @@ +//! # MobileNet-v4 +//! //! MobileNet-v4 inference implementation based on timm. //! -//! See "MobileNetV4 - Universal Models for the Mobile Ecosystem" -//! https://arxiv.org/abs/2404.10518 +//! ## Paper +//! +//! ["MobileNetV4 - Universal Models for the Mobile Ecosystem"](https://arxiv.org/abs/2404.10518) +//! +//! ## References //! -//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py +//! - [PyTorch Implementation](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py) use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/mobileone.rs b/candle-transformers/src/models/mobileone.rs index 674da40b97..e8836745b9 100644 --- a/candle-transformers/src/models/mobileone.rs +++ b/candle-transformers/src/models/mobileone.rs @@ -1,7 +1,8 @@ +//! # MobileOne +//! //! MobileOne inference implementation based on timm and candle-repvgg //! -//! See "MobileOne: An Improved One millisecond Mobile Backbone" -//! https://arxiv.org/abs/2206.04040 +//! See ["MobileOne: An Improved One millisecond Mobile Backbone"](https://arxiv.org/abs/2206.04040) use candle::{DType, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index cde59d43d6..d351d7c019 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -1,3 +1,14 @@ +//! MoonDream Model vision-to-text +//! +//! The model consists of: +//! - Vision encoder using a ViT-style architecture +//! - Text decoder based on Microsoft's Phi model +//! - Vision projection module to align vision and text embeddings +//! +//! References: +//! - [MoonDream Original Implementation](https://github.com/vikhyat/moondream) +//! + use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear}; use candle::{IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index d46524fcc2..d4170d6bff 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -1,3 +1,11 @@ +//! Module implementing the MPT (Multi-Purpose Transformer) model +//! +//! References: +//! - [MPT Model used by replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py) +//! - [Configuration](https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py) +//! +//! The model uses grouped query attention and alibi positional embeddings. + use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; /// MPT model used by replit-code-v1_5-3b /// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py diff --git a/candle-transformers/src/models/olmo.rs b/candle-transformers/src/models/olmo.rs index 983a33340a..6cf5b1f79d 100644 --- a/candle-transformers/src/models/olmo.rs +++ b/candle-transformers/src/models/olmo.rs @@ -1,3 +1,19 @@ +//! OLMo (Open Language Model) implementation +//! +//! See OLMo model details at: +//! - [Hugging Face](https://huggingface.co/allenai/OLMo) +//! - [OLMo Paper](https://allenai.org/olmo) +//! +//! The model uses: +//! - RoPE embeddings +//! - Sliding window attention +//! - Transformer architecture +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/allenai/OLMo) +//! - [OLMo Paper](https://allenai.org/olmo) +//! + use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{linear_b, linear_no_bias, Activation, LayerNorm, Linear, VarBuilder}; use std::sync::Arc; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs index ee2a501d6a..dacb627f9e 100644 --- a/candle-transformers/src/models/openclip/mod.rs +++ b/candle-transformers/src/models/openclip/mod.rs @@ -1 +1,9 @@ +//! Open Contrastive Language-Image Pre-Training +//! +//! Open Contrastive Language-Image Pre-Training (OpenCLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! - [GH Link](https://github.com/mlfoundations/open_clip) +//! + pub mod text_model; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs index a5e7f694f5..e992869923 100644 --- a/candle-transformers/src/models/paligemma.rs +++ b/candle-transformers/src/models/paligemma.rs @@ -1,3 +1,19 @@ +//! Multimodal multi-purpose model combining Gemma-based language model with SigLIP image understanding +//! +//! See PaLiGemma details at: +//! - [Paper](https://arxiv.org/abs/2402.05257) +//! - [Google Blog Post](https://blog.research.google/2024/02/paligemma-scaling-language-image.html) +//! +//! The model is a multimodal combination of: +//! - SigLIP vision encoder +//! - Gemma language model +//! - Cross-projection layers +//! +//! References: +//! - [HuggingFace Implementation](https://huggingface.co/google/paligemma-3b) +//! - [Paper: PaLI-3 and Beyond: Scaling Language-Image Learning](https://arxiv.org/abs/2402.05257) +//! + use crate::models::{gemma, siglip}; use candle::{Module, Result, Tensor}; use candle_nn::{linear, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs index da40124741..0c08aa9427 100644 --- a/candle-transformers/src/models/parler_tts.rs +++ b/candle-transformers/src/models/parler_tts.rs @@ -1,3 +1,20 @@ +//! Parler Model implementation for parler_tts text-to-speech synthesis +//! +//! Implements a transformer-based decoder architecture for generating audio tokens +//! from text using discrete tokens. The model converts text into audio segments +//! using multiple codebooks of quantized audio tokens. +//! +//! The model architecture includes: +//! - Multi-head attention layers for text and audio processing +//! - Feed-forward networks +//! - Layer normalization +//! - Positional embeddings +//! - Multiple codebook prediction heads +//! +//! The implementation follows the original parler_tts architecture while focusing +//! on audio token generation for text-to-speech synthesis. +//! + use crate::generation::LogitsProcessor; use crate::models::t5; use candle::{IndexOp, Result, Tensor}; diff --git a/candle-transformers/src/models/persimmon.rs b/candle-transformers/src/models/persimmon.rs index afee7c83ee..0996decf55 100644 --- a/candle-transformers/src/models/persimmon.rs +++ b/candle-transformers/src/models/persimmon.rs @@ -1,3 +1,19 @@ +//! Persimmon Model +//! +//! A transformer language model for efficient inference and general-purpose tasks. See Persimmon model details at: +//! - [Hugging Face](https://huggingface.co/adept/persimmon-8b-base) +//! +//! The model uses a standard transformer architecture with: +//! - Layer normalization for Q/K attention +//! - RoPE embeddings with partial rotary factor +//! - ReLU activation +//! - Separate number of attention heads and KV heads +//! +//! References: +//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/modeling_persimmon.py) +//! - [Persimmon Config](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py) +//! + use candle::DType; use serde::Deserialize; diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index bffc14faed..36a08bb3c6 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -1,3 +1,20 @@ +//! Microsoft Phi model implementation +//! +//! See Phi model details at: +//! - [Phi-2 Model](https://huggingface.co/microsoft/phi-2) +//! +//! The Phi series are decoder-only transformers designed for code and language tasks. +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE embeddings +//! - Layer normalization +//! - QK normalization +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-2) +//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-2/tree/main) +//! + use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; /// Phi model. /// https://huggingface.co/microsoft/phi-2 diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index a5e3e9a948..7ce9e987c9 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -1,3 +1,22 @@ +//! Microsoft Phi-3 model implementation +//! +//! See Phi model details at: +//! - [Phi-3 Model](https://huggingface.co/microsoft/phi-3) +//! +//! The Phi series are decoder-only transformers designed for code and language tasks. +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE embeddings +//! - Layer normalization +//! - QK normalization +//! - Mixed activation functions +//! - Improved context window handling +//! +//! References: +//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-3) +//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-3/tree/main) +//! + // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs index 9d0eccfb57..53f9ef9182 100644 --- a/candle-transformers/src/models/pixtral/mod.rs +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -1,3 +1,11 @@ +//! Pixtral Language-Image Pre-Training +//! +//! Pixtral is an architecture trained for multimodal learning +//! using images paired with text descriptions. +//! +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) +//! + pub mod llava; pub mod vision_model; diff --git a/candle-transformers/src/models/quantized_blip.rs b/candle-transformers/src/models/quantized_blip.rs index 31e22b4570..acba9ba191 100644 --- a/candle-transformers/src/models/quantized_blip.rs +++ b/candle-transformers/src/models/quantized_blip.rs @@ -1,3 +1,19 @@ +//! BLIP model implementation with quantization support. +//! +//! BLIP is a vision-language model for image understanding and generation tasks. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Vision encoder using ViT architecture +//! - Text decoder using BERT-style transformer +//! - Cross-attention between vision and text features +//! - Support for 8-bit quantization +//! +//! References: +//! - [BLIP Paper](https://arxiv.org/abs/2201.12086) +//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip) +//! + use super::quantized_blip_text as blip_text; use crate::quantized_nn::{layer_norm, linear, Linear}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_blip_text.rs b/candle-transformers/src/models/quantized_blip_text.rs index 652205d6f6..61e468e78b 100644 --- a/candle-transformers/src/models/quantized_blip_text.rs +++ b/candle-transformers/src/models/quantized_blip_text.rs @@ -1,3 +1,20 @@ +//! Quantized BLIP text module implementation. +//! +//! Provides the text decoder portion of the BLIP model with 8-bit quantization. +//! Uses a BERT-style transformer architecture for text processing. +//! +//! Key components: +//! - Text embeddings layer with position embeddings +//! - Multi-head self attention layers +//! - Cross-attention for vision-text fusion +//! - Layer normalization and feed-forward layers +//! - Quantized linear transformations +//! +//! References: +//! - [BLIP Paper](https://arxiv.org/abs/2201.12086) +//! - [Hugging Face Implementation](https://huggingface.co/docs/transformers/model_doc/blip) +//! + use crate::models::with_tracing::QMatMul; use crate::quantized_nn::{layer_norm, linear, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 04a50981b6..7efd385d61 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -1,3 +1,20 @@ +//! Quantized llama model implementation. +//! +//! This provides a quantized implementation of the llama language model architecture. +//! The model implements parameter efficient quantization for reduced memory usage +//! while maintaining model quality. +//! +//! Key characteristics: +//! - Transformer decoder architecture +//! - Support for 2/3/4/8-bit quantization +//! - Optimized memory usage through quantization +//! - Configurable model sizes and parameter counts +//! +//! References: +//! - [LLaMA Paper](https://arxiv.org/abs/2302.13971) +//! - [LLaMA Model](https://github.com/facebookresearch/llama) +//! + use std::collections::HashMap; use crate::quantized_nn::RmsNorm; diff --git a/candle-transformers/src/models/quantized_llama2_c.rs b/candle-transformers/src/models/quantized_llama2_c.rs index cbb8aad8da..3eb14bb9e6 100644 --- a/candle-transformers/src/models/quantized_llama2_c.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -1,3 +1,19 @@ +//! Quantized Llama2 model implementation. +//! +//! This provides an 8-bit quantized implementation of Meta's LLaMA2 language model +//! for reduced memory usage and faster inference. +//! +//! Key characteristics: +//! - Decoder-only transformer architecture +//! - RoPE position embeddings +//! - Grouped Query Attention +//! - 8-bit quantization of weights +//! +//! References: +//! - [LLaMA2 Paper](https://arxiv.org/abs/2307.09288) +//! - [LLaMA2 Technical Report](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/) +//! + use super::llama2_c::{Cache, Config}; use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs index 947ab750cd..ac72162715 100644 --- a/candle-transformers/src/models/quantized_metavoice.rs +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -1,3 +1,19 @@ +//! Quantized MetaVoice model implementation. +//! +//! MetaVoice is a conditional text-to-speech model based on a transformer architecture. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Transformer-based autoregressive decoder +//! - Speaker conditioning +//! - Support for 8-bit quantization +//! - Key-value caching for efficient inference +//! - RMS normalization layers +//! +//! References: +//! - [MetaVoice Code](https://github.com/metavoiceio/metavoice) +//! + use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 0583810a0d..cdb687d573 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -1,3 +1,20 @@ +//! Mistral model implementation with quantization support. +//! +//! Mistral is a large language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Sliding window attention mechanism +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Mistral Paper](https://arxiv.org/abs/2310.06825) +//! - [Model Card](https://huggingface.co/mistralai/Mistral-7B-v0.1) +//! + use crate::quantized_nn::{linear_no_bias, Embedding, Linear, RmsNorm}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index fa72672a9e..8736544625 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -1,3 +1,16 @@ +//! Module containing quantized MixFormer model implementation. +//! +//! MixFormer is an efficient transformer variant for text generation that uses +//! mixture-of-experts and parallel attention/feed-forward blocks. +//! This implementation provides quantization for reduced memory usage. +//! +//! Key features: +//! - Parallel attention and feed-forward computation +//! - Rotary positional embeddings +//! - Optional key-value caching +//! - Support for 8-bit quantization +//! + use crate::quantized_nn::{layer_norm, linear, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_moondream.rs b/candle-transformers/src/models/quantized_moondream.rs index 1b125d9306..c1daffafe4 100644 --- a/candle-transformers/src/models/quantized_moondream.rs +++ b/candle-transformers/src/models/quantized_moondream.rs @@ -1,3 +1,18 @@ +//! Implementation of a quantized Moondream vision language model. +//! +//! Moondream is a lightweight vision-language model for image understanding and generation. +//! This module provides a quantized version for reduced memory usage and faster inference. +//! +//! Key features: +//! - ViT-based vision encoder +//! - Phi-2 text decoder model +//! - Memory efficient 8-bit quantization +//! - Optimized for efficient deployment +//! +//! References: +//! - [Moondream Model](https://github.com/vikhyat/moondream) +//! + use crate::models::moondream::{Config, VisionConfig}; use crate::models::quantized_mixformer::MixFormerSequentialForCausalLM as PhiModel; use crate::quantized_nn::{layer_norm, linear_b, Linear}; diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 056fcac2d1..44d8566b7b 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -1,3 +1,21 @@ +//! Quantized MPT model implementation. +//! +//! MPT (MPT-7B) is a causal transformer model series optimized for code generation. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Multi-Query Grouped Attention (MQA) +//! - Support for KV-caching +//! - Pre-computed ALiBi attention biases +//! - Support for 8-bit quantization +//! +//! References: +//! - [Replit Code Models](https://huggingface.co/replit/replit-code-v1_5-3b) +//! - [MPT-7B Implementation](https://github.com/mosaicml/llm-foundry) +//! +/// MPT model used by replit-code-v1_5-3b +/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py +/// use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; /// MPT model used by replit-code-v1_5-3b diff --git a/candle-transformers/src/models/quantized_phi.rs b/candle-transformers/src/models/quantized_phi.rs index 0ebf7f4d4b..b874ad94ea 100644 --- a/candle-transformers/src/models/quantized_phi.rs +++ b/candle-transformers/src/models/quantized_phi.rs @@ -1,3 +1,20 @@ +//! Phi2 model implementation with quantization support. +//! +//! Phi2 is a 2.7B parameter language model using scaled-up Transformer decoder architecture. +//! This implementation provides quantization for reduced memory and compute usage. +//! +//! Key characteristics: +//! - Partial attention with learned mixing to reduce quadratic costs +//! - Layer reuse for improved inference efficiency +//! - Linear transformations with scalar mixing +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Phi2 Paper](https://arxiv.org/abs/2309.05463) +//! - [Model Card](https://huggingface.co/microsoft/phi-2) +//! + use std::collections::HashMap; use candle::quantized::gguf_file; diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index 257ad98379..51a75f3895 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -1,3 +1,18 @@ +//! Phi3 model implementation with quantization support. +//! +//! Phi3 is a language model intended for research purposes. +//! This implementation provides quantization for reduced memory usage. +//! +//! Key characteristics: +//! - Multi-head attention +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for quantization +//! +//! References: +//! - [Model Card](https://huggingface.co/microsoft/phi-3) +//! + use std::collections::HashMap; use candle::quantized::gguf_file; diff --git a/candle-transformers/src/models/quantized_qwen2.rs b/candle-transformers/src/models/quantized_qwen2.rs index addfab2b04..c04da56925 100644 --- a/candle-transformers/src/models/quantized_qwen2.rs +++ b/candle-transformers/src/models/quantized_qwen2.rs @@ -1,3 +1,18 @@ +//! Qwen2 model implementation with quantization support. +//! +//! Qwen2 is a chat-optimized language model that supports 8-bit quantization +//! for reduced memory usage and faster inference. +//! +//! Key characteristics: +//! - Group Query Attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Model Card](https://huggingface.co/Qwen/Qwen2) +//! + use crate::{quantized_nn::RmsNorm, utils::repeat_kv}; use candle::{ quantized::{gguf_file, QMatMul}, diff --git a/candle-transformers/src/models/quantized_recurrent_gemma.rs b/candle-transformers/src/models/quantized_recurrent_gemma.rs index c28064da6b..e40daa1f33 100644 --- a/candle-transformers/src/models/quantized_recurrent_gemma.rs +++ b/candle-transformers/src/models/quantized_recurrent_gemma.rs @@ -1,3 +1,20 @@ +//! Recurrent Gemma model implementation with quantization support. +//! +//! Gemma is a large language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Recurrent blocks with gated recurrent units +//! - Convolution and attention blocks +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Gemma Paper](https://arxiv.org/abs/2401.06751) +//! - [Model Card](https://ai.google.dev/gemma) +//! + use crate::quantized_nn::{linear_b as linear, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_rwkv_v5.rs b/candle-transformers/src/models/quantized_rwkv_v5.rs index c41d7b4e08..cc5204bf24 100644 --- a/candle-transformers/src/models/quantized_rwkv_v5.rs +++ b/candle-transformers/src/models/quantized_rwkv_v5.rs @@ -1,3 +1,20 @@ +//! RWKV v5 model implementation with quantization support. +//! +//! RWKV v5 is an attention-free language model optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - GroupNorm layer normalization +//! - Time-mixing layers +//! - State-based sequential processing +//! - Support for 8-bit quantization +//! +//! References: +//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v5 Architecture](https://www.rwkv.com/v5) +//! + use crate::{ quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, quantized_var_builder::VarBuilder, diff --git a/candle-transformers/src/models/quantized_rwkv_v6.rs b/candle-transformers/src/models/quantized_rwkv_v6.rs index 81150c3ec0..91288c2e61 100644 --- a/candle-transformers/src/models/quantized_rwkv_v6.rs +++ b/candle-transformers/src/models/quantized_rwkv_v6.rs @@ -1,3 +1,21 @@ +//! RWKV v6 model implementation with quantization support. +//! +//! RWKV is a linear attention model that combines the efficiency of RNNs +//! with the parallelizable training of Transformers. Version 6 builds on previous +//! versions with further optimizations. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - Time mixing layers +//! - Channel mixing layers +//! - RMSNorm for normalization +//! - Support for 8-bit quantization +//! +//! References: +//! - [RWKV Architecture](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v6 Release](https://huggingface.co/BlinkDL/rwkv-6) +//! + use crate::{ quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, quantized_var_builder::VarBuilder, diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index da4475220f..d74ed743d8 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,3 +1,18 @@ +//! Module for quantized StableLM implementation. +//! +//! StableLM is a series of open-source large language models +//! optimized for performance and stability. This implementation +//! provides quantization support for efficient model deployment. +//! +//! Key characteristics: +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [StableLM](https://github.com/Stability-AI/StableLM) +//! + use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 88224d2da3..9f770d69d9 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -1,5 +1,19 @@ -// T5 Text Model, quantized version -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +//! T5 model implementation with quantization support. +//! +//! T5 is an encoder-decoder model pre-trained on a multi-task mixture of supervised +//! and unsupervised tasks. This implementation provides quantization for reduced +//! memory and compute requirements. +//! +//! Key characteristics: +//! - Encoder-decoder architecture +//! - Layer normalization +//! - Relative positional encodings +//! - Support for 8-bit quantization +//! +//! References: +//! - [T5 Paper](https://arxiv.org/abs/1910.10683) +//! - [Model Card](https://huggingface.co/t5-base) +//! - Original model from [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating}; use crate::models::with_tracing::QMatMul; diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 187ea98a10..8dbca36b3e 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -1,3 +1,20 @@ +//! Qwen2 model implementation with quantization support. +//! +//! Qwen2 is a large language model from Alibaba optimized for efficiency. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Streaming decode support +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for 8-bit quantization +//! +//! References: +//! - [Qwen2 Model](https://huggingface.co/Qwen/Qwen2-7B) +//! - [Model Card](https://huggingface.co/Qwen/Qwen2-7B) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs index 8d1d2f70f4..40e0279748 100644 --- a/candle-transformers/src/models/qwen2_moe.rs +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -1,3 +1,21 @@ +//! Qwen2 model implementation with Mixture of Experts support. +//! +//! Qwen2 is a large language model using sparse Mixture of Experts (MoE). +//! This implementation provides support for sparsely activated MoE layers. +//! +//! Key characteristics: +//! - Mixture of Experts architecture +//! - Sparse expert activation +//! - Shared expert routing mechanism +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! +//! References: +//! - [Qwen2 Paper](https://arxiv.org/abs/2401.08985) +//! - [Model Card](https://huggingface.co/Qwen/Qwen2-7B-beta) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/recurrent_gemma.rs b/candle-transformers/src/models/recurrent_gemma.rs index 24d2b7e38b..d6a029babc 100644 --- a/candle-transformers/src/models/recurrent_gemma.rs +++ b/candle-transformers/src/models/recurrent_gemma.rs @@ -1,5 +1,22 @@ -// This implementation is based on the python version from huggingface/transformers. -// https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2 +//! Recurrent Gemma model implementation +//! +//! Recurrent Gemma is a version of the Gemma language model that incorporates recurrent memory. +//! This allows the model to maintain state between predictions and have longer-range memory. +//! +//! Key characteristics: +//! - Real-gated linear recurrent units (RGLRU) +//! - 1D convolution for local context +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Grouped query attention +//! +//! References: +//! - [Gemma: Open Models Based on Gemini Technology](https://blog.google/technology/developers/gemma-open-models/) +//! - [Recurrent Memory model architecture](https://arxiv.org/abs/2402.00441) +//! +//! This implementation is based on the python version from huggingface/transformers. +//! https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2 +//! use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{linear_b as linear, Linear, VarBuilder}; use std::sync::Arc; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs index 34016e5b45..a6ffce0d6d 100644 --- a/candle-transformers/src/models/repvgg.rs +++ b/candle-transformers/src/models/repvgg.rs @@ -2,6 +2,17 @@ //! //! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 //! https://arxiv.org/abs/2101.03697 +//! +//! Key characteristics: +//! - Efficient inference architecture through structural reparameterization +//! - Single 3x3 conv layer after fusing 3x3 branch, 1x1 branch and identity branch +//! - Different configurations including a0-a2, b0-b3 and variants with group convolutions +//! - High accuracy with VGG-like plain architecture and training +//! +//! References: +//! - [RepVGG Paper](https://arxiv.org/abs/2101.03697) +//! - [Official Implementation](https://github.com/DingXiaoH/RepVGG) +//! use candle::{Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/resnet.rs b/candle-transformers/src/models/resnet.rs index 30029a0bd1..31395c8f84 100644 --- a/candle-transformers/src/models/resnet.rs +++ b/candle-transformers/src/models/resnet.rs @@ -1,7 +1,15 @@ -//! ResNet implementation. +//! # ResNet Implementation //! -//! See "Deep Residual Learning for Image Recognition" He et al. 2015 -//! +//! Implementation of ResNet architectures as described in the paper: +//! +//! ## Reference +//! +//! [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) +//! He et al. (2015) +//! +//! This paper introduced ResNet, a deep neural network architecture that utilizes +//! skip connections ("residual connections") to enable training of very deep networks. + use candle::{Result, D}; use candle_nn::{batch_norm, Conv2d, Func, VarBuilder}; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index eb51273196..6390f886d2 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -1,3 +1,20 @@ +//! RWKV v5 model implementation. +//! +//! RWKV is an RNN with transformer-level performance that can be implemented +//! as either a transformer or RNN. +//! +//! Key characteristics: +//! - Time-mix attention mechanism +//! - Channel-mix feed-forward network +//! - Linear attention +//! - Group normalization +//! - Token shift mechanism +//! +//! References: +//! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM) +//! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main) +//! + use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/rwkv_v6.rs b/candle-transformers/src/models/rwkv_v6.rs index 457c351ec1..c75aa885e9 100644 --- a/candle-transformers/src/models/rwkv_v6.rs +++ b/candle-transformers/src/models/rwkv_v6.rs @@ -1,3 +1,19 @@ +//! RWKV v6 model implementation. +//! +//! RWKV is an RNN with transformer-like performance. +//! Version 6 introduces refinements to the architecture. +//! +//! Key characteristics: +//! - Linear attention mechanism +//! - Time-mixing for temporal dependencies +//! - Group normalization +//! - Feed forward gating +//! - State recycling for efficient inference +//! +//! References: +//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM) +//! + use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{IndexOp, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs index 260ceb3a84..9e0461bc70 100644 --- a/candle-transformers/src/models/segformer.rs +++ b/candle-transformers/src/models/segformer.rs @@ -1,3 +1,19 @@ +//! Segformer model implementation for semantic segmentation and image classification. +//! +//! Segformer is a transformer-based model designed for vision tasks. It uses a hierarchical +//! structure that progressively generates features at different scales. +//! +//! Key characteristics: +//! - Efficient self-attention with sequence reduction +//! - Hierarchical feature generation +//! - Mix-FFN for local and global feature interaction +//! - Lightweight all-MLP decode head +//! +//! References: +//! - [SegFormer Paper](https://arxiv.org/abs/2105.15203) +//! - [Model Card](https://huggingface.co/nvidia/mit-b0) +//! + use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; use candle::{Module, ModuleT, Result, Tensor, D}; use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs index c54493d296..3e85fe3594 100644 --- a/candle-transformers/src/models/segment_anything/mod.rs +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -1,3 +1,11 @@ +//! Segment Anything Model (SAM) +//! +//! SAM is an architecture for image segmentation, capable of segmenting any object +//! in an image based on prompts like points or boxes. +//! +//! - [GH Link](https://github.com/facebookresearch/segment-anything) +//! - [Paper](https://arxiv.org/abs/2304.02643) +//! pub use crate::models::with_tracing::Linear; use candle::{Result, Tensor}; use candle_nn::{Module, VarBuilder}; diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 63b6635dc1..2046401428 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -1,3 +1,11 @@ +//! Siglip model implementation. +//! +//! Siglip architecture combining vision and language for zero-shot tasks. +//! +//! References: +//! - [Model Card](https://huggingface.co/google/siglip-base-patch16-224) +//! + use crate::models::clip::div_l2_norm; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 37f4cdbf59..d3e2032b6e 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -1,3 +1,12 @@ +//! Stable Diffusion +//! +//! Stable Diffusion is a latent text-to-image diffusion model capable of +//! generating photo-realistic images given any text input. +//! +//! - [Original Repository](https://github.com/CompVis/stable-diffusion) +//! - [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5) +//! + pub mod attention; pub mod clip; pub mod ddim; diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 2b46e8a12f..c5dbd3958d 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -1,3 +1,18 @@ +//! StableLM model implementation. +//! +//! StableLM is a family of language models trained by Stability AI. +//! This implementation supports the StableLM architecture. +//! +//! Key characteristics: +//! - Grouped query attention (GQA) +//! - Layer normalization +//! - Rotary positional embeddings (RoPE) +//! - Support for different model sizes (3B, 7B) +//! +//! References: +//! - [Model Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index d108d06235..833cb0679f 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -1,3 +1,20 @@ +//! StarCoder model implementation with quantization support. +//! +//! StarCoder is a large language model optimized for code generation. +//! This implementation provides quantization for reduced memory and compute. +//! +//! Key characteristics: +//! - Causal self-attention mechanism +//! - Multi-query attention (MQA) +//! - LayerNorm for normalization +//! - Absolute positional embeddings +//! - Support for 8-bit quantization +//! +//! References: +//! - [StarCoder Paper](https://arxiv.org/abs/2305.06161) +//! - [Model Card](https://huggingface.co/bigcode/starcoder) +//! + #![allow(unused)] use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder}; diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 9d933fade5..7c1d2b5ae9 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -1,3 +1,20 @@ +//! Stella v5 model implementation. +//! +//! Stella is a dense text embedding model optimized for retrieval and similarity tasks. +//! This implementation provides support for multiple embedding dimensions. +//! +//! Key characteristics: +//! - Dense text embeddings optimized for similarity search +//! - Multiple output dimension support (256 to 8192) +//! - Grouped query attention (GQA) +//! - RMSNorm for layer normalization +//! - Rotary positional embeddings (RoPE) +//! +//! References: +//! - [MRL Framework](https://arxiv.org/abs/2205.13147) +//! - [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) +//! + use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8ba0c1c1d7..9da0c1afec 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,5 +1,19 @@ -// T5 Text Model -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +//! T5 model implementation. +//! +//! T5 (Text-to-Text Transfer Transformer) is a unified text-to-text transformer model. +//! This implementation follows the original model architecture. +//! +//! Key characteristics: +//! - Text-to-text framework +//! - Relative positional embeddings +//! - T5-specific layer normalization +//! - Encoder-decoder architecture +//! - Support for sequence-to-sequence tasks +//! +//! References: +//! - [T5 Paper](https://arxiv.org/abs/1910.10683) +//! - [HuggingFace T5](https://huggingface.co/docs/transformers/model_doc/t5) +//! - [GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) use crate::models::with_tracing::Embedding; use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/trocr.rs b/candle-transformers/src/models/trocr.rs index d17eda17bf..88418dd3ca 100644 --- a/candle-transformers/src/models/trocr.rs +++ b/candle-transformers/src/models/trocr.rs @@ -1,3 +1,19 @@ +//! TrOCR model implementation. +//! +//! TrOCR is a Transformer-based OCR model that uses a Vision Transformer encoder +//! and a BART-like decoder for optical character recognition. +//! +//! Key characteristics: +//! - Vision Transformer encoder for image processing +//! - BART-style decoder for text generation +//! - Learned positional embeddings +//! - Layer normalization and self-attention +//! +//! References: +//! - [Paper](https://arxiv.org/abs/2109.10282) +//! - [Model Card](https://huggingface.co/microsoft/trocr-base-handwritten) +//! + use crate::models::vit::{Config, Embeddings, Encoder}; use candle::{DType, Result, Tensor}; use candle_nn::{ diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 010643c8d2..57f9ae67bb 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -1,7 +1,18 @@ //! VGG-16 model implementation. //! -//! See Very Deep Convolutional Networks for Large-Scale Image Recognition -//! +//! VGG-16 is a convolutional neural network architecture. It consists of 13 +//! convolutional layers followed by 3 fully connected layers. +//! +//! Key characteristics: +//! - Conv layers with 3x3 filters +//! - Max pooling after every 2-3 conv layers +//! - Three fully connected layers of 4096, 4096, 1000 units +//! - ReLU activation and dropout +//! +//! References: +//! - [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556) +//! + use candle::{ModuleT, Result, Tensor}; use candle_nn::{FuncT, VarBuilder}; diff --git a/candle-transformers/src/models/vit.rs b/candle-transformers/src/models/vit.rs index 3be72bf599..49ab463017 100644 --- a/candle-transformers/src/models/vit.rs +++ b/candle-transformers/src/models/vit.rs @@ -1,3 +1,20 @@ +//! Vision Transformer (ViT) implementation. +//! +//! Vision Transformer applies transformer architecture to image classification +//! by splitting images into patches and processing them as a sequence. +//! +//! Key characteristics: +//! - Image patches as sequence tokens +//! - Self-attention between patches +//! - Position embeddings +//! - CLS token for classification +//! - Layer normalization +//! +//! References: +//! - [ViT Paper](https://arxiv.org/abs/2010.11929) +//! - [Model Card](https://huggingface.co/google/vit-base-patch16-224) +//! + use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear}; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 8028cf2c66..6123884ae4 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -1,3 +1,11 @@ +//! Whisper Model Implementation +//! +//! Whisper is an automatic speech recognition (ASR) system trained on large amounts +//! of multilingual and multitask supervised data collected from the web. +//! +//! - [GH Link](https://github.com/openai/whisper) +//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py) +//! pub mod audio; pub mod model; pub mod quantized_model; diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 7b076f0610..9bb37a3bcc 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,3 +1,12 @@ +//! Würstchen Efficient Diffusion Model +//! +//! Würstchen is an efficient diffusion model architecture for generating images using +//! a two-stage approach with a small decoder and prior network. +//! +//! - [Paper Link](https://openreview.net/pdf?id=gU58AyJlYz) +//! - [GH Link](https://github.com/dome272/Wuerstchen) +//! - [Reference Implementation](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py) +//! pub mod attention_processor; pub mod common; pub mod ddpm; diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index df78ddce7a..047ea77046 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,4 +1,18 @@ -/// https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py +//! Yi model implementation. +//! +//! Yi is a decoder-only large language model trained by 01.AI. +//! It follows a standard transformer architecture similar to Llama. +//! +//! Key characteristics: +//! - Multi-head attention with rotary positional embeddings +//! - RMS normalization +//! - SwiGLU activation in feed-forward layers +//! - Grouped-query attention for efficient inference +//! +//! References: +//! - [Yi Model](https://huggingface.co/01-ai/Yi-6B) +//! - [Hugging Face](https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py) + use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; From 00d8a0c178f588b6454c02e66b709917628c2bae Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 15 Nov 2024 16:46:55 +0100 Subject: [PATCH 03/15] Remove some unused macros. (#2618) * Remove some unused macros. * More unused fixes. --- candle-examples/Cargo.toml | 2 +- candle-examples/examples/reinforcement-learning/ddpg.rs | 8 +++++--- .../examples/reinforcement-learning/gym_env.rs | 1 - candle-examples/examples/reinforcement-learning/main.rs | 2 -- .../examples/reinforcement-learning/policy_gradient.rs | 2 +- .../examples/reinforcement-learning/vec_gym_env.rs | 5 +++-- candle-pyo3/Cargo.toml | 2 +- candle-transformers/src/models/encodec.rs | 4 ++-- candle-transformers/src/models/starcoder2.rs | 1 - 9 files changed, 13 insertions(+), 14 deletions(-) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0c1219d760..df85302d6d 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } palette = { version = "0.7.6", optional = true } enterpolation = { version = "0.2.1", optional = true} -pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } safetensors = { workspace = true } diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs index 5309eaf669..389caac1a1 100644 --- a/candle-examples/examples/reinforcement-learning/ddpg.rs +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::fmt::Display; use candle::{DType, Device, Error, Module, Result, Tensor, Var}; use candle_nn::{ @@ -167,6 +166,7 @@ fn track( Ok(()) } +#[allow(unused)] struct Actor<'a> { varmap: VarMap, vb: VarBuilder<'a>, @@ -211,7 +211,7 @@ impl Actor<'_> { let target_network = make_network("target-actor")?; // this sets the two networks to be equal to each other using tau = 1.0 - track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0); + track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0)?; Ok(Self { varmap, @@ -244,6 +244,7 @@ impl Actor<'_> { } } +#[allow(unused)] struct Critic<'a> { varmap: VarMap, vb: VarBuilder<'a>, @@ -287,7 +288,7 @@ impl Critic<'_> { let target_network = make_network("target-critic")?; // this sets the two networks to be equal to each other using tau = 1.0 - track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0); + track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0)?; Ok(Self { varmap, @@ -322,6 +323,7 @@ impl Critic<'_> { } } +#[allow(unused)] #[allow(clippy::upper_case_acronyms)] pub struct DDPG<'a> { actor: Actor<'a>, diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index a2b6652f87..05518b1bf1 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -1,4 +1,3 @@ -#![allow(unused)] //! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym) use candle::{Device, Result, Tensor}; use pyo3::prelude::*; diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index 1a25cd93ef..34115b228a 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -1,5 +1,3 @@ -#![allow(unused)] - #[cfg(feature = "mkl")] extern crate intel_mkl_src; diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs index 6c355fe62f..3ae2617d16 100644 --- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs +++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs @@ -14,7 +14,7 @@ fn new_model( ) -> Result<(impl Module, VarMap)> { let input_size = input_shape.iter().product(); - let mut varmap = VarMap::new(); + let varmap = VarMap::new(); let var_builder = VarBuilder::from_varmap(&varmap, dtype, device); let model = seq() diff --git a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs index e382ad76da..a985d9e978 100644 --- a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs @@ -1,9 +1,8 @@ -#![allow(unused)] //! Vectorized version of the gym environment. use candle::{DType, Device, Result, Tensor}; use pyo3::prelude::*; -use pyo3::types::PyDict; +#[allow(unused)] #[derive(Debug)] pub struct Step { pub obs: Tensor, @@ -11,6 +10,7 @@ pub struct Step { pub is_done: Tensor, } +#[allow(unused)] pub struct VecGymEnv { env: PyObject, action_space: usize, @@ -21,6 +21,7 @@ fn w(res: PyErr) -> candle::Error { candle::Error::wrap(res) } +#[allow(unused)] impl VecGymEnv { pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result { Python::with_gil(|py| { diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 2776a3f77c..d91619fbb3 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -20,7 +20,7 @@ candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py311"] } [build-dependencies] pyo3-build-config = "0.22" diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index a8d509ce8b..517b9b1d7e 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -4,9 +4,8 @@ //! //! Based on implementation from [huggingface/transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py) -#![allow(unused)] use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D}; -use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder}; +use candle_nn::{conv1d, Conv1d, ConvTranspose1d, VarBuilder}; // Encodec Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py @@ -226,6 +225,7 @@ impl candle::CustomOp2 for CodebookEncode { } // https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340 +#[allow(unused)] #[derive(Clone, Debug)] pub struct EuclideanCodebook { inited: Tensor, diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index 833cb0679f..0df5990b89 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -15,7 +15,6 @@ //! - [Model Card](https://huggingface.co/bigcode/starcoder) //! -#![allow(unused)] use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder}; use std::sync::Arc; From a3f200e36991418c25cddef0e09c426deea90606 Mon Sep 17 00:00:00 2001 From: zachcp Date: Sat, 16 Nov 2024 03:09:17 -0500 Subject: [PATCH 04/15] Module Docs (#2620) * update bert docs * update based * update bigcode * add pixtral * add flux as well --- candle-transformers/src/models/based.rs | 6 +- candle-transformers/src/models/bert.rs | 59 ++++++++++++++++++- candle-transformers/src/models/bigcode.rs | 18 +++++- candle-transformers/src/models/flux/mod.rs | 22 ++++++- candle-transformers/src/models/pixtral/mod.rs | 31 ++++++++++ 5 files changed, 126 insertions(+), 10 deletions(-) diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs index c54ff96629..1dbd6dc2a6 100644 --- a/candle-transformers/src/models/based.rs +++ b/candle-transformers/src/models/based.rs @@ -1,9 +1,9 @@ //! Based from the Stanford Hazy Research group. //! //! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 -//! - [Arxiv](https://arxiv.org/abs/2402.18668) -//! - [Github](https://github.com/HazyResearch/based) -//! +//! - Simple linear attention language models balance the recall-throughput tradeoff. [Arxiv](https://arxiv.org/abs/2402.18668) +//! - [Github Rep](https://github.com/HazyResearch/based) +//! - [Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based) use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{ diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index a7db075cbb..808ca41557 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,8 +1,61 @@ //! BERT (Bidirectional Encoder Representations from Transformers) //! -//! See "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", Devlin et al. 2018 -//! - [Arxiv](https://arxiv.org/abs/1810.04805) -//! - [Github](https://github.com/google-research/bert) +//! Bert is a general large language model that can be used for various language tasks: +//! - Compute sentence embeddings for a prompt. +//! - Compute similarities between a set of sentences. +//! - [Arxiv](https://arxiv.org/abs/1810.04805) "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" +//! - Upstream [Github repo](https://github.com/google-research/bert). +//! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code +//! +//! ```no_run +//! // for sentence embeddings +//! # use candle_core::Tensor; +//! # use candle_nn::{VarBuilder, Module}; +//! # fn main() -> candle_core::Result<()> { +//! # let model = todo!(); +//! # let prompt = "Here is a test sentence"; +//! let embeddings = model.forward(prompt)?; +//! // Returns tensor of shape [1, 7, 384] +//! println!("{embeddings}"); +//! # Ok(()) +//! # } +//! +//! // Different models can be loaded using the model ID +//! # use candle_core::Tensor; +//! # use candle_nn::{VarBuilder, Module}; +//! # fn main() -> candle_core::Result<()> { +//! # let vb = todo!(); +//! # let config = todo!(); +//! let model = BertModel::load(vb, &config )?; +//! # Ok(()) +//! # } +//! +//! // Gelu approximation +//! // You can get a speedup by configuring the model +//! // to use an approximation of the gelu activation: +//! # use candle_core::Tensor; +//! # use candle_nn::{VarBuilder, Module}; +//! # fn main() -> candle_core::Result<()> { +//! # let mut config = todo!(); +//! config.hidden_act = HiddenAct::GeluApproximate; +//! # Ok(()) +//! # } +//! +//! // Similarities +//! // Bert can compute sentence embeddings which can then be used to calculate +//! // semantic similarities between sentences through cosine similarity scoring. +//! // The sentence embeddings are computed using average pooling across all tokens. +//! # use candle_core::Tensor; +//! # use candle_nn::{VarBuilder, Module}; +//! # fn main() -> candle_core::Result<()> { +//! # let model = todo!(); +//! let sentence1 = "The new movie is awesome"; +//! let sentence2 = "The new movie is so great"; +//! let emb1 = model.forward(sentence1)?; +//! let emb2 = model.forward(sentence2)?; +//! # Ok(()) +//! # } +//! ``` //! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index 8ed1462b1c..c5dcb6bc80 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,9 +1,25 @@ //! BigCode implementation in Rust based on the GPT-BigCode model. //! -//! See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 +//! [StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM +//! model specialized to code generation. The initial model was trained on 80 +//! programming languages. See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 //! - [Arxiv](https://arxiv.org/abs/2305.06161) //! - [Github](https://github.com/bigcode-project/starcoder) //! +//! ## Running some example +//! +//! ```bash +//! cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64" +//! +//! > fn fact(n: u64) -> u64 { +//! > if n == 0 { +//! > 1 +//! > } else { +//! > n * fact(n - 1) +//! > } +//! > } +//! ``` +//! use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index 8eb928f557..064c5130f5 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,10 +1,26 @@ //! Flux Model //! -//! Flux is a series of text-to-image generation models based on diffusion transformers. +//! Flux is a 12B rectified flow transformer capable of generating images from text descriptions. //! -//! - [GH Link](https://github.com/black-forest-labs/flux) -//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! - [Hugging Face Model](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +//! - [GitHub Repository](https://github.com/black-forest-labs/flux) +//! - [Blog Post](https://blackforestlabs.ai/announcing-black-forest-labs/) //! +//! # Usage +//! +//! ```bash +//! cargo run --features cuda \ +//! --example flux -r -- \ +//! --height 1024 --width 1024 \ +//! --prompt "a rusty robot walking on a beach holding a small torch, \ +//! the robot has the word \"rust\" written on it, high quality, 4k" +//! ``` +//! +//!
+//! +//!
+//! + use candle::{Result, Tensor}; pub trait WithForward { diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs index 53f9ef9182..e722ffcfd2 100644 --- a/candle-transformers/src/models/pixtral/mod.rs +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -4,7 +4,38 @@ //! using images paired with text descriptions. //! //! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) +//! - [Blog Post](https://mistral.ai/news/pixtral-12b/) - +//! - [HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) - +//! - [HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b). //! +//! # Example +//! +//!
+//! +//!
+//! +//! ```bash +//! cargo run --profile=release-with-debug \ +//! --features cuda \ +//! --example pixtral -- \ +//! --image candle-examples/examples/flux/assets/flux-robot.jpg +//! ``` +//! +//! ```txt +//! Describe the image. +//! +//! The image depicts a charming, rustic robot standing on a sandy beach at sunset. +//! The robot has a vintage, steampunk aesthetic with visible gears and mechanical +//! parts. It is holding a small lantern in one hand, which emits a warm glow, and +//! its other arm is extended forward as if reaching out or guiding the way. The +//! robot's body is adorned with the word "RUST" in bright orange letters, adding to +//! its rustic theme. +//! +//! The background features a dramatic sky filled with clouds, illuminated by the +//! setting sun, casting a golden hue over the scene. Gentle waves lap against the +//! shore, creating a serene and picturesque atmosphere. The overall mood of the +//! image is whimsical and nostalgic, evoking a sense of adventure and tranquility. +//! ``` pub mod llava; pub mod vision_model; From 12d7e7b1450f0c3f87c3cce3a2a1dd1674cb8fd7 Mon Sep 17 00:00:00 2001 From: zachcp Date: Sun, 17 Nov 2024 14:27:24 -0500 Subject: [PATCH 05/15] More Model Module Docs (#2623) * dinov2 * add another example * ad dinov2reg4 * eva2 * efficientvit * moondream * update t5 * update t5 * rwkv * stable diffusion docs * add wasm link * add segment_anything * adjsut for clippy * ignore bertdoc * dinov2 ignore * update block to be text * remove the rust blocks for the moment * bump python to 3.11 * add a setup-python step * add py311 to test as well --- .github/workflows/rust-ci.yml | 6 +++ candle-transformers/src/models/bert.rs | 50 ------------------- candle-transformers/src/models/dinov2.rs | 38 +++++++++++++- candle-transformers/src/models/dinov2reg4.rs | 31 ++++++++++-- .../src/models/efficientvit.rs | 37 ++++++++++++-- candle-transformers/src/models/eva2.rs | 28 +++++++++-- candle-transformers/src/models/moondream.rs | 30 ++++++++++- candle-transformers/src/models/rwkv_v5.rs | 20 +++++++- candle-transformers/src/models/rwkv_v6.rs | 21 ++++++-- .../src/models/segment_anything/mod.rs | 29 +++++++++-- .../src/models/stable_diffusion/mod.rs | 30 +++++++++++ candle-transformers/src/models/t5.rs | 43 ++++++++++++++++ 12 files changed, 291 insertions(+), 72 deletions(-) diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index ee480c474c..db25503079 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -16,6 +16,9 @@ jobs: rust: [stable] steps: - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -35,6 +38,9 @@ jobs: rust: [stable] steps: - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" - uses: actions-rs/toolchain@v1 with: profile: minimal diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 808ca41557..da8734160a 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -7,56 +7,6 @@ //! - Upstream [Github repo](https://github.com/google-research/bert). //! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code //! -//! ```no_run -//! // for sentence embeddings -//! # use candle_core::Tensor; -//! # use candle_nn::{VarBuilder, Module}; -//! # fn main() -> candle_core::Result<()> { -//! # let model = todo!(); -//! # let prompt = "Here is a test sentence"; -//! let embeddings = model.forward(prompt)?; -//! // Returns tensor of shape [1, 7, 384] -//! println!("{embeddings}"); -//! # Ok(()) -//! # } -//! -//! // Different models can be loaded using the model ID -//! # use candle_core::Tensor; -//! # use candle_nn::{VarBuilder, Module}; -//! # fn main() -> candle_core::Result<()> { -//! # let vb = todo!(); -//! # let config = todo!(); -//! let model = BertModel::load(vb, &config )?; -//! # Ok(()) -//! # } -//! -//! // Gelu approximation -//! // You can get a speedup by configuring the model -//! // to use an approximation of the gelu activation: -//! # use candle_core::Tensor; -//! # use candle_nn::{VarBuilder, Module}; -//! # fn main() -> candle_core::Result<()> { -//! # let mut config = todo!(); -//! config.hidden_act = HiddenAct::GeluApproximate; -//! # Ok(()) -//! # } -//! -//! // Similarities -//! // Bert can compute sentence embeddings which can then be used to calculate -//! // semantic similarities between sentences through cosine similarity scoring. -//! // The sentence embeddings are computed using average pooling across all tokens. -//! # use candle_core::Tensor; -//! # use candle_nn::{VarBuilder, Module}; -//! # fn main() -> candle_core::Result<()> { -//! # let model = todo!(); -//! let sentence1 = "The new movie is awesome"; -//! let sentence2 = "The new movie is so great"; -//! let emb1 = model.forward(sentence1)?; -//! let emb2 = model.forward(sentence2)?; -//! # Ok(()) -//! # } -//! ``` -//! use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index df8834d1f7..4d46941f8b 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -1,8 +1,42 @@ //! Implementation of the DINOv2 models from Meta Research. //! -//! See: -//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! This module implements the DINOv2 vision transformer model from Meta AI Research. +//! DINOv2 is a self-supervised learning model that can learn visual features +//! without using any labeled data. See: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) //! +//! ## Running an example with color map and CUDA +//! +//! ```bash +//! cargo run \ +//! --features cuda,depth_anything_v2 \ +//! --package candle-examples \ +//! --example depth_anything_v2 \ +//! -- --color-map \ +//! --image candle-examples/examples/yolo-v8/assets/bike.jpg +//! ``` +//! +//! ## Running as an ImageNet classifier +//! +//! The model returns the probability for the image to belong to each of the 1000 ImageNet categories. +//! +//!
+//! +//!
+//! +//! ```bash +//! cargo run \ +//! --example dinov2 \ +//! --release \ +//! -- --image candle-examples/examples/yolo-v8/assets/bike.jpg +//! +//! > mountain bike, all-terrain bike, off-roader: 43.67% +//! > bicycle-built-for-two, tandem bicycle, tandem: 33.20% +//! > crash helmet : 13.23% +//! > unicycle, monocycle : 2.44% +//! > maillot : 2.42% +//! ``` +//! + use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs index 0d2320e14c..549f2c3ce5 100644 --- a/candle-transformers/src/models/dinov2reg4.rs +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -1,9 +1,34 @@ //! Implementation of the DINOv2 revision (4 regularization) //! -//! See: -//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2) +//! The DINOv2-reg4 model is a variant of DINOv2 that adds 4 regularization tokens to the +//! original architecture. This implementation is specifically trained for plant species +//! classification on the PlantCLEF2024 dataset with 7,806 classes. //! -//! This code implements the regularization tokens version with 4 regularization tokens. +//! - [Paper](https://arxiv.org/abs/2309.16588). DINOv2: Learning Robust Visual Features without Supervision +//! - [GH Repo](https://github.com/facebookresearch/dinov2) +//! +//! # Example +//! +//! ```bash +//! # Download classes names and a plant picture to identify +//! # see candle/examples/dinov2reg4 for full code. +//! +//! # Perform inference +//! cargo run \ +//! --example dinov2reg4 \ +//! --release -- \ +//! --image +//! +//! > Orchis simia Lam. : 45.55% +//! > Orchis × bergonii Nanteuil: 9.80% +//! > Orchis italica Poir. : 9.66% +//! > Orchis × angusticruris Franch.: 2.76% +//! > Orchis × bivonae Tod. : 2.54% +//! ``` +//! +//!
+//! +//!
//! use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/efficientvit.rs b/candle-transformers/src/models/efficientvit.rs index 9724f702a6..4c231d7679 100644 --- a/candle-transformers/src/models/efficientvit.rs +++ b/candle-transformers/src/models/efficientvit.rs @@ -1,9 +1,40 @@ //! EfficientViT (MSRA) inference implementation based on timm. //! -//! See ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027) +//! This crate provides an implementation of the EfficientViT model from Microsoft Research Asia +//! for efficient image classification. The model uses cascaded group attention modules +//! to achieve strong performance while maintaining low memory usage. +//! +//! The model was originally described in the paper: +//! ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027) +//! +//! This implementation is based on the reference implementation from +//! [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py). +//! +//! # Example Usage +//! +//! This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference. +//! The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes. +//! +//! +//! ```bash +//! cargo run +//! --example efficientvit \ +//! --release -- \ +//! --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1 +//! +//! > loaded image Tensor[dims 3, 224, 224; f32] +//! > model built +//! > mountain bike, all-terrain bike, off-roader: 69.80% +//! > unicycle, monocycle : 13.03% +//! > bicycle-built-for-two, tandem bicycle, tandem: 9.28% +//! > crash helmet : 2.25% +//! > alp : 0.46% +//! ``` +//! +//!
+//! +//!
//! -//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py) - use candle::{Result, Tensor, D}; use candle_nn::{ batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func, diff --git a/candle-transformers/src/models/eva2.rs b/candle-transformers/src/models/eva2.rs index ee84cca43c..9e31f58c73 100644 --- a/candle-transformers/src/models/eva2.rs +++ b/candle-transformers/src/models/eva2.rs @@ -1,9 +1,31 @@ //! EVA-2 inference implementation. //! -//! See ["EVA-02: A Visual Representation for Neon Genesis"](https://arxiv.org/abs/2303.11331) +//! EVA-02 is a computer vision model that can be used as an ImageNet classifier. +//! The model returns the probability for an image to belong to each of the 1000 +//! ImageNet categories. +//! +//! - [Paper](https://arxiv.org/abs/2303.11331). EVA-02: A Visual Representation for Neon Genesis +//! - [Code](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py) +//! +//! # Example +//! +//! ```bash +//! cargo run \ +//! --example eva2 \ +//! --release -- \ +//! --image candle-examples/examples/yolo-v8/assets/bike.jpg +//! +//! > mountain bike, all-terrain bike, off-roader: 37.09% +//! > maillot : 8.30% +//! > alp : 2.13% +//! > bicycle-built-for-two, tandem bicycle, tandem: 0.84% +//! > crash helmet : 0.73% +//! ``` +//! +//!
+//! +//!
//! -//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py) - use candle::{IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index d351d7c019..a9dc9b7dc2 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -1,13 +1,39 @@ //! MoonDream Model vision-to-text //! +//! +//! Moondream is a computer-vision model that can answer real-world questions about images. +//! It's lightweight with only 1.6B parameters, enabling it to run on mobile phones and edge devices. +//! [MoonDream Original Implementation](https://github.com/vikhyat/moondream) +//! //! The model consists of: //! - Vision encoder using a ViT-style architecture //! - Text decoder based on Microsoft's Phi model //! - Vision projection module to align vision and text embeddings //! -//! References: -//! - [MoonDream Original Implementation](https://github.com/vikhyat/moondream) +//! # Examples +//! +//! +//! +//! ```bash +//! # download an example image +//! wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg +//! +//! # Now you can run Moondream from the `candle-examples` crate: +//! cargo run --example moondream \ +//! --release -- \ +//! --prompt "What is the girl eating?" +//! --image "./demo-1.jpg" //! +//! > avavx: false, neon: true, simd128: false, f16c: false +//! > temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64 +//! > retrieved the files in 3.395583ms +//! > Running on CPU, to run on GPU(metal), build this example with `--features metal` +//! > loaded the model in 5.485493792s +//! > loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s +//! > starting the inference loop +//! > The girl is eating a hamburger.< +//! > 9 tokens generated (0.68 token/s) +//! ``` use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear}; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index 6390f886d2..15e386d292 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -1,7 +1,9 @@ //! RWKV v5 model implementation. //! -//! RWKV is an RNN with transformer-level performance that can be implemented -//! as either a transformer or RNN. +//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model +//! with performance on par with transformer architectures. Several variants are +//! available, candle implements the v5 and v6 versions and can be used with +//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)). //! //! Key characteristics: //! - Time-mix attention mechanism @@ -14,6 +16,20 @@ //! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM) //! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main) //! +//! # Example +//! +//! ```bash +//! cargo run --example rwkv --release -- \ +//! --prompt "The smallest prime is " +//! +//! > avx: true, neon: false, simd128: false, f16c: true +//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64 +//! > The smallest prime is ϕ(2) = 2. +//! > The smallest composite is ϕ(3) = 3. +//! > The smallest perfect number is ϕ(5) = 5. +//! > The smallest perfect square is ϕ(4) = 4. +//! > The smallest perfect cube is ϕ(6) = 6. +//! ``` use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor}; diff --git a/candle-transformers/src/models/rwkv_v6.rs b/candle-transformers/src/models/rwkv_v6.rs index c75aa885e9..5da1c5ce81 100644 --- a/candle-transformers/src/models/rwkv_v6.rs +++ b/candle-transformers/src/models/rwkv_v6.rs @@ -1,7 +1,9 @@ //! RWKV v6 model implementation. //! -//! RWKV is an RNN with transformer-like performance. -//! Version 6 introduces refinements to the architecture. +//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model +//! with performance on par with transformer architectures. Several variants are +//! available, candle implements the v5 and v6 versions and can be used with +//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)). //! //! Key characteristics: //! - Linear attention mechanism @@ -10,9 +12,20 @@ //! - Feed forward gating //! - State recycling for efficient inference //! -//! References: -//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM) +//! # Example //! +//! ```bash +//! cargo run --example rwkv --release -- \ +//! --prompt "The smallest prime is " +//! +//! > avx: true, neon: false, simd128: false, f16c: true +//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64 +//! > The smallest prime is ϕ(2) = 2. +//! > The smallest composite is ϕ(3) = 3. +//! > The smallest perfect number is ϕ(5) = 5. +//! > The smallest perfect square is ϕ(4) = 4. +//! > The smallest perfect cube is ϕ(6) = 6. +//! ``` use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; use candle::{IndexOp, Result, Tensor}; diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs index 3e85fe3594..fe0b099008 100644 --- a/candle-transformers/src/models/segment_anything/mod.rs +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -1,10 +1,33 @@ //! Segment Anything Model (SAM) //! //! SAM is an architecture for image segmentation, capable of segmenting any object -//! in an image based on prompts like points or boxes. +//! in an image based on prompts like points or boxes. //! This model provides a robust and fast image segmentation pipeline that can be tweaked via +//! some prompting (requesting some points to be in the target mask, requesting some +//! points to be part of the background so _not_ in the target mask, specifying some +//! bounding box). //! -//! - [GH Link](https://github.com/facebookresearch/segment-anything) -//! - [Paper](https://arxiv.org/abs/2304.02643) +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/candle-segment-anything-wasm) +//! - 💻 [GH Link](https://github.com/facebookresearch/segment-anything) +//! - 📝 [Paper](https://arxiv.org/abs/2304.02643) +//! - 💡 The default backbone can be replaced by the smaller and faster TinyViT model based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). +//! +//! +//! ## Example +//! +//! ```bash +//! cargo run --example segment-anything --release -- \ +//! --image candle-examples/examples/yolo-v8/assets/bike.jpg +//! --use-tiny --point 0.6,0.6 --point 0.6,0.55 +//! ``` +//! +//!
+//! +//! +//! +//!
+//! +//! +//! > Original; Prompt with `--point 0.6,0.55`; Prompt with `--point 0.6,0.6 --point 0.6,0.55` //! pub use crate::models::with_tracing::Linear; use candle::{Result, Tensor}; diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index d3e2032b6e..458a7de2d4 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -5,7 +5,37 @@ //! //! - [Original Repository](https://github.com/CompVis/stable-diffusion) //! - [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5) +//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler. //! +//! +//! # Example +//! +//!
+//! rusty robot holding a candle +//!
+//! +//! _"A rusty robot holding a fire torch in its hand."_ Generated by Stable Diffusion XL using Rust and [candle](https://github.com/huggingface/candle). +//! +//! ```bash +//! # example running with cuda +//! # see the candle-examples/examples/stable-diffusion for all options +//! cargo run --example stable-diffusion --release --features=cuda,cudnn \ +//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" +//! +//! # with sd-turbo +//! cargo run --example stable-diffusion --release --features=cuda,cudnn \ +//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \ +//! --sd-version turbo +//! +//! # with flash attention. +//! # feature flag: `--features flash-attn` +//! # cli flag: `--use-flash-attn`. +//! # flash-attention-v2 is only compatible with Ampere, Ada, \ +//! # or Hopper GPUs (e.g., A100/H100, RTX 3090/4090). +//! cargo run --example stable-diffusion --release --features=cuda,cudnn \ +//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \ +//! --use-flash-attn +//! ``` pub mod attention; pub mod clip; diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 9da0c1afec..d3fd2ba686 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -14,6 +14,49 @@ //! - [T5 Paper](https://arxiv.org/abs/1910.10683) //! - [HuggingFace T5](https://huggingface.co/docs/transformers/model_doc/t5) //! - [GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) +//! +//! # Encoder-decoder example: +//! +//! ```bash +//! cargo run --example t5 --release -- \ +//! --model-id "t5-small" \ +//! --prompt "translate to German: A beautiful candle." \ +//! --decode +//! > ... +//! > Eine schöne Kerze. +//! > 9 tokens generated (2.42 token/s) +//! ``` +//! +//! Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported. +//! +//! # Translation with MADLAD +//! +//! +//! [MADLAD-400](https://arxiv.org/abs/2309.04662) is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models. +//! +//! ```bash +//! cargo run --example t5 --release -- \ +//! --model-id "jbochi/madlad400-3b-mt" \ +//! --prompt "<2de> How are you, my friend?" \ +//! --decode --temperature 0 +//! ... +//! Wie geht es dir, mein Freund? +//! ``` +//! +//! ## Sentence embedding example +//! +//! ```bash +//! cargo run --example t5 --release -- \ +//! --model-id "t5-small" --prompt "A beautiful candle." +//! ... +//! [[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265], +//! [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164], +//! [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962], +//! [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990], +//! [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]] +//! Tensor[[1, 5, 512], f32] +//! Took 303.766583ms +//! ``` use crate::models::with_tracing::Embedding; use candle::{DType, Device, Module, Result, Tensor, D}; From 386fd8abb4be23c125e8100fed932f17d356a160 Mon Sep 17 00:00:00 2001 From: zachcp Date: Mon, 18 Nov 2024 08:19:23 -0500 Subject: [PATCH 06/15] Module Docs (#2624) * update whisper * update llama2c * update t5 * update phi and t5 * add a blip model * qlamma doc * add two new docs * add docs and emoji * additional models * openclip * pixtral * edits on the model docs * update yu * update a fe wmore models * add persimmon * add model-level doc * names * update module doc * links in heira * remove empty URL * update more hyperlinks * updated hyperlinks * more links * Update mod.rs --------- Co-authored-by: Laurent Mazare --- candle-transformers/src/models/blip.rs | 9 ++++--- candle-transformers/src/models/blip_text.rs | 9 ++++--- candle-transformers/src/models/chatglm.rs | 6 ++--- .../src/models/chinese_clip/mod.rs | 5 ++-- .../src/models/chinese_clip/text_model.rs | 6 ++--- .../src/models/chinese_clip/vision_model.rs | 6 ++--- candle-transformers/src/models/clip/mod.rs | 6 +++-- .../src/models/clip/text_model.rs | 4 ++-- .../src/models/codegeex4_9b.rs | 7 +++--- candle-transformers/src/models/convmixer.rs | 6 ++--- candle-transformers/src/models/convnext.rs | 15 +++++++----- candle-transformers/src/models/flux/mod.rs | 6 ++--- candle-transformers/src/models/hiera.rs | 7 +++--- candle-transformers/src/models/llama2_c.rs | 4 +++- candle-transformers/src/models/llava/mod.rs | 9 ++++--- candle-transformers/src/models/mimi/mod.rs | 24 ++++++++++++++++--- candle-transformers/src/models/mmdit/mod.rs | 12 +++++++--- candle-transformers/src/models/mod.rs | 16 +++++++++++++ .../src/models/openclip/mod.rs | 6 ++++- candle-transformers/src/models/persimmon.rs | 10 ++++---- candle-transformers/src/models/phi.rs | 9 +++---- candle-transformers/src/models/pixtral/mod.rs | 8 +++---- .../src/models/quantized_llama.rs | 7 +++--- .../src/models/quantized_t5.rs | 6 ++--- candle-transformers/src/models/qwen2.rs | 3 +-- candle-transformers/src/models/repvgg.rs | 5 +--- candle-transformers/src/models/siglip.rs | 2 +- .../src/models/stable_diffusion/clip.rs | 2 +- .../src/models/stable_diffusion/ddpm.rs | 2 +- .../euler_ancestral_discrete.rs | 9 ++----- .../src/models/stable_diffusion/mod.rs | 6 ++--- .../src/models/stable_diffusion/resnet.rs | 3 ++- .../src/models/stable_diffusion/schedulers.rs | 2 +- candle-transformers/src/models/stable_lm.rs | 2 +- candle-transformers/src/models/starcoder2.rs | 4 ++-- candle-transformers/src/models/t5.rs | 7 +++--- candle-transformers/src/models/whisper/mod.rs | 10 +++++--- .../src/models/wuerstchen/mod.rs | 13 +++++++--- candle-transformers/src/models/yi.rs | 12 ++++++---- 39 files changed, 170 insertions(+), 115 deletions(-) diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index 0330386574..a391daacbf 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -1,8 +1,11 @@ //! Based on the BLIP paper from Salesforce Research. //! -//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" -//! - [Arxiv](https://arxiv.org/abs/2201.12086) -//! - [Github](https://github.com/salesforce/BLIP) +//! The blip-image-captioning model can generate captions for an input image. +//! +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning) +//! - 💻 [GH Link](https://github.com/salesforce/BLIP) +//! - 🤗 [HF Link](https://huggingface.co/Salesforce/blip-image-captioning-base) +//! - 📝 [Paper](https://arxiv.org/abs/2201.12086) //! use super::blip_text; diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs index aceaf4ac1b..ad28193b16 100644 --- a/candle-transformers/src/models/blip_text.rs +++ b/candle-transformers/src/models/blip_text.rs @@ -1,9 +1,12 @@ //! Implementation of BLIP text encoder/decoder. //! -//! See "BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" -//! https://arxiv.org/abs/2201.12086 +//! - 📝 [Paper](https://arxiv.org/abs/2201.12086). BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation" +//! +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning) +//! - 💻 [GH Link](https://github.com/salesforce/BLIP) +//! - 🤗 [HF Link](https://huggingface.co/Salesforce/blip-image-captioning-base) +//! - 📝 [Paper](https://arxiv.org/abs/2201.12086) //! - use super::with_tracing::{linear, Embedding, Linear}; use candle::{Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 8d5d9ec601..a115c7fef2 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -1,10 +1,8 @@ //! Implementation of the ChatGLM2/3 models from THUDM. //! -//! See: -//! - ChatGLM3: ["ChatGLM3: Advancing Multilingual Conversational Language Models with High-Quality Data"](https://github.com/THUDM/ChatGLM3) -//! - ChatGLM2: ["ChatGLM2: An Open Bilingual Chat LLM"](https://github.com/THUDM/ChatGLM2-6B) +//! - 💻 [Github](https://github.com/THUDM/ChatGLM3) ChatGLM3: Advancing Multilingual Conversational Language Models with High-Quality Data +//! - 💻 [Github](https://github.com/THUDM/ChatGLM2-6B) ChatGLM2-6B. //! - use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 86616baa1c..1edc903179 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -3,10 +3,9 @@ //! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! - [GH Link](https://github.com/OFA-Sys/Chinese-CLIP) -//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) +//! - 💻 [GH Link](https://github.com/OFA-Sys/Chinese-CLIP) +//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) //! - use candle::{Module, Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/chinese_clip/text_model.rs b/candle-transformers/src/models/chinese_clip/text_model.rs index 19499709a7..1cbf7c914e 100644 --- a/candle-transformers/src/models/chinese_clip/text_model.rs +++ b/candle-transformers/src/models/chinese_clip/text_model.rs @@ -3,8 +3,8 @@ //! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/OFA-Sys/Chinese-CLIP -//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py +//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP) +//! - 💻 [HF](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn as nn; @@ -67,7 +67,7 @@ impl Default for ChineseClipTextConfig { } impl ChineseClipTextConfig { - /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + /// [referer](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json) pub fn clip_vit_base_patch16() -> Self { Self { vocab_size: 21128, diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs index 2d345e0f4a..a20535c40e 100644 --- a/candle-transformers/src/models/chinese_clip/vision_model.rs +++ b/candle-transformers/src/models/chinese_clip/vision_model.rs @@ -3,8 +3,8 @@ //! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/OFA-Sys/Chinese-CLIP -//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py +//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP) +//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_ use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D}; use candle_nn as nn; @@ -49,7 +49,7 @@ impl Default for ChineseClipVisionConfig { } impl ChineseClipVisionConfig { - /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + /// [referer](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json) pub fn clip_vit_base_patch16() -> Self { Self { hidden_size: 768, diff --git a/candle-transformers/src/models/clip/mod.rs b/candle-transformers/src/models/clip/mod.rs index e83f27e388..2b00267317 100644 --- a/candle-transformers/src/models/clip/mod.rs +++ b/candle-transformers/src/models/clip/mod.rs @@ -3,8 +3,10 @@ //! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! - [GH Link](https://github.com/openai/CLIP) -//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip) +//! - 💻 [GH Link](https://github.com/openai/CLIP) +//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip) +//! - 🤗 [HF Model](https://huggingface.co/openai/clip-vit-large-patch14-336) +//! use self::{ text_model::{Activation, ClipTextTransformer}, diff --git a/candle-transformers/src/models/clip/text_model.rs b/candle-transformers/src/models/clip/text_model.rs index 4662f65fda..eb103bd29a 100644 --- a/candle-transformers/src/models/clip/text_model.rs +++ b/candle-transformers/src/models/clip/text_model.rs @@ -3,8 +3,8 @@ //! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/openai/CLIP -//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip +//! - [GH](https://github.com/openai/CLIP) +//! - [Code](https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip) use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/codegeex4_9b.rs b/candle-transformers/src/models/codegeex4_9b.rs index baf4745922..c37a97d57e 100644 --- a/candle-transformers/src/models/codegeex4_9b.rs +++ b/candle-transformers/src/models/codegeex4_9b.rs @@ -1,8 +1,9 @@ //! CodeGeeX4 - A multi-language code generation model //! -//! See "CodeGeeX: A Pre-Trained Model For Code Generation with Multilingual Evaluations on HumanEval-X", Qian et al. 2023 -//! - [Arxiv](https://arxiv.org/abs/2303.17568) -//! - [Github](https://github.com/THUDM/CodeGeeX) +//! A Pre-Trained Model For Code Generation with Multilingual Evaluations on HumanEval-X" +//! +//! - 📝 [Arxiv](https://arxiv.org/abs/2303.17568) +//! - 💻 [Github](https://github.com/THUDM/CodeGeeX) //! use crate::models::with_tracing::{linear_b as linear, Linear}; diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index e095f793a4..7f1b75ebc4 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -1,10 +1,10 @@ //! ConvMixer implementation. //! //! See "Patches Are All You Need?" by Trockman et al. 2022 -//! - [Arxiv](https://arxiv.org/abs/2201.09792) -//! - [Github](https://github.com/locuslab/convmixer) //! - +//! - 📝 [Arxiv](https://arxiv.org/abs/2201.09792) +//! - 💻 [Github](https://github.com/locuslab/convmixer) +//! use candle::Result; use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder}; diff --git a/candle-transformers/src/models/convnext.rs b/candle-transformers/src/models/convnext.rs index d791895f1d..727e11381c 100644 --- a/candle-transformers/src/models/convnext.rs +++ b/candle-transformers/src/models/convnext.rs @@ -1,13 +1,16 @@ //! ConvNeXt implementation. //! -//! See ["A ConvNet for the 2020s" Liu et al. 2022](https://arxiv.org/abs/2201.03545) -//! and -//! ["ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023](https://arxiv.org/abs/2301.00808) +//! This candle implementation uses a pre-trained ConvNeXt network for inference. The +//! classification head has been trained on the ImageNet dataset and returns the +//! probabilities for the top-5 classes. //! //! Original code: -//! - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/) -//! - [ConvNeXt-V2](https://github.com/facebookresearch/ConvNeXt-V2/) -//! - [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py) +//! - 💻 [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/) +//! - 💻 [ConvNeXt-V2](https://github.com/facebookresearch/ConvNeXt-V2/) +//! - 💻 [timm](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py) +//! - 📝 [Paper](https://arxiv.org/abs/2201.03545) A ConvNet for the 2020s +//! - 📝 [Paper](https://arxiv.org/abs/2301.00808) ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders +//! use candle::shape::ShapeWithOneHole; use candle::{Result, D}; diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index 064c5130f5..1d2fa4ef33 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -2,9 +2,9 @@ //! //! Flux is a 12B rectified flow transformer capable of generating images from text descriptions. //! -//! - [Hugging Face Model](https://huggingface.co/black-forest-labs/FLUX.1-schnell) -//! - [GitHub Repository](https://github.com/black-forest-labs/flux) -//! - [Blog Post](https://blackforestlabs.ai/announcing-black-forest-labs/) +//! - 🤗 [Hugging Face Model](https://huggingface.co/black-forest-labs/FLUX.1-schnell) +//! - 💻 [GitHub Repository](https://github.com/black-forest-labs/flux) +//! - 📝 [Blog Post](https://blackforestlabs.ai/announcing-black-forest-labs/) //! //! # Usage //! diff --git a/candle-transformers/src/models/hiera.rs b/candle-transformers/src/models/hiera.rs index 39f8d639b6..98ad825737 100644 --- a/candle-transformers/src/models/hiera.rs +++ b/candle-transformers/src/models/hiera.rs @@ -1,9 +1,8 @@ -//! [Hiera] inference implementation based on timm. +//! Hiera inference implementation based on timm. //! -//! See "[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]" -//! [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles]: https://arxiv.org/abs/2306.00989 //! -//! [Hiera]: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py +//! - 💻 [Hiera](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/hiera.py) +//! - 📝 [Paper](https://arxiv.org/abs/2306.00989). Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles use candle::{Result, D}; use candle_nn::{conv2d, layer_norm, linear, ops::softmax, Conv2dConfig, Func, VarBuilder}; diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index d825d8e4dd..930c8b8aa6 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -2,7 +2,9 @@ //! //! See ["LLaMA 2: Open Foundation and Fine-Tuned Chat Models"](https://arxiv.org/abs/2307.09288) //! -//! Based on the [llama2.c](https://github.com/karpathy/llama2.c) implementation +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/lmz/candle-llama2) +//! - 💻 llama2.c [GH Link](https://github.com/karpathy/llama2.c) +//! use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index 44a00bf9a1..c252dbed56 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -1,13 +1,12 @@ //! The LLaVA (Large Language and Vision Assistant) model. //! //! This provides the main model implementation combining a vision tower (CLIP) with -//! language model (Llama) for multimodal capabilities. +//! language model (Llama) for multimodal capabilities. The architecture implements the training-free projection technique. //! -//! The architecture implements the training-free projection technique from the paper: -//! [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485). -//! -//! - [GH Link](https://github.com/haotian-liu/LLaVA/tree/main) +//! - 💻[GH Link](https://github.com/haotian-liu/LLaVA/tree/main) +//! - 📝 [Paper](https://arxiv.org/abs/2304.08485)/ Visual Instruction Tuning //! + pub mod config; pub mod utils; diff --git a/candle-transformers/src/models/mimi/mod.rs b/candle-transformers/src/models/mimi/mod.rs index f19f9ae5fa..8945abfb03 100644 --- a/candle-transformers/src/models/mimi/mod.rs +++ b/candle-transformers/src/models/mimi/mod.rs @@ -1,9 +1,27 @@ //! mimi model //! -//! Mimi is a state-of-the-art audio neural codec. +//! [Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio +//! compression model using an encoder/decoder architecture with residual vector +//! quantization. The candle implementation supports streaming meaning that it's +//! possible to encode or decode a stream of audio tokens on the flight to provide +//! low latency interaction with an audio model. //! -//! - [HuggingFace Model Card](https://huggingface.co/kyutai/mimi) -//! - [GitHub](https://github.com/kyutai-labs/moshi) +//! - 🤗 [HuggingFace Model Card](https://huggingface.co/kyutai/mimi) +//! - 💻 [GitHub](https://github.com/kyutai-labs/moshi) +//! +//! +//! # Example +//! ```bash +//! # Generating some audio tokens from an audio files. +//! wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 +//! cargo run --example mimi \ +//! --features mimi --release -- \ +//! audio-to-code bria.mp3 bria.safetensors +//! +//! # And decoding the audio tokens back into a sound file. +//! cargo run --example mimi +//! --features mimi --release -- \ +//! code-to-audio bria.safetensors bria.wav //! // Copyright (c) Kyutai, all rights reserved. diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs index ce4872e0b2..88e73e1e3d 100644 --- a/candle-transformers/src/models/mmdit/mod.rs +++ b/candle-transformers/src/models/mmdit/mod.rs @@ -3,9 +3,15 @@ //! Mix of Multi-scale Dilated and Traditional Convolutions (MMDiT) is an architecture //! introduced for Stable Diffusion 3, with the MMDiT-X variant used in Stable Diffusion 3.5. //! -//! - [Research Paper](https://arxiv.org/abs/2403.03206) -//! - ComfyUI [reference implementation](https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py) -//! - Stability-AI [MMDiT-X implementation](https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py) +//! - 📝 [Research Paper](https://arxiv.org/abs/2403.03206) +//! - 💻 ComfyUI [reference implementation](https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py) +//! - 💻 Stability-AI [MMDiT-X implementation](https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py) + +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning) +//! - 💻 [GH Link](https://github.com/salesforce/BLIP) +//! - 🤗 [HF Link](https://huggingface.co/Salesforce/blip-image-captioning-base) +//! - 📝 [Paper](https://arxiv.org/abs/2201.12086) +//! pub mod blocks; pub mod embedding; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 23edf349ad..571a88614d 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,3 +1,19 @@ +//! Candle implementations for various deep learning models +//! +//! This crate provides implementations of popular machine learning models and architectures for different modalities. +//! +//! - Large language models: [`llama`], [`phi3`], [`mamba`], [`mixtral`], [`bert`], ... +//! - Text to text models: [`t5`], ... +//! - Image to text models: [`blip`], ... +//! - Text to image models: [`stable_diffusion`] and [`wuerstchen`], ... +//! - Audio models: [`whisper`], [`encodec`], [`metavoice`], [`parler_tts`], ... +//! - Computer vision models: [`dinov2`], [`convmixer`], [`efficientnet`], ... +//! +//! Some of the models also have quantized variants, e.g. [`quantized_blip`], [`quantized_llama`] and [`quantized_qwen2`]. +//! +//! The implementations aim to be readable while maintaining good performance. For more information +//! on each model see the model's module docs in the links below. + pub mod based; pub mod beit; pub mod bert; diff --git a/candle-transformers/src/models/openclip/mod.rs b/candle-transformers/src/models/openclip/mod.rs index dacb627f9e..b3864b815e 100644 --- a/candle-transformers/src/models/openclip/mod.rs +++ b/candle-transformers/src/models/openclip/mod.rs @@ -3,7 +3,11 @@ //! Open Contrastive Language-Image Pre-Training (OpenCLIP) is an architecture trained on //! pairs of images with related texts. //! -//! - [GH Link](https://github.com/mlfoundations/open_clip) +//! - 💻 [GH Link](https://github.com/mlfoundations/open_clip) +//! - 📝 [Paper](https://arxiv.org/abs/2212.07143) //! +//! ## Overview +//! +//! ![](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/CLIP.png) pub mod text_model; diff --git a/candle-transformers/src/models/persimmon.rs b/candle-transformers/src/models/persimmon.rs index 0996decf55..d1e3db316f 100644 --- a/candle-transformers/src/models/persimmon.rs +++ b/candle-transformers/src/models/persimmon.rs @@ -1,17 +1,15 @@ //! Persimmon Model //! -//! A transformer language model for efficient inference and general-purpose tasks. See Persimmon model details at: -//! - [Hugging Face](https://huggingface.co/adept/persimmon-8b-base) -//! -//! The model uses a standard transformer architecture with: +//! A transformer language model for efficient inference and general-purpose tasks. The model uses a standard transformer architecture with: //! - Layer normalization for Q/K attention //! - RoPE embeddings with partial rotary factor //! - ReLU activation //! - Separate number of attention heads and KV heads //! //! References: -//! - [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/modeling_persimmon.py) -//! - [Persimmon Config](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py) +//! - 💻 [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/modeling_persimmon.py) +//! - 💻 [Persimmon Config](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py) +//! - 🤗 [Hugging Face](https://huggingface.co/adept/persimmon-8b-base) //! use candle::DType; diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index 36a08bb3c6..c94ef6686b 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -1,18 +1,15 @@ //! Microsoft Phi model implementation //! -//! See Phi model details at: -//! - [Phi-2 Model](https://huggingface.co/microsoft/phi-2) -//! //! The Phi series are decoder-only transformers designed for code and language tasks. +//! //! Key characteristics: //! - Decoder-only transformer architecture //! - RoPE embeddings //! - Layer normalization //! - QK normalization //! -//! References: -//! - [Hugging Face Implementation](https://huggingface.co/microsoft/phi-2) -//! - [Alternative Implementation](https://huggingface.co/microsoft/phi-2/tree/main) +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-phi1-phi2-wasm-demo) +//! - 🤗 [HF Link](https://huggingface.co/microsoft/phi-2) //! use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs index e722ffcfd2..18bcc5f793 100644 --- a/candle-transformers/src/models/pixtral/mod.rs +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -3,10 +3,10 @@ //! Pixtral is an architecture trained for multimodal learning //! using images paired with text descriptions. //! -//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) -//! - [Blog Post](https://mistral.ai/news/pixtral-12b/) - -//! - [HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) - -//! - [HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b). +//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) +//! - 📝 [Blog Post](https://mistral.ai/news/pixtral-12b/) +//! - 🤗 [HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) +//! - 🤗 [HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b) //! //! # Example //! diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 7efd385d61..e171b54fd8 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -10,9 +10,10 @@ //! - Optimized memory usage through quantization //! - Configurable model sizes and parameter counts //! -//! References: -//! - [LLaMA Paper](https://arxiv.org/abs/2302.13971) -//! - [LLaMA Model](https://github.com/facebookresearch/llama) +//! - 💻 [GH Link](https://github.com/facebookresearch/llama) +//! - 📝 [Paper](https://arxiv.org/abs/2302.13971) +//! +//! ![](https://raw.githubusercontent.com/huggingface/candle/main/candle-examples/examples/quantized/assets/aoc.gif) //! use std::collections::HashMap; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 9f770d69d9..4fc9c537f8 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -11,9 +11,9 @@ //! - Support for 8-bit quantization //! //! References: -//! - [T5 Paper](https://arxiv.org/abs/1910.10683) -//! - [Model Card](https://huggingface.co/t5-base) -//! - Original model from [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) +//! - 📝 [T5 Paper](https://arxiv.org/abs/1910.10683) +//! - 🤗 [Model Card](https://huggingface.co/t5-base) +//! - 🤗 Original model from [T5](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating}; use crate::models::with_tracing::QMatMul; diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 8dbca36b3e..8a29646efe 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -11,8 +11,7 @@ //! - Support for 8-bit quantization //! //! References: -//! - [Qwen2 Model](https://huggingface.co/Qwen/Qwen2-7B) -//! - [Model Card](https://huggingface.co/Qwen/Qwen2-7B) +//! - 🤗 [Qwen2 Model](https://huggingface.co/Qwen/Qwen2-7B) //! use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs index a6ffce0d6d..6e45c2d68c 100644 --- a/candle-transformers/src/models/repvgg.rs +++ b/candle-transformers/src/models/repvgg.rs @@ -1,8 +1,5 @@ //! RepVGG inference implementation //! -//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 -//! https://arxiv.org/abs/2101.03697 -//! //! Key characteristics: //! - Efficient inference architecture through structural reparameterization //! - Single 3x3 conv layer after fusing 3x3 branch, 1x1 branch and identity branch @@ -10,7 +7,7 @@ //! - High accuracy with VGG-like plain architecture and training //! //! References: -//! - [RepVGG Paper](https://arxiv.org/abs/2101.03697) +//! - [RepVGG Paper](https://arxiv.org/abs/2101.03697). RepVGG: Making VGG-style ConvNets Great Again //! - [Official Implementation](https://github.com/DingXiaoH/RepVGG) //! diff --git a/candle-transformers/src/models/siglip.rs b/candle-transformers/src/models/siglip.rs index 2046401428..932970ed3b 100644 --- a/candle-transformers/src/models/siglip.rs +++ b/candle-transformers/src/models/siglip.rs @@ -3,7 +3,7 @@ //! Siglip architecture combining vision and language for zero-shot tasks. //! //! References: -//! - [Model Card](https://huggingface.co/google/siglip-base-patch16-224) +//! - 🤗 [Model Card](https://huggingface.co/google/siglip-base-patch16-224) //! use crate::models::clip::div_l2_norm; diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 2f631248bc..4c3f9d512d 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -3,7 +3,7 @@ //! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on //! pairs of images with related texts. //! -//! https://github.com/openai/CLIP +//! - [CLIP](https://github.com/openai/CLIP) use candle::{DType, Device, Result, Tensor, D}; use candle_nn as nn; use candle_nn::Module; diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs index d393f39aac..42a0dc7e17 100644 --- a/candle-transformers/src/models/stable_diffusion/ddpm.rs +++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs @@ -104,7 +104,7 @@ impl DDPMScheduler { }; let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev; - // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + // For t > 0, compute predicted variance βt (see formula (6) and (7) from [the pdf](https://arxiv.org/pdf/2006.11239.pdf)) // and sample from it to get previous sample // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t; diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs index 9576c2de40..edd5eb508b 100644 --- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -1,12 +1,7 @@ //! Ancestral sampling with Euler method steps. //! -//! Reference implementation in Rust: -//! -//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs -//! -//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd]. +//! Based on the original [`k-diffusion` implementation by Katherine Crowson]( https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72). /// -/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 use super::{ schedulers::{ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, @@ -29,7 +24,7 @@ pub struct EulerAncestralDiscreteSchedulerConfig { pub steps_offset: usize, /// prediction type of the scheduler function, one of `epsilon` (predicting /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) - /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + /// or `v_prediction` (see [section 2.4](https://imagen.research.google/video/paper.pdf)) pub prediction_type: PredictionType, /// number of diffusion steps used to train the model pub train_timesteps: usize, diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 458a7de2d4..6d89f9cd43 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -3,9 +3,9 @@ //! Stable Diffusion is a latent text-to-image diffusion model capable of //! generating photo-realistic images given any text input. //! -//! - [Original Repository](https://github.com/CompVis/stable-diffusion) -//! - [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5) -//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler. +//! - 💻 [Original Repository](https://github.com/CompVis/stable-diffusion) +//! - 🤗 [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5) +//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler. //! //! //! # Example diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs index 5df04a8b44..5cca7edd30 100644 --- a/candle-transformers/src/models/stable_diffusion/resnet.rs +++ b/candle-transformers/src/models/stable_diffusion/resnet.rs @@ -3,7 +3,8 @@ //! Some Residual Network blocks used in UNet models. //! //! Denoising Diffusion Implicit Models, K. He and al, 2015. -//! https://arxiv.org/abs/1512.03385 +//! - [Paper](https://arxiv.org/abs/1512.03385) +//! use crate::models::with_tracing::{conv2d, Conv2d}; use candle::{Result, Tensor, D}; use candle_nn as nn; diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 94f8ab86f7..1d39037f8f 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -43,7 +43,7 @@ pub enum PredictionType { /// Time step spacing for the diffusion process. /// -/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 +/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of the [paper](https://arxiv.org/abs/2305.08891) #[derive(Debug, Clone, Copy)] pub enum TimestepSpacing { Leading, diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index c5dbd3958d..536f7727e4 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -10,7 +10,7 @@ //! - Support for different model sizes (3B, 7B) //! //! References: -//! - [Model Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t) +//! - 🤗 [Model Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t) //! use crate::models::with_tracing::{linear, linear_no_bias, Linear}; diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index 0df5990b89..266221e5c8 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -11,8 +11,8 @@ //! - Support for 8-bit quantization //! //! References: -//! - [StarCoder Paper](https://arxiv.org/abs/2305.06161) -//! - [Model Card](https://huggingface.co/bigcode/starcoder) +//! - 📝 [StarCoder Paper](https://arxiv.org/abs/2305.06161) +//! - 🤗 [Model Card](https://huggingface.co/bigcode/starcoder) //! use candle::{DType, Device, Module, Result, Tensor, D}; diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index d3fd2ba686..5d23549f21 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -11,9 +11,10 @@ //! - Support for sequence-to-sequence tasks //! //! References: -//! - [T5 Paper](https://arxiv.org/abs/1910.10683) -//! - [HuggingFace T5](https://huggingface.co/docs/transformers/model_doc/t5) -//! - [GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm) +//! - 💻[GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py) +//! - 🤗 [HF Link](https://huggingface.co/docs/transformers/model_doc/t5) +//! - 📝 [T5 Paper](https://arxiv.org/abs/1910.10683) //! //! # Encoder-decoder example: //! diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 6123884ae4..d7082ea6d8 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -1,10 +1,14 @@ //! Whisper Model Implementation //! //! Whisper is an automatic speech recognition (ASR) system trained on large amounts -//! of multilingual and multitask supervised data collected from the web. +//! of multilingual and multitask supervised data collected from the web. It can be used to +//! convert audio files (in the `.wav` format) to text. Supported features include +//! language detection as well as multilingual speech recognition. +//! +//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/lmz/candle-whisper) +//! - 💻 [GH Link](https://github.com/openai/whisper) +//! - 💻 Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py) //! -//! - [GH Link](https://github.com/openai/whisper) -//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py) //! pub mod audio; pub mod model; diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 9bb37a3bcc..ae42c4a884 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -3,10 +3,17 @@ //! Würstchen is an efficient diffusion model architecture for generating images using //! a two-stage approach with a small decoder and prior network. //! -//! - [Paper Link](https://openreview.net/pdf?id=gU58AyJlYz) -//! - [GH Link](https://github.com/dome272/Wuerstchen) -//! - [Reference Implementation](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py) +//! - 💻 [GH Link](https://github.com/dome272/Wuerstchen) +//! - 🤗 [HF Link](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py) +//! - 📝 [Paper](https://openreview.net/pdf?id=gU58AyJlYz) //! +//! ## Example +//! +//!
+//! +//!

"Anthropomorphic cat dressed as a fire fighter"

+//!
+ pub mod attention_processor; pub mod common; pub mod ddpm; diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs index 047ea77046..8a2fb111be 100644 --- a/candle-transformers/src/models/yi.rs +++ b/candle-transformers/src/models/yi.rs @@ -1,7 +1,12 @@ //! Yi model implementation. //! -//! Yi is a decoder-only large language model trained by 01.AI. -//! It follows a standard transformer architecture similar to Llama. +//! This candle implementation uses a pre-trained Yi decoder-only large language model for inference. +//! The model was trained by 01.AI and follows a standard transformer architecture similar to LLaMA. +//! +//! Original code: +//! - 💻 [Yi Model](https://huggingface.co/01-ai/Yi-6B) +//! - 💻 [Yi Modeling Code](https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py) +//! - 📝 [Technical Report](https://arxiv.org/abs/2403.04652) Yi: Open Foundation Models by 01.AI //! //! Key characteristics: //! - Multi-head attention with rotary positional embeddings @@ -9,9 +14,6 @@ //! - SwiGLU activation in feed-forward layers //! - Grouped-query attention for efficient inference //! -//! References: -//! - [Yi Model](https://huggingface.co/01-ai/Yi-6B) -//! - [Hugging Face](https://huggingface.co/01-ai/Yi-6B/blob/main/modeling_yi.py) use crate::models::with_tracing::{linear_no_bias, Linear, RmsNorm}; use candle::{DType, Device, Module, Result, Tensor, D}; From e86565624bcbc1c4bf2d33410d924bf97ad05f31 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 18 Nov 2024 14:32:38 +0100 Subject: [PATCH 07/15] Fix for clippy. (#2626) --- .../src/models/stable_diffusion/euler_ancestral_discrete.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs index edd5eb508b..c27e983a34 100644 --- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -1,7 +1,7 @@ //! Ancestral sampling with Euler method steps. //! //! Based on the original [`k-diffusion` implementation by Katherine Crowson]( https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72). -/// +//! use super::{ schedulers::{ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, From 1a0f9ccf16de9fc311b000a61e8e9e357a15855b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 19 Nov 2024 03:41:34 +0100 Subject: [PATCH 08/15] Import the ggml_cuda_dp4a function. (#2628) --- candle-kernels/src/quantized.cu | 77 +++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index 05f878f3d6..b6a4310005 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -82,6 +82,17 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) +static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + return __dp4a(a, b, c); +#else // __CUDA_ARCH__ >= MIN_CC_DP4A + const int8_t * a8 = (const int8_t *) &a; + const int8_t * b8 = (const int8_t *) &b; + return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + + #define MMQ_X_Q4_0_RDNA2 64 #define MMQ_Y_Q4_0_RDNA2 128 #define NWARPS_Q4_0_RDNA2 8 @@ -1821,8 +1832,8 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } const float2 ds8f = __half22float2(ds8); @@ -1844,8 +1855,8 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } #ifdef GGML_CUDA_F16 @@ -1878,14 +1889,14 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } const float2 ds8f = __half22float2(ds8); @@ -1909,14 +1920,14 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } #ifdef GGML_CUDA_F16 @@ -1945,7 +1956,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp #pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } return d8_0*d8_1 * sumi; @@ -1959,7 +1970,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp #pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } #ifdef GGML_CUDA_F16 @@ -1994,13 +2005,13 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int vi = (v >> (2*i)) & 0x03030303; - sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product // fill int with 4x m int m = sc >> 4; m |= m << 8; m |= m << 16; - sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values } const float2 dm2f = __half22float2(dm2); @@ -2029,8 +2040,8 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + sumi_d_sc = ggml_cuda_dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = ggml_cuda_dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m } sumi_d += sumi_d_sc * (sc & 0xF); @@ -2071,7 +2082,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int vi = __vsubss4(vil, vih); - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product } return d3 * sumf; @@ -2089,7 +2100,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( int sumi_sc = 0; for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; @@ -2114,8 +2125,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; - const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u sumf_d += d8[i] * (dot1 * sc[i]); sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values @@ -2140,7 +2151,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( #pragma unroll for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product } const float2 ds8f = __half22float2(ds8[i]); @@ -2176,8 +2187,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int v0i = vl0i | vh0i; const int v1i = vl1i | vh1i; - const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u sumf_d += d8[i] * (dot1 * sc[i]); sumf_m += d8[i] * (dot2 * m[i]); @@ -2203,7 +2214,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( #pragma unroll for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product } const float2 ds8f = __half22float2(ds8[i]); @@ -2237,7 +2248,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product } return d*sumf; @@ -2256,11 +2267,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( #pragma unroll for (int i = i0; i < i0 + 2; ++i) { - sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product - sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product - sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product - sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product } sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); @@ -2488,10 +2499,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const int v1 = q4[0]; const int v2 = q4[4]; - const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + const int dot1 = ggml_cuda_dp4a(ui2, v2 & 0x0f0f0f0f, ggml_cuda_dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = ggml_cuda_dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, ggml_cuda_dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = ggml_cuda_dp4a(0x01010101, ui2, ggml_cuda_dp4a(0x01010101, ui1, 0)); + const int dot4 = ggml_cuda_dp4a(0x01010101, ui4, ggml_cuda_dp4a(0x01010101, ui3, 0)); sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); @@ -2576,8 +2587,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); - const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) - + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + const float sumf_d = d8_1 * (ggml_cuda_dp4a(ui1, v1, 0) * s[0] + ggml_cuda_dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (ggml_cuda_dp4a(ui3, v3, 0) * s[2] + ggml_cuda_dp4a(ui4, v4, 0) * s[3]); return d * sumf_d; #endif From 3159f91b90a5bc68b275f8688472ba8917a834da Mon Sep 17 00:00:00 2001 From: zachcp Date: Mon, 18 Nov 2024 22:07:07 -0500 Subject: [PATCH 09/15] 20241118 docs (#2629) * module docs * varbuilder gguf docs * add a link to gguf files * small additonal mod doc titles * safetensor docs * more core docs * more module docs in canlde_core * 2 more link fixes --- candle-core/src/backend.rs | 2 ++ candle-core/src/backprop.rs | 2 +- candle-core/src/conv.rs | 2 ++ candle-core/src/cpu/mod.rs | 2 ++ candle-core/src/cpu_backend/mod.rs | 1 + candle-core/src/cuda_backend/mod.rs | 2 ++ candle-core/src/device.rs | 1 + candle-core/src/display.rs | 7 ++++--- candle-core/src/dummy_cuda_backend.rs | 2 ++ candle-core/src/error.rs | 1 + candle-core/src/layout.rs | 1 + candle-core/src/lib.rs | 8 ++++---- candle-core/src/metal_backend/mod.rs | 2 ++ candle-core/src/op.rs | 2 ++ candle-core/src/pickle.rs | 2 +- candle-core/src/quantized/ggml_file.rs | 2 +- candle-core/src/quantized/gguf_file.rs | 3 +-- candle-core/src/quantized/mod.rs | 1 + candle-core/src/safetensors.rs | 11 +++++++++++ candle-core/src/scalar.rs | 2 ++ candle-core/src/streaming.rs | 2 ++ candle-core/src/utils.rs | 1 + candle-transformers/src/generation/mod.rs | 5 +++++ candle-transformers/src/object_detection.rs | 6 ++++++ candle-transformers/src/quantized_nn.rs | 6 ++++++ candle-transformers/src/quantized_var_builder.rs | 6 ++++++ candle-transformers/src/utils.rs | 2 ++ 27 files changed, 72 insertions(+), 12 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index afe3e40754..f98cb4f4fd 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -1,3 +1,5 @@ +//! Traits to Define Backend Behavior +//! use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a556677478..d19f099f71 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,4 +1,4 @@ -/// Methods for backpropagation of gradients. +//! Methods for backpropagation of gradients. use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 7b3922dd73..4728c21a23 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,3 +1,5 @@ +//! 1D and 2D Convolutions +//! use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/candle-core/src/cpu/mod.rs b/candle-core/src/cpu/mod.rs index e7d8b6906f..be5b99128e 100644 --- a/candle-core/src/cpu/mod.rs +++ b/candle-core/src/cpu/mod.rs @@ -1,3 +1,5 @@ +//! Traits and methods for CPU-backed Tensors + pub mod erf; pub mod kernels; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 58773c8020..229e3bbce1 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1,3 +1,4 @@ +//! Implementation of Backend Fns for CPU use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index f14e00d533..37fef5078e 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1,3 +1,5 @@ +//! Implementation of Backend traits for CUDA device +//! use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 18aa61aff7..9b1fb9ee00 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -11,6 +11,7 @@ pub enum DeviceLocation { Metal { gpu_id: usize }, } +/// Cpu, Cuda, or Metal #[derive(Debug, Clone)] pub enum Device { Cpu, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 7e6e3cf8f1..76d39010a9 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -1,6 +1,7 @@ -/// Pretty printing of tensors -/// This implementation should be in line with the PyTorch version. -/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py +//! Pretty printing of tensors +//! +//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py). +//! use crate::{DType, Result, Tensor, WithDType}; use half::{bf16, f16}; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index b4f2e8aa00..9d30d8214d 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,3 +1,5 @@ +//! Implementation of the Cuda backend when Cuda support has not been compiled in. +//! #![allow(dead_code)] use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index a35bec3cbe..15604c15a8 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,3 +1,4 @@ +//! Candle-specific Error and Result use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 7e3b7afbba..949695848b 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -1,3 +1,4 @@ +//! Tensor Layouts including contiguous or sparse strides use crate::{Error, Result, Shape}; #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 4b73d00696..5f9a1c97a5 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -7,8 +7,8 @@ //! //! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; //! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; -//! //! let c = a.matmul(&b)?; +//! //! # Ok(())} //! ``` //! @@ -140,7 +140,7 @@ impl ToUsize2 for (usize, usize) { } } -// A simple trait defining a module with forward method using a single argument. +/// Defining a module with forward method using a single argument. pub trait Module { fn forward(&self, xs: &Tensor) -> Result; } @@ -160,8 +160,8 @@ impl Module for Option<&M> { } } -// A trait defining a module with forward method using a single tensor argument and a flag to -// separate the training and evaluation behaviors. +/// A single forward method using a single single tensor argument and a flag to +/// separate the training and evaluation behaviors. pub trait ModuleT { fn forward_t(&self, xs: &Tensor, train: bool) -> Result; } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index de107a61b0..47f54c8d59 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1,3 +1,5 @@ +//! Implementation of Backend traits for Metal +//! use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 49ba44be89..c5fc3fc475 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -1,3 +1,5 @@ +//! Tensor Opertion Enums and Traits +//! #![allow(clippy::redundant_closure_call)] use crate::Tensor; use half::{bf16, f16}; diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 08335257c6..24f13d2025 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -1,4 +1,4 @@ -// Just enough pickle support to be able to read PyTorch checkpoints. +//! Just enough pickle support to be able to read PyTorch checkpoints. // This hardcodes objects that are required for tensor reading, we may want to make this a bit more // composable/tensor agnostic at some point. use crate::{DType, Error as E, Layout, Result, Tensor}; diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 99200bbd06..0f7e9c118c 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -134,7 +134,7 @@ fn from_raw_data( super::QTensor::new(data, dims) } -/// Creates a [Tensor] from a raw GGML tensor. +/// Creates a Tensor from a raw GGML tensor. pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index d3fe4b5852..cdd1a1543e 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -1,6 +1,5 @@ -//! Support for the GGUF file format. +//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md). //! -//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; use crate::{Device, Result}; diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d852d50410..236f5a9811 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,3 +1,4 @@ +//! Code for GGML and GGUF files use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 5ea1f192b3..618e391e34 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,3 +1,14 @@ +//! Module to load `safetensor` files into CPU/GPU memory. +//! +//! There are multiple ways to load tensors from safetensor files: +//! - `load` function for loading directly into memory and returning a HashMap of tensors +//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation +//! - `SliceSafetensors` for working with in-memory buffers +//! - `BufferedSafetensors` for owning a buffer of data +//! +//! Tensors can also be serialized to safetensor format using the `save` function or +//! `Tensor::save_safetensors` method. +//! use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs index 43e1f4c8c5..30308d11c0 100644 --- a/candle-core/src/scalar.rs +++ b/candle-core/src/scalar.rs @@ -1,3 +1,5 @@ +//! TensorScalar Enum and Trait +//! use crate::{Result, Tensor, WithDType}; pub enum TensorScalar { diff --git a/candle-core/src/streaming.rs b/candle-core/src/streaming.rs index f70ec51e6c..f4c0a9ff0b 100644 --- a/candle-core/src/streaming.rs +++ b/candle-core/src/streaming.rs @@ -1,3 +1,5 @@ +//! StreamTensror useful for streaming ops. +//! use crate::{Result, Shape, Tensor}; pub trait Dim: crate::shape::Dim + Copy {} diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 78c45a9a9d..aa4d2705ef 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -1,3 +1,4 @@ +//! Useful functions for checking features. use std::str::FromStr; pub fn get_num_threads() -> usize { diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index c250a1865f..d95a05953a 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,3 +1,8 @@ +//! Logit Processing and Sampling +//! +//! Functionality for modeling sampling strategies and logits processing in text generation +//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p), +//! and combinations thereof. use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs index e922075fcc..d1b78cfa25 100644 --- a/candle-transformers/src/object_detection.rs +++ b/candle-transformers/src/object_detection.rs @@ -1,3 +1,9 @@ +//! Bounding Boxes and Intersection +//! +//! This module provides functionality for handling bounding boxes and their manipulation, +//! particularly in the context of object detection. It includes tools for calculating +//! intersection over union (IoU) and non-maximum suppression (NMS). + /// A bounding box around an object. #[derive(Debug, Clone)] pub struct Bbox { diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 9298b80e7e..4a83253d2e 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -1,3 +1,9 @@ +//! Utilities for quanitized network layers +//! +//! This module contains various implementations of standard neural network layers, modules and +//! utilities including embedding, linear layers, and various normalization techniques. +//! Most implementations provide quantized weights support. + use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; use candle::quantized::QTensor; diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 875a2b454d..2ac64aa5e7 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -1,3 +1,9 @@ +//! Varbuilder for Loading gguf files +//! +//! VarBuilder is a utility to store quantized tensors from a [GGUF model file](https://huggingface.co/docs/hub/gguf). +//! These tensors can be loaded from disk using `from_gguf` or from an in-memory +//! buffer using `from_gguf_buffer`. + use candle::quantized::QTensor; use candle::{Device, Result, Shape}; use std::sync::Arc; diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs index 17e836946f..884d4f378a 100644 --- a/candle-transformers/src/utils.rs +++ b/candle-transformers/src/utils.rs @@ -1,3 +1,5 @@ +//! Apply penalty and repeat_kv + use candle::{Result, Tensor}; pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result { From f86f4d62243d301b84c0992088be0effa153f22e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 19 Nov 2024 04:32:36 +0100 Subject: [PATCH 10/15] Tweak the CI to avoid running out of disk space. (#2630) * Tweak the CI to avoid running out of disk space. * Linux only. --- .github/workflows/rust-ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index db25503079..33d859dc36 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -37,6 +37,9 @@ jobs: os: [ubuntu-latest, windows-latest, macOS-latest] rust: [stable] steps: + - name: Delete huge unnecessary tools folder + if: runner.os == 'Linux' + run: rm -rf /opt/hostedtoolcache - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: From c12db594e389610c2b0d20fc90ecffd32c2f8d40 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Sat, 23 Nov 2024 02:40:00 -0500 Subject: [PATCH 11/15] fix typo (#2606) --- candle-core/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 75dc1c8a55..3169928893 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -242,7 +242,7 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, false) } - /// Creates a new tensor filled with ones with same shape, dtype, and device as the other + /// Creates a new tensor filled with zeros with same shape, dtype, and device as the other /// tensor. /// /// ```rust From b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd Mon Sep 17 00:00:00 2001 From: zachcp Date: Tue, 26 Nov 2024 16:52:53 -0500 Subject: [PATCH 12/15] Provide a method to allow PTH files with state maps to be loaded. (#2639) * Provide a method to allow PTH files iwth state maps to be loaded. * add a line to the doc * String-. &str --- candle-nn/src/var_builder.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0d836c7fd4..2731456d43 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + /// similar to [`from_pth`] but requires a `state_key`. + pub fn from_pth_with_state>( + p: P, + dtype: DType, + state_key: &str, + dev: &Device, + ) -> Result { + let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. /// From 21c686387cead049aad32e6d1cc494d6c79e46e3 Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Tue, 26 Nov 2024 23:10:09 +0100 Subject: [PATCH 13/15] Onnx Support for Sign operation #2641 (#2642) * Support for Sign operation #2641 * Apply rustfmt. --------- Co-authored-by: Laurent --- candle-onnx/src/eval.rs | 6 ++++++ candle-onnx/tests/ops.rs | 41 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 358af7acff..2c60ed2f23 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1944,6 +1944,12 @@ fn simple_eval_( values.insert(node.output[0].clone(), out); } + // https://onnx.ai/onnx/operators/onnx__Sign.html + "Sign" => { + let input = get(&node.input[0])?; + let output = input.sign()?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a84ba481ee..3586bfbd68 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> { } Ok(()) } + +#[test] +fn test_sign_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sign".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert( + INPUT_X.to_string(), + Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?, + ); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + assert_eq!( + z.to_dtype(candle::DType::I64)?.to_vec1::()?.to_vec(), + vec![-1, -1, 0, 1, 1] + ); + Ok(()) +} From 23ed8a9ded155df7b5961d6a5ae12b4e8096a9c2 Mon Sep 17 00:00:00 2001 From: Adam Nelson Date: Wed, 27 Nov 2024 22:35:11 +0100 Subject: [PATCH 14/15] Fix for whisper-microphone example failure if audio isn't chunk aligned (#2645) At least on my macOS Sequoia system (MBP 14" 2021, M1 Pro), when I run the `whisper-microphone` example after it has gathered 10 seconds of audio, it fails before the transcription: ``` Error: Insufficient buffer size 384 for input channel 0, expected 1024 ``` At least for the audio device I'm using (Airpods Pro Max), there is no guarantee that each audio buffer is a multiple of 1024 samples. Thus at the end of the 10 seconds, `buffered_pcm` can have some samples at the end that do not form a complete 1024 sample chunk. This fixes that by tracking when there is a partial chunk at the end of the buffer, and leaving it in `buffered_pcm` to be processed on the next loop iteration. Note that, in the interest of keeping this PR as small as possible, I didn't make any other changes to this example. --- .../examples/whisper-microphone/main.rs | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/whisper-microphone/main.rs b/candle-examples/examples/whisper-microphone/main.rs index 5165da1c1e..373c40e2bb 100644 --- a/candle-examples/examples/whisper-microphone/main.rs +++ b/candle-examples/examples/whisper-microphone/main.rs @@ -624,13 +624,27 @@ pub fn main() -> Result<()> { continue; } let mut resampled_pcm = vec![]; - for buffered_pcm in buffered_pcm.chunks(1024) { + // resample the audio, one chunk of 1024 samples at a time. + // in case the audio input failed to produce an exact multiple of 1024 samples, + // process the remainder on the next iteration of the loop. + let full_chunks = buffered_pcm.len() / 1024; + let remainder = buffered_pcm.len() % 1024; + for chunk in 0..full_chunks { + let buffered_pcm = &buffered_pcm[chunk * 1024..(chunk + 1) * 1024]; let pcm = resampler.process(&[&buffered_pcm], None)?; - resampled_pcm.extend_from_slice(&pcm[0]) + resampled_pcm.extend_from_slice(&pcm[0]); } let pcm = resampled_pcm; println!("{} {}", buffered_pcm.len(), pcm.len()); - buffered_pcm.clear(); + if remainder == 0 { + buffered_pcm.clear(); + } else { + // efficiently copy the remainder to the beginning of the `buffered_pcm` buffer and + // truncate it. That's more efficient then allocating a new vector and copying into it + println!("audio device produced partial chunk with {remainder} samples; processing the remainder on the next iteration of the loop"); + buffered_pcm.copy_within(full_chunks * 1024.., 0); + buffered_pcm.truncate(remainder); + } let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters); let mel_len = mel.len(); let mel = Tensor::from_vec( From 54e7fc3c97a6d40e459cee4d4bf2eff5c82390da Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Fri, 29 Nov 2024 03:30:21 +0530 Subject: [PATCH 15/15] Lint fixes introduced with Rust 1.83 (#2646) * Fixes for lint errors introduced with Rust 1.83 * rustfmt * Fix more lints. --------- Co-authored-by: Laurent --- candle-core/src/cpu_backend/mod.rs | 22 +++++++++---------- candle-core/src/quantized/gguf_file.rs | 2 +- candle-core/src/quantized/k_quants.rs | 4 ++-- candle-core/src/safetensors.rs | 2 +- candle-core/src/strided_index.rs | 2 +- candle-datasets/src/nlp/tinystories.rs | 2 +- .../examples/mamba-minimal/model.rs | 2 +- candle-examples/src/imagenet.rs | 1 - candle-metal-kernels/src/lib.rs | 20 ++++++++--------- candle-metal-kernels/src/utils.rs | 17 ++++++++------ candle-nn/src/func.rs | 8 +++---- candle-nn/src/var_builder.rs | 12 +++++----- candle-pyo3/src/lib.rs | 2 +- candle-transformers/src/models/convmixer.rs | 4 ++-- .../src/models/depth_anything_v2.rs | 2 +- .../src/models/efficientnet.rs | 4 ++-- candle-transformers/src/models/encodec.rs | 2 +- candle-transformers/src/models/mamba.rs | 2 +- .../src/models/stable_diffusion/utils.rs | 2 +- 19 files changed, 57 insertions(+), 55 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 229e3bbce1..11ff1a406f 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -66,7 +66,7 @@ impl Map2U8 for Cmp { struct WCond<'a, T: IntDType>(&'a [T], &'a Layout); -impl<'a, I: IntDType> Map2 for WCond<'a, I> { +impl Map2 for WCond<'_, I> { const OP: &'static str = "where"; #[inline(always)] fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> { @@ -216,7 +216,7 @@ struct ReduceSum<'a> { reduce_dims_and_stride: Vec<(usize, usize)>, } -impl<'a> ReduceSum<'a> { +impl ReduceSum<'_> { #[inline(always)] fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result> where @@ -281,7 +281,7 @@ impl<'a> ReduceSum<'a> { } } -impl<'a> Map1 for ReduceSum<'a> { +impl Map1 for ReduceSum<'_> { #[inline(always)] fn f(&self, src: &[T], src_l: &Layout) -> Result> { self.fold_impl(src, src_l, T::zero()) @@ -454,7 +454,7 @@ struct Gather<'a, I: IntDType> { dim: usize, } -impl<'a, I: IntDType> Map1 for Gather<'a, I> { +impl Map1 for Gather<'_, I> { fn f(&self, src: &[T], src_l: &Layout) -> Result> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], @@ -507,7 +507,7 @@ struct IndexSelect<'a, T: IntDType> { dim: usize, } -impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { +impl Map1 for IndexSelect<'_, I> { fn f(&self, src: &[T], layout: &Layout) -> Result> { let src = match layout.contiguous_offsets() { Some((a, b)) => &src[a..b], @@ -560,7 +560,7 @@ struct ScatterAdd<'a, I: IntDType> { dim: usize, } -impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { +impl Map2 for ScatterAdd<'_, I> { const OP: &'static str = "scatter-add"; fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { let dst_len = l1.shape().elem_count(); @@ -616,7 +616,7 @@ struct IndexAdd<'a, I: IntDType> { dim: usize, } -impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { +impl Map2 for IndexAdd<'_, I> { const OP: &'static str = "index-add"; // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_ // v1, l1 -> self @@ -736,7 +736,7 @@ fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); -impl<'a> Map2 for Conv1D<'a> { +impl Map2 for Conv1D<'_> { const OP: &'static str = "conv1d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; @@ -960,7 +960,7 @@ impl Map1 for Col2Im1D { struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); -impl<'a> Map2 for ConvTranspose1D<'a> { +impl Map2 for ConvTranspose1D<'_> { const OP: &'static str = "conv_transpose1d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; @@ -1029,7 +1029,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> { struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); -impl<'a> Map2 for Conv2D<'a> { +impl Map2 for Conv2D<'_> { const OP: &'static str = "conv2d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; @@ -1117,7 +1117,7 @@ impl<'a> Map2 for Conv2D<'a> { struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); -impl<'a> Map2 for ConvTranspose2D<'a> { +impl Map2 for ConvTranspose2D<'_> { const OP: &'static str = "conv_transpose2d"; fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index cdd1a1543e..ccbd59eb5c 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -457,7 +457,7 @@ impl Content { Some(Value::I32(v)) if *v >= 0 => *v as u64, _ => DEFAULT_ALIGNMENT, }; - let tensor_data_offset = (position + alignment - 1) / alignment * alignment; + let tensor_data_offset = position.div_ceil(alignment) * alignment; Ok(Self { magic, metadata, diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 6210ac1e9f..1d3e053898 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1850,8 +1850,8 @@ pub fn matmul( crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); } - let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE; - let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE; + let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE); + let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE); // TODO: Do not make this copy if the DotType is f32. // TODO: Pre-allocate this. let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 618e391e34..d402d6b8e0 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -182,7 +182,7 @@ pub trait Load { fn load(&self, device: &Device) -> Result; } -impl<'a> Load for st::TensorView<'a> { +impl Load for st::TensorView<'_> { fn load(&self, device: &Device) -> Result { convert(self, device) } diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index eb6a736f83..9354e8ea3c 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -32,7 +32,7 @@ impl<'a> StridedIndex<'a> { } } -impl<'a> Iterator for StridedIndex<'a> { +impl Iterator for StridedIndex<'_> { type Item = usize; fn next(&mut self) -> Option { diff --git a/candle-datasets/src/nlp/tinystories.rs b/candle-datasets/src/nlp/tinystories.rs index c657c9eb6b..ba471728f3 100644 --- a/candle-datasets/src/nlp/tinystories.rs +++ b/candle-datasets/src/nlp/tinystories.rs @@ -87,7 +87,7 @@ impl<'a> DatasetRandomIter<'a> { } } -impl<'a> Iterator for DatasetRandomIter<'a> { +impl Iterator for DatasetRandomIter<'_> { type Item = Result<(Tensor, Tensor)>; fn next(&mut self) -> Option { diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs index 4a0a345d17..7ebea76a8d 100644 --- a/candle-examples/examples/mamba-minimal/model.rs +++ b/candle-examples/examples/mamba-minimal/model.rs @@ -17,7 +17,7 @@ pub struct Config { impl Config { fn vocab_size(&self) -> usize { let pad = self.pad_vocab_size_multiple; - (self.vocab_size + pad - 1) / pad * pad + self.vocab_size.div_ceil(pad) * pad } fn dt_rank(&self) -> usize { diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index a3b1242387..ca77b5df06 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -6,7 +6,6 @@ pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; /// Loads an image from disk using the image crate at the requested resolution, /// using the given std and mean parameters. /// This returns a tensor with shape (3, res, res). imagenet normalization is applied. - pub fn load_image_with_std_mean>( p: P, res: usize, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0843cc1179..5f948cbf4c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -372,7 +372,7 @@ pub fn call_unary_contiguous_tiled( let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let tile_size = 2; - let tiles = (length + tile_size - 1) / tile_size; + let tiles = length.div_ceil(tile_size); encoder.set_compute_pipeline_state(&pipeline); @@ -594,7 +594,7 @@ pub fn call_reduce_contiguous( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64 + 2 - 1) / 2, + (elements_to_sum as u64).div_ceil(2), ) .next_power_of_two(); @@ -1735,7 +1735,7 @@ pub fn call_sdpa_full( } }; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1759,16 +1759,16 @@ pub fn call_sdpa_full( let ldo = dk; let tn = 1; - let tm = (m + BM - 1) / BM; + let tm = m.div_ceil(BM); let b_stride_q = dk * qseq; let b_stride_k = dk * qseq; let b_stride_v = dk * qseq; let b_stride_o = dk * qseq; let swizzle_log = 0; - let gemm_n_iterations_aligned = (n + BN - 1) / BN; - let gemm_k_iterations_aligned = (k + bk - 1) / bk; - let gemm_sv_m_block_iterations = (m + BM - 1) / BM; + let gemm_n_iterations_aligned = n.div_ceil(BN); + let gemm_k_iterations_aligned = k.div_ceil(*bk); + let gemm_sv_m_block_iterations = m.div_ceil(BM); let batch_ndim = batch_shape.len(); let alpha = if softcapping != 1. { @@ -1906,7 +1906,7 @@ pub fn call_sdpa_vector( alpha }; - let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1933,7 +1933,7 @@ pub fn call_sdpa_vector( let grid_dims = MTLSize { width: 1, height: b as u64, - depth: 1 as u64, + depth: 1_u64, }; let group_dims = MTLSize { width: 1024, @@ -2320,7 +2320,7 @@ pub fn call_quantized_matmul_mv_t( } fn divide(m: usize, b: usize) -> NSUInteger { - ((m + b - 1) / b) as NSUInteger + m.div_ceil(b) as NSUInteger } #[allow(clippy::too_many_arguments)] diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 0092ecfa58..025808d754 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -8,7 +8,7 @@ use std::ffi::c_void; pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); - let count = (size + width - 1) / width; + let count = size.div_ceil(width); let thread_group_count = MTLSize { width: count, height: 1, @@ -128,7 +128,7 @@ impl EncoderParam for (&Buffer, usize) { } } -impl<'a> EncoderParam for &BufferOffset<'a> { +impl EncoderParam for &BufferOffset<'_> { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); } @@ -169,7 +169,7 @@ pub struct WrappedEncoder<'a> { end_encoding_on_drop: bool, } -impl<'a> Drop for WrappedEncoder<'a> { +impl Drop for WrappedEncoder<'_> { fn drop(&mut self) { if self.end_encoding_on_drop { self.inner.end_encoding() @@ -177,14 +177,15 @@ impl<'a> Drop for WrappedEncoder<'a> { } } -impl<'a> AsRef for WrappedEncoder<'a> { +impl AsRef for WrappedEncoder<'_> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { self.inner } } impl EncoderProvider for &metal::CommandBuffer { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { @@ -196,7 +197,8 @@ impl EncoderProvider for &metal::CommandBuffer { } impl EncoderProvider for &metal::CommandBufferRef { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { @@ -208,7 +210,8 @@ impl EncoderProvider for &metal::CommandBufferRef { } impl EncoderProvider for &ComputeCommandEncoderRef { - type Encoder<'a> = WrappedEncoder<'a> + type Encoder<'a> + = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index 3adfda860d..72744404ac 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -9,7 +9,7 @@ pub struct Func<'a> { f: Arc Result + Send + Sync>, } -impl<'a> std::fmt::Debug for Func<'a> { +impl std::fmt::Debug for Func<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "func") } @@ -22,7 +22,7 @@ where Func { f: Arc::new(f) } } -impl<'a> super::Module for Func<'a> { +impl super::Module for Func<'_> { fn forward(&self, xs: &Tensor) -> Result { (*self.f)(xs) } @@ -44,7 +44,7 @@ pub struct FuncT<'a> { f: Arc Result + Send + Sync>, } -impl<'a> std::fmt::Debug for FuncT<'a> { +impl std::fmt::Debug for FuncT<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "func") } @@ -57,7 +57,7 @@ where FuncT { f: Arc::new(f) } } -impl<'a> super::ModuleT for FuncT<'a> { +impl super::ModuleT for FuncT<'_> { fn forward_t(&self, xs: &Tensor, train: bool) -> Result { (*self.f)(xs, train) } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 2731456d43..ba410e4ea8 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -20,7 +20,7 @@ pub struct VarBuilderArgs<'a, B: Backend> { _phantom: std::marker::PhantomData<&'a B>, } -impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { +impl Clone for VarBuilderArgs<'_, B> { fn clone(&self) -> Self { Self { data: self.data.clone(), @@ -76,7 +76,7 @@ pub trait SimpleBackend: Send + Sync { fn contains_tensor(&self, name: &str) -> bool; } -impl<'a> Backend for Box { +impl Backend for Box { type Hints = crate::Init; fn get( &self, @@ -94,7 +94,7 @@ impl<'a> Backend for Box { } } -impl<'a, B: Backend> VarBuilderArgs<'a, B> { +impl VarBuilderArgs<'_, B> { pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { let data = TensorData { backend, @@ -286,7 +286,7 @@ pub struct SafeTensorWithRouting<'a> { safetensors: Vec>, } -impl<'a> SimpleBackend for SafeTensorWithRouting<'a> { +impl SimpleBackend for SafeTensorWithRouting<'_> { fn get( &self, s: Shape, @@ -439,7 +439,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { } } -impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> { +impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { fn get( &self, s: Shape, @@ -732,7 +732,7 @@ pub struct Rename<'a, R: Renamer> { renamer: R, } -impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> { +impl SimpleBackend for Rename<'_, R> { fn get( &self, s: Shape, diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 722b5e3ace..b8695cc8a0 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -276,7 +276,7 @@ impl PyTensor { /// &RETURNS&: _ArrayLike fn values(&self, py: Python<'_>) -> PyResult { struct M<'a>(Python<'a>); - impl<'a> MapDType for M<'a> { + impl MapDType for M<'_> { type Output = PyObject; fn f(&self, t: &Tensor) -> PyResult { match t.rank() { diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs index 7f1b75ebc4..7f92479431 100644 --- a/candle-transformers/src/models/convmixer.rs +++ b/candle-transformers/src/models/convmixer.rs @@ -21,8 +21,8 @@ fn conv2d_same( let module = candle_nn::func(move |xs| { let ih = xs.dim(2)?; let iw = xs.dim(3)?; - let oh = (ih + s - 1) / s; - let ow = (iw + s - 1) / s; + let oh = ih.div_ceil(s); + let ow = iw.div_ceil(s); let pad_h = usize::max((oh - 1) * s + k - ih, 0); let pad_w = usize::max((ow - 1) * s + k - iw, 0); if pad_h > 0 || pad_w > 0 { diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs index 411b0764ff..8eddbf2af5 100644 --- a/candle-transformers/src/models/depth_anything_v2.rs +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -543,7 +543,7 @@ impl<'a> DepthAnythingV2<'a> { } } -impl<'a> Module for DepthAnythingV2<'a> { +impl Module for DepthAnythingV2<'_> { fn forward(&self, xs: &Tensor) -> Result { let features = self.pretrained.get_intermediate_layers( xs, diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs index ecca2509ae..36754f2102 100644 --- a/candle-transformers/src/models/efficientnet.rs +++ b/candle-transformers/src/models/efficientnet.rs @@ -125,8 +125,8 @@ impl Module for Conv2DSame { let s = self.s; let k = self.k; let (_, _, ih, iw) = xs.dims4()?; - let oh = (ih + s - 1) / s; - let ow = (iw + s - 1) / s; + let oh = ih.div_ceil(s); + let ow = iw.div_ceil(s); let pad_h = usize::max((oh - 1) * s + k - ih, 0); let pad_w = usize::max((ow - 1) * s + k - iw, 0); if pad_h > 0 || pad_w > 0 { diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index 517b9b1d7e..d8dff74c0e 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -89,7 +89,7 @@ impl Config { fn frame_rate(&self) -> usize { let hop_length: usize = self.upsampling_ratios.iter().product(); - (self.sampling_rate + hop_length - 1) / hop_length + self.sampling_rate.div_ceil(hop_length) } fn num_quantizers(&self) -> usize { diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index 18a0285ff6..a29f261955 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -23,7 +23,7 @@ pub struct Config { impl Config { fn vocab_size(&self) -> usize { let pad = self.pad_vocab_size_multiple; - (self.vocab_size + pad - 1) / pad * pad + self.vocab_size.div_ceil(pad) * pad } fn dt_rank(&self) -> usize { diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index 5b5fa0f797..0118bafc54 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -21,7 +21,7 @@ struct LinearInterpolator<'x, 'y> { cache: usize, } -impl<'x, 'y> LinearInterpolator<'x, 'y> { +impl LinearInterpolator<'_, '_> { fn accel_find(&mut self, x: f64) -> usize { let xidx = self.cache; if x < self.xp[xidx] {