From 197392558b8f1e6e9e0018e6fa54c9a0326c29f0 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 23 Aug 2024 11:07:16 +0800 Subject: [PATCH] =?UTF-8?q?refactor(causal-lm):=20render=20=E5=AD=98?= =?UTF-8?q?=E5=82=A8=20bos/eos=EF=BC=8C=E4=B8=8D=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E5=A4=96=E9=83=A8=E6=8F=90=E4=BE=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .cargo/config.toml | 1 - Cargo.lock | 6 +-- Cargo.toml | 2 +- causal-lm/src/render.rs | 59 ++++++++++++++++---------- models/llama/common-cpu/src/lib.rs | 2 +- models/llama/nvidia-gpu/src/lib.rs | 2 +- service/Cargo.toml | 1 - xtask/src/cast.rs | 67 ------------------------------ xtask/src/main.rs | 1 - 9 files changed, 42 insertions(+), 99 deletions(-) delete mode 100644 xtask/src/cast.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index 01d5e979..9baa024b 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -5,5 +5,4 @@ list-turbo = "xtask list-turbo" deploy = "xtask deploy" generate = "xtask generate" chat = "xtask chat" -cast = "xtask cast" service = "xtask service" diff --git a/Cargo.lock b/Cargo.lock index 8425cba4..76c81429 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "ggus" version = "0.2.0" -source = "git+https://github.com/YdrMaster/gguf?rev=aa1281a#aa1281a937e970ed2a5b3ba67a32329adca9dafe" +source = "git+https://github.com/YdrMaster/gguf?rev=1f4a7ba#1f4a7ba1f2c3c30c8881e697c718787d29ab294a" dependencies = [ "fancy-regex", "indexmap", @@ -909,9 +909,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] diff --git a/Cargo.toml b/Cargo.toml index 72b18dff..cc0485b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] } digit-layout = "0.0" build-script-cfg = "0.0" -ggus = { git = "https://github.com/YdrMaster/gguf", rev = "aa1281a" } +ggus = { git = "https://github.com/YdrMaster/gguf", rev = "1f4a7ba" } operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "29b950c", default-features = false } search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d089ada" } search-neuware-tools = "0.0" diff --git a/causal-lm/src/render.rs b/causal-lm/src/render.rs index e188d472..c0c84141 100644 --- a/causal-lm/src/render.rs +++ b/causal-lm/src/render.rs @@ -1,14 +1,19 @@ -use common::GGufModel; +use crate::Tokenize; +use common::GGufModel; use minijinja::Environment; use serde::Serialize; use std::sync::{ atomic::{AtomicUsize, Ordering::Relaxed}, - OnceLock, RwLock, + LazyLock, RwLock, }; +use tokeneer::utok; /// A template for rendering chat messages. -#[repr(transparent)] -pub struct ChatTemplate(String); +pub struct ChatTemplate { + id: String, + bos: String, + eos: String, +} #[derive(Serialize)] pub struct Message<'a> { @@ -17,7 +22,7 @@ pub struct Message<'a> { } /// Build a chat template from the GGuf model. -pub fn build_render(gguf: &GGufModel) -> Option { +pub fn build_render(gguf: &GGufModel, tokenize: &dyn Tokenize) -> Option { let template = gguf .meta_kvs .get("tokenizer.chat_template")? @@ -25,12 +30,26 @@ pub fn build_render(gguf: &GGufModel) -> Option { .read_str() .unwrap() .into(); - Some(ChatTemplate::new(template)) + + let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"] + .value_reader() + .read::() + .unwrap(); + let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"] + .value_reader() + .read::() + .unwrap(); + + Some(ChatTemplate::new( + template, + tokenize.decode(bos).into(), + tokenize.decode(eos).into(), + )) } impl ChatTemplate { /// Create a new chat template. - pub fn new(template: String) -> Self { + pub fn new(template: String, bos: String, eos: String) -> Self { static NEXT: AtomicUsize = AtomicUsize::new(0); let id = NEXT.fetch_add(1, Relaxed).to_string(); @@ -40,15 +59,13 @@ impl ChatTemplate { .add_template_owned(id.clone(), template) .unwrap(); - Self(id) + Self { id, bos, eos } } /// Render the chat template with the given messages. pub fn render( &self, messages: &[Message], - bos_token: &str, - eos_token: &str, add_generation_prompt: bool, ) -> Result { #[derive(Serialize)] @@ -62,12 +79,12 @@ impl ChatTemplate { jinja() .read() .unwrap() - .get_template(&self.0) + .get_template(&self.id) .unwrap() .render(Args { messages, - bos_token, - eos_token, + bos_token: &self.bos, + eos_token: &self.eos, add_generation_prompt, }) } @@ -75,13 +92,12 @@ impl ChatTemplate { impl Drop for ChatTemplate { fn drop(&mut self) { - jinja().write().unwrap().remove_template(&self.0); + jinja().write().unwrap().remove_template(&self.id); } } fn jinja() -> &'static RwLock> { - static ENV: OnceLock>> = OnceLock::new(); - ENV.get_or_init(|| { + static ENV: LazyLock>> = LazyLock::new(|| { let mut env = Environment::empty(); env.set_unknown_method_callback(|_, value, method, args| { use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value}; @@ -93,7 +109,8 @@ fn jinja() -> &'static RwLock> { } }); RwLock::new(env) - }) + }); + &ENV } #[test] @@ -101,14 +118,12 @@ fn test() { const TAIDE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<>\n' + messages[0]['content'] + '\n<>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"; const MINICPM: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"; - let result = ChatTemplate::new(TAIDE.into()) + let result = ChatTemplate::new(TAIDE.into(), "".into(), "".into()) .render( &[Message { role: "user", content: "Hello, who are you?", }], - "", - "", true, ) .unwrap(); @@ -118,14 +133,12 @@ fn test() { "[INST] Hello, who are you? [/INST]<|im_start|>assistant\n" ); - let result = ChatTemplate::new(MINICPM.into()) + let result = ChatTemplate::new(MINICPM.into(), "".into(), "".into()) .render( &[Message { role: "user", content: "Hello, who are you?", }], - "", - "", true, ) .unwrap(); diff --git a/models/llama/common-cpu/src/lib.rs b/models/llama/common-cpu/src/lib.rs index b3506e2e..e676efcf 100644 --- a/models/llama/common-cpu/src/lib.rs +++ b/models/llama/common-cpu/src/lib.rs @@ -33,7 +33,7 @@ impl Model for Transformer { let gguf = GGufModel::read(_files.iter().map(|f| &**f)); let tokenize = build_tokenize(&gguf); - let render = build_render(&gguf); + let render = build_render(&gguf, &*tokenize); let model = LlamaModel::from_gguf(&gguf); #[inline(always)] diff --git a/models/llama/nvidia-gpu/src/lib.rs b/models/llama/nvidia-gpu/src/lib.rs index c37135f1..e8a6469c 100644 --- a/models/llama/nvidia-gpu/src/lib.rs +++ b/models/llama/nvidia-gpu/src/lib.rs @@ -83,7 +83,7 @@ impl Model for Transformer { let gguf = GGufModel::read(_files.iter().map(|f| &**f)); let tokenize = build_tokenize(&gguf); - let render = build_render(&gguf); + let render = build_render(&gguf, &*tokenize); let llama = LlamaModel::from_gguf(&gguf); let LlamaMeta { dt_norm, diff --git a/service/Cargo.toml b/service/Cargo.toml index b87ffa37..3192b83e 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -10,7 +10,6 @@ authors = ["YdrMaster "] common = { path = "../common" } tensor = { path = "../tensor" } causal-lm = { path = "../causal-lm" } -chat-template = { path = "../chat-template" } log.workspace = true tokio.workspace = true memmap2.workspace = true diff --git a/xtask/src/cast.rs b/xtask/src/cast.rs deleted file mode 100644 index c8cb2d7f..00000000 --- a/xtask/src/cast.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::{fs, path::PathBuf, time::Instant}; - -use digit_layout::types::{BF16, F16, F32}; - -#[derive(Args, Default)] -pub(crate) struct CastArgs { - /// Original model directory. - #[clap(short, long)] - model: String, - /// Target model directory. - #[clap(short, long)] - target: Option, - /// Target model type. - /// avliable value includes: "f32", "f16", "bf16", "float32", etc. - #[clap(long)] - dt: Option, -} - -impl CastArgs { - pub fn invoke(self) { - let ty = match self.dt.as_deref() { - Some("f32") | Some("float") | Some("float32") | None => F32, - Some("f16") | Some("half") | Some("float16") => F16, - Some("bf16") | Some("bfloat16") => BF16, - Some(ty) => panic!("Unknown data type: \"{ty}\""), - }; - let model_dir = PathBuf::from(self.model); - - let time = Instant::now(); - let model = llama::Storage::load_safetensors(&model_dir).unwrap(); - println!("load model ... {:?}", time.elapsed()); - - let target = self.target.map(PathBuf::from).unwrap_or_else(|| { - model_dir.parent().unwrap().join(format!( - "{}_{}", - model_dir.file_name().unwrap().to_str().unwrap(), - match ty { - F16 => "f16", - F32 => "f32", - BF16 => "bf16", - _ => unreachable!(), - } - )) - }); - fs::create_dir_all(&target).unwrap(); - - let time = Instant::now(); - let model = model.cast(ty); - println!("cast data type ... {:?}", time.elapsed()); - - let time = Instant::now(); - model.save(&target).unwrap(); - println!("save model ... {:?}", time.elapsed()); - - let copy_file = |name: &str| { - let src = model_dir.join(name); - if src.is_file() { - let time = Instant::now(); - fs::copy(&src, target.join(name)).unwrap(); - println!("copy {name} ... {:?}", time.elapsed()); - } - }; - - copy_file("tokenizer.model"); - copy_file("vocabs.txt"); - } -} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 8785e99d..983de69e 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,4 +1,3 @@ -mod cast; mod chat; mod deploy; mod generate;