diff --git a/causal-lm/src/render.rs b/causal-lm/src/chat_template.rs similarity index 69% rename from causal-lm/src/render.rs rename to causal-lm/src/chat_template.rs index c0c84141..a7725328 100644 --- a/causal-lm/src/render.rs +++ b/causal-lm/src/chat_template.rs @@ -1,4 +1,4 @@ -use crate::Tokenize; +use crate::Tokenizer; use common::GGufModel; use minijinja::Environment; use serde::Serialize; @@ -21,39 +21,38 @@ pub struct Message<'a> { pub content: &'a str, } -/// Build a chat template from the GGuf model. -pub fn build_render(gguf: &GGufModel, tokenize: &dyn Tokenize) -> Option { - let template = gguf - .meta_kvs - .get("tokenizer.chat_template")? - .value_reader() - .read_str() - .unwrap() - .into(); - - let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"] - .value_reader() - .read::() - .unwrap(); - let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"] - .value_reader() - .read::() - .unwrap(); +impl ChatTemplate { + pub fn from_gguf(gguf: &GGufModel, tokenize: &Tokenizer) -> Option { + let template = gguf + .meta_kvs + .get("tokenizer.chat_template")? + .value_reader() + .read_str() + .unwrap() + .into(); - Some(ChatTemplate::new( - template, - tokenize.decode(bos).into(), - tokenize.decode(eos).into(), - )) -} + let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"] + .value_reader() + .read::() + .unwrap(); + let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"] + .value_reader() + .read::() + .unwrap(); + + Some(ChatTemplate::new( + template, + tokenize.decode(bos).into(), + tokenize.decode(eos).into(), + )) + } -impl ChatTemplate { /// Create a new chat template. pub fn new(template: String, bos: String, eos: String) -> Self { static NEXT: AtomicUsize = AtomicUsize::new(0); let id = NEXT.fetch_add(1, Relaxed).to_string(); - jinja() + JINJA_ENV .write() .unwrap() .add_template_owned(id.clone(), template) @@ -76,7 +75,7 @@ impl ChatTemplate { add_generation_prompt: bool, } - jinja() + JINJA_ENV .read() .unwrap() .get_template(&self.id) @@ -92,26 +91,23 @@ impl ChatTemplate { impl Drop for ChatTemplate { fn drop(&mut self) { - jinja().write().unwrap().remove_template(&self.id); + JINJA_ENV.write().unwrap().remove_template(&self.id); } } -fn jinja() -> &'static RwLock> { - static ENV: LazyLock>> = LazyLock::new(|| { - let mut env = Environment::empty(); - env.set_unknown_method_callback(|_, value, method, args| { - use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value}; - match (method, value.kind(), args) { - ("strip", ThisType::String, []) => Ok(Value::from_safe_string( - value.to_str().unwrap().trim().into(), - )), - _ => Err(UnknownMethod.into()), - } - }); - RwLock::new(env) +static JINJA_ENV: LazyLock>> = LazyLock::new(|| { + let mut env = Environment::empty(); + env.set_unknown_method_callback(|_, value, method, args| { + use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value}; + match (method, value.kind(), args) { + ("strip", ThisType::String, []) => Ok(Value::from_safe_string( + value.to_str().unwrap().trim().into(), + )), + _ => Err(UnknownMethod.into()), + } }); - &ENV -} + RwLock::new(env) +}); #[test] fn test() { diff --git a/causal-lm/src/lib.rs b/causal-lm/src/lib.rs index 668b6d95..40ebe667 100644 --- a/causal-lm/src/lib.rs +++ b/causal-lm/src/lib.rs @@ -1,21 +1,21 @@ #![doc = include_str!("../README.md")] -#![deny(warnings, missing_docs)] +// #![deny(warnings, missing_docs)] +mod chat_template; mod decoding; mod query_context; -mod render; -mod tokenize; +mod tokenizer; use common::{upos, utok}; use digit_layout::types::U32; -use std::{path::Path, time::Duration}; +use std::{io::Write, path::Path}; use tensor::{udim, Tensor}; +pub use chat_template::ChatTemplate; pub use decoding::DecodingMeta; pub use operators::random_sample::SampleArgs; pub use query_context::QueryContext; -pub use render::{build_render, ChatTemplate}; -pub use tokenize::{build_tokenize, Tokenize}; +pub use tokenizer::Tokenizer; /// 从文件系统加载的模型。 pub trait Model: Sized { @@ -24,7 +24,7 @@ pub trait Model: Sized { /// 模型加载中可能的错误。 type Error; /// 从文件系统加载模型。 - fn load(gguf: impl AsRef, meta: Self::Config) -> Result, Self::Error>; + fn load(gguf: impl AsRef, config: Self::Config) -> Result, Self::Error>; } /// 从 GGuf 文件加载模型、分词器和渲染模板。 @@ -32,9 +32,9 @@ pub struct FromGGuf { /// 模型。 pub model: M, /// 分词器。 - pub tokenize: Box, + pub tokenizer: Tokenizer, /// 渲染模板。 - pub render: Option, + pub chat_template: Option, } /// 因果语言模型。 @@ -119,32 +119,38 @@ pub fn pos<'a, S: 'a>( } /// 测试模型实现。 -pub fn test_impl(meta: M::Config, prompt: &[utok]) +pub fn test_impl(meta: M::Config, max_steps: usize, prompt: &str) where M: CausalLM, M::Error: std::fmt::Debug, { - use std::time::Instant; + use std::time::{Duration, Instant}; let Some(gguf) = common::test_model::find() else { return; }; println!("model: {}", gguf.display()); - let t0 = Instant::now(); - let FromGGuf { model, .. } = M::load(gguf, meta).unwrap(); - let t1 = Instant::now(); - println!("load {:?}", t1 - t0); + let time = Instant::now(); + let FromGGuf { + model, tokenizer, .. + } = M::load(gguf, meta).unwrap(); + println!("load {:?}", time.elapsed()); - let mut cache = model.new_cache(); + let mut prompt = tokenizer.encode(prompt); + print!("prompt:"); + for t in &prompt { + print!(" {t}"); + } - let mut prompt = prompt.to_vec(); + let mut tokens = prompt.clone(); let mut pos = 0; let mut time = Duration::ZERO; let mut steps = 0; - while prompt != [model.eos_token()] { + let mut cache = model.new_cache(); + while prompt != [model.eos_token()] && steps <= max_steps { let start = Instant::now(); let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied()); @@ -165,21 +171,33 @@ where num_decode: 1, args: SampleArgs::ARG_MAX, }]; - let tokens = CausalLM::sample(&model, args, logits); + let token = CausalLM::sample(&model, args, logits)[0]; if steps > 0 { time += start.elapsed(); } steps += 1; - println!("{:?}", tokens); + print!(" {token}"); + std::io::stdout().flush().unwrap(); + pos += prompt.len() as upos; - prompt = tokens; + prompt.clear(); + prompt.push(token); + tokens.push(token); } steps -= 1; + println!(); println!( "steps = {steps}, average decoding time = {:?}", time.div_f32(steps as _) ); + println!(); + println!("---"); + for t in tokens { + print!("{}", tokenizer.decode(t)); + } + println!(); + println!("---"); } diff --git a/causal-lm/src/tokenize.rs b/causal-lm/src/tokenize.rs deleted file mode 100644 index 517c253d..00000000 --- a/causal-lm/src/tokenize.rs +++ /dev/null @@ -1,132 +0,0 @@ -use common::GGufModel; -use ggus::{GGmlTokenType, GGufMetaDataValueType}; -use std::str::{from_utf8, from_utf8_unchecked}; -use tokeneer::{utok, Bpe, Method, Tokeneer}; - -/// A trait for tokenization. -pub trait Tokenize { - /// Encode a text into a sequence of tokens. - fn encode(&self, text: &str) -> Vec; - /// Decode a token into str. - fn decode(&self, token: utok) -> &str; -} - -impl Tokenize for Tokeneer { - #[inline] - fn encode(&self, text: &str) -> Vec { - self.encode(text) - } - #[inline] - fn decode(&self, token: utok) -> &str { - unsafe { from_utf8_unchecked(self.internal().decode(token)) } - } -} - -/// Build a polymorphic tokenize from the GGuf model. -pub fn build_tokenize(gguf: &GGufModel) -> Box { - let model = gguf.meta_kvs["tokenizer.ggml.model"] - .value_reader() - .read_str() - .unwrap(); - match model { - "llama" => Box::new(build_bpe(gguf)), - _ => panic!("Unsupported tokenizer model: {model}"), - } -} - -fn build_bpe(gguf: &GGufModel) -> Tokeneer { - let _pre = gguf.meta_kvs["tokenizer.ggml.pre"] - .value_reader() - .read_str() - .unwrap(); - let mut tokens = gguf.meta_kvs["tokenizer.ggml.tokens"].value_reader(); - let mut scores = gguf.meta_kvs["tokenizer.ggml.scores"].value_reader(); - let mut token_type = gguf.meta_kvs["tokenizer.ggml.token_type"].value_reader(); - - let unk = gguf.meta_kvs["tokenizer.ggml.unknown_token_id"] - .value_reader() - .read::() - .unwrap(); - let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"] - .value_reader() - .read::() - .unwrap(); - let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"] - .value_reader() - .read::() - .unwrap(); - - let (ty, len) = tokens.read_arr_header().unwrap(); - assert_eq!(ty, GGufMetaDataValueType::String); - - let (ty, len_) = scores.read_arr_header().unwrap(); - assert_eq!(ty, GGufMetaDataValueType::F32); - assert_eq!(len_, len); - - let (ty, len_) = token_type.read_arr_header().unwrap(); - assert_eq!(ty, GGufMetaDataValueType::I32); - assert_eq!(len_, len); - - let vocabs = (0..len).map(|_| tokens.read_str().unwrap()); - let scores = (0..len).map(|_| scores.read::().unwrap()); - let is_byte = - (0..len).map(|_| token_type.read::().unwrap() == GGmlTokenType::Byte); - - let bpe = Bpe::new(vocabs, scores, is_byte, unk); - let bos_piece = from_utf8(bpe.decode(bos)).unwrap().to_string(); - let eos_piece = from_utf8(bpe.decode(eos)).unwrap().to_string(); - - let mut tokeneer = Tokeneer::new(bpe); - tokeneer.extend_special([(bos_piece, vec![bos]), (eos_piece, vec![eos])]); - tokeneer -} - -// pub trait Normalizer { -// fn encode<'a>(&self, text: &'a str) -> Cow<'a, str>; -// fn decode<'a>(&self, text: &'a str) -> Cow<'a, str>; -// } - -// impl Normalizer for () { -// #[inline] -// fn encode<'a>(&self, text: &'a str) -> Cow<'a, str> { -// Cow::Borrowed(text) -// } - -// #[inline] -// fn decode<'a>(&self, text: &'a str) -> Cow<'a, str> { -// Cow::Borrowed(text) -// } -// } - -// #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -// pub struct BPECommonNormalizer; - -// impl Normalizer for BPECommonNormalizer { -// fn encode<'a>(&self, text: &'a str) -> Cow<'a, str> { -// let mut ans = String::new(); -// if text -// .chars() -// .next() -// .filter(char::is_ascii_alphabetic) -// .is_some() -// { -// ans.push('▁'); -// } -// for c in text.chars() { -// ans.push(match c { -// ' ' => '▁', -// c => c, -// }); -// } -// Cow::Owned(ans) -// } - -// #[inline] -// fn decode<'a>(&self, text: &'a str) -> Cow<'a, str> { -// if text.contains('▁') { -// Cow::Owned(text.replace('▁', " ")) -// } else { -// Cow::Borrowed(text) -// } -// } -// } diff --git a/causal-lm/src/tokenizer.rs b/causal-lm/src/tokenizer.rs new file mode 100644 index 00000000..496461f3 --- /dev/null +++ b/causal-lm/src/tokenizer.rs @@ -0,0 +1,137 @@ +use common::GGufModel; +use ggus::{GGmlTokenType, GGufMetaDataValueType}; +use std::{ + borrow::Cow, + str::{from_utf8, from_utf8_unchecked}, +}; +use tokeneer::{utok, Bpe, Method, Tokeneer}; + +pub struct Tokenizer { + tokenize: Box, + replace_space: Option, +} + +impl Tokenizer { + pub fn from_gguf(gguf: &GGufModel) -> Self { + let model = gguf.meta_kvs["tokenizer.ggml.model"] + .value_reader() + .read_str() + .unwrap(); + match model { + "llama" => Self::bpe_from_gguf(gguf), + _ => panic!("Unsupported tokenizer model: {model}"), + } + } + + pub fn encode(&self, text: &str) -> Vec { + let space = self.replace_space.unwrap_or(' '); + let mut chars = text.chars(); + let mut text = match chars.next() { + Some(c) => { + if c.is_ascii_alphabetic() { + format!("{space}{c}") + } else { + format!("{c}") + } + } + None => return vec![], + }; + for c in chars { + text.push(match c { + ' ' => space, + c => c, + }) + } + self.tokenize.encode(&text) + } + pub fn decode(&self, token: utok) -> Cow { + let piece = self.tokenize.decode(token); + if let Some(c) = self.replace_space { + piece.replace(c, " ").into() + } else { + piece.into() + } + } + + fn bpe_from_gguf(gguf: &GGufModel) -> Self { + let _pre = gguf.meta_kvs["tokenizer.ggml.pre"] + .value_reader() + .read_str() + .unwrap(); + let mut tokens = gguf.meta_kvs["tokenizer.ggml.tokens"].value_reader(); + let mut scores = gguf.meta_kvs["tokenizer.ggml.scores"].value_reader(); + let mut token_type = gguf.meta_kvs["tokenizer.ggml.token_type"].value_reader(); + + let unk = gguf.meta_kvs["tokenizer.ggml.unknown_token_id"] + .value_reader() + .read::() + .unwrap(); + let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"] + .value_reader() + .read::() + .unwrap(); + let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"] + .value_reader() + .read::() + .unwrap(); + + let (ty, len) = tokens.read_arr_header().unwrap(); + assert_eq!(ty, GGufMetaDataValueType::String); + + let (ty, len_) = scores.read_arr_header().unwrap(); + assert_eq!(ty, GGufMetaDataValueType::F32); + assert_eq!(len_, len); + + let (ty, len_) = token_type.read_arr_header().unwrap(); + assert_eq!(ty, GGufMetaDataValueType::I32); + assert_eq!(len_, len); + // + let mut space_exist = false; + let mut replace_exist = false; + let vocabs = (0..len).map(|_| { + let piece = tokens.read_str().unwrap(); + match piece { + " " => space_exist = true, + "▁" => replace_exist = true, + _ => {} + } + piece + }); + let scores = (0..len).map(|_| scores.read::().unwrap()); + let is_byte = (0..len).map(|_| GGmlTokenType::Byte == token_type.read().unwrap()); + + let bpe = Bpe::new(vocabs, scores, is_byte, unk); + let bos_piece = from_utf8(bpe.decode(bos)).unwrap().to_string(); + let eos_piece = from_utf8(bpe.decode(eos)).unwrap().to_string(); + + let mut tokeneer = Tokeneer::new(bpe); + tokeneer.extend_special([(bos_piece, vec![bos]), (eos_piece, vec![eos])]); + Self { + tokenize: Box::new(tokeneer), + replace_space: match (space_exist, replace_exist) { + (true, _) => None, + (false, true) => Some('▁'), + (false, false) => panic!("Unknown user-defined space"), + }, + } + } +} + +/// A trait for tokenization. +trait Tokenize { + /// Encode a text into a sequence of tokens. + fn encode(&self, text: &str) -> Vec; + /// Decode a token into str. + fn decode(&self, token: utok) -> &str; +} + +impl Tokenize for Tokeneer { + #[inline] + fn encode(&self, text: &str) -> Vec { + self.encode(text) + } + #[inline] + fn decode(&self, token: utok) -> &str { + unsafe { from_utf8_unchecked(self.internal().decode(token)) } + } +} diff --git a/models/llama/common-cpu/src/lib.rs b/models/llama/common-cpu/src/lib.rs index e676efcf..d8ced370 100644 --- a/models/llama/common-cpu/src/lib.rs +++ b/models/llama/common-cpu/src/lib.rs @@ -1,5 +1,5 @@ use causal_lm::{ - build_render, build_tokenize, CausalLM, DecodingMeta, FromGGuf, Model, QueryContext, SampleMeta, + CausalLM, ChatTemplate, DecodingMeta, FromGGuf, Model, QueryContext, SampleMeta, Tokenizer, }; use common::{map_files, upos, utok, Blob, GGufModel}; use common_cpu::{ @@ -32,9 +32,9 @@ impl Model for Transformer { let _files = map_files(gguf); let gguf = GGufModel::read(_files.iter().map(|f| &**f)); - let tokenize = build_tokenize(&gguf); - let render = build_render(&gguf, &*tokenize); - let model = LlamaModel::from_gguf(&gguf); + let tokenizer = Tokenizer::from_gguf(&gguf); + let chat_template = ChatTemplate::from_gguf(&gguf, &tokenizer); + let llama = LlamaModel::from_gguf(&gguf); #[inline(always)] const fn keep_lifetime(data: &[u8]) -> &'static [u8] { @@ -42,11 +42,11 @@ impl Model for Transformer { } let model = Self { - meta: model.meta.clone(), - token_embed: keep_lifetime(model.token_embed), - output_norm: keep_lifetime(model.output_norm), - output: keep_lifetime(model.output), - blocks: model + meta: llama.meta.clone(), + token_embed: keep_lifetime(llama.token_embed), + output_norm: keep_lifetime(llama.output_norm), + output: keep_lifetime(llama.output), + blocks: llama .blocks .iter() .map(|blk| blk.as_ref().map(|s| keep_lifetime(s))) @@ -57,8 +57,8 @@ impl Model for Transformer { Ok(FromGGuf { model, - tokenize, - render, + tokenizer, + chat_template, }) } } @@ -293,11 +293,5 @@ impl CausalLM for Transformer { #[test] fn test_infer() { - causal_lm::test_impl::( - (), - &[ - 29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106, - 29879, 5299, 29989, 465, 22137, 29989, 29958, 13, - ], - ); + causal_lm::test_impl::((), 100, "Once upon a time,"); } diff --git a/models/llama/nvidia-gpu/src/lib.rs b/models/llama/nvidia-gpu/src/lib.rs index e8a6469c..22e53987 100644 --- a/models/llama/nvidia-gpu/src/lib.rs +++ b/models/llama/nvidia-gpu/src/lib.rs @@ -6,7 +6,7 @@ mod resource; extern crate log; use causal_lm::{ - build_render, build_tokenize, CausalLM, DecodingMeta, FromGGuf, Model, QueryContext, SampleMeta, + CausalLM, ChatTemplate, DecodingMeta, FromGGuf, Model, QueryContext, SampleMeta, Tokenizer, }; use common::{map_files, upos, utok, Blob, GGufModel}; use common_nv::{ @@ -82,8 +82,8 @@ impl Model for Transformer { let _files = map_files(gguf); let gguf = GGufModel::read(_files.iter().map(|f| &**f)); - let tokenize = build_tokenize(&gguf); - let render = build_render(&gguf, &*tokenize); + let tokenizer = Tokenizer::from_gguf(&gguf); + let chat_template = ChatTemplate::from_gguf(&gguf, &tokenizer); let llama = LlamaModel::from_gguf(&gguf); let LlamaMeta { dt_norm, @@ -157,8 +157,8 @@ impl Model for Transformer { Ok(FromGGuf { model, - tokenize, - render, + tokenizer, + chat_template, }) } } @@ -627,9 +627,7 @@ fn test_infer() { device, load_layers: 20, }, - &[ - 29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106, - 29879, 5299, 29989, 465, 22137, 29989, 29958, 13, - ], + 100, + "Once upon a time,", ); }