Skip to content

Commit

Permalink
refactor(causal-lm): render 存储 bos/eos,不依赖外部提供
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 23, 2024
1 parent ced0bb0 commit 1973925
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 99 deletions.
1 change: 0 additions & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ list-turbo = "xtask list-turbo"
deploy = "xtask deploy"
generate = "xtask generate"
chat = "xtask chat"
cast = "xtask cast"
service = "xtask service"
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
digit-layout = "0.0"
build-script-cfg = "0.0"

ggus = { git = "https://github.com/YdrMaster/gguf", rev = "aa1281a" }
ggus = { git = "https://github.com/YdrMaster/gguf", rev = "1f4a7ba" }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "29b950c", default-features = false }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d089ada" }
search-neuware-tools = "0.0"
Expand Down
59 changes: 36 additions & 23 deletions causal-lm/src/render.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use common::GGufModel;
use crate::Tokenize;
use common::GGufModel;
use minijinja::Environment;
use serde::Serialize;
use std::sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
OnceLock, RwLock,
LazyLock, RwLock,
};
use tokeneer::utok;

/// A template for rendering chat messages.
#[repr(transparent)]
pub struct ChatTemplate(String);
pub struct ChatTemplate {
id: String,
bos: String,
eos: String,
}

#[derive(Serialize)]
pub struct Message<'a> {
Expand All @@ -17,20 +22,34 @@ pub struct Message<'a> {
}

/// Build a chat template from the GGuf model.
pub fn build_render(gguf: &GGufModel) -> Option<ChatTemplate> {
pub fn build_render(gguf: &GGufModel, tokenize: &dyn Tokenize) -> Option<ChatTemplate> {
let template = gguf
.meta_kvs
.get("tokenizer.chat_template")?
.value_reader()
.read_str()
.unwrap()
.into();
Some(ChatTemplate::new(template))

let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();
let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();

Some(ChatTemplate::new(
template,
tokenize.decode(bos).into(),
tokenize.decode(eos).into(),
))
}

impl ChatTemplate {
/// Create a new chat template.
pub fn new(template: String) -> Self {
pub fn new(template: String, bos: String, eos: String) -> Self {
static NEXT: AtomicUsize = AtomicUsize::new(0);
let id = NEXT.fetch_add(1, Relaxed).to_string();

Expand All @@ -40,15 +59,13 @@ impl ChatTemplate {
.add_template_owned(id.clone(), template)
.unwrap();

Self(id)
Self { id, bos, eos }
}

/// Render the chat template with the given messages.
pub fn render(
&self,
messages: &[Message],
bos_token: &str,
eos_token: &str,
add_generation_prompt: bool,
) -> Result<String, minijinja::Error> {
#[derive(Serialize)]
Expand All @@ -62,26 +79,25 @@ impl ChatTemplate {
jinja()
.read()
.unwrap()
.get_template(&self.0)
.get_template(&self.id)
.unwrap()
.render(Args {
messages,
bos_token,
eos_token,
bos_token: &self.bos,
eos_token: &self.eos,
add_generation_prompt,
})
}
}

impl Drop for ChatTemplate {
fn drop(&mut self) {
jinja().write().unwrap().remove_template(&self.0);
jinja().write().unwrap().remove_template(&self.id);
}
}

fn jinja() -> &'static RwLock<Environment<'static>> {
static ENV: OnceLock<RwLock<Environment<'_>>> = OnceLock::new();
ENV.get_or_init(|| {
static ENV: LazyLock<RwLock<Environment<'_>>> = LazyLock::new(|| {
let mut env = Environment::empty();
env.set_unknown_method_callback(|_, value, method, args| {
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
Expand All @@ -93,22 +109,21 @@ fn jinja() -> &'static RwLock<Environment<'static>> {
}
});
RwLock::new(env)
})
});
&ENV
}

#[test]
fn test() {
const TAIDE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
const MINICPM: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}";

let result = ChatTemplate::new(TAIDE.into())
let result = ChatTemplate::new(TAIDE.into(), "<s>".into(), "</s>".into())
.render(
&[Message {
role: "user",
content: "Hello, who are you?",
}],
"<s>",
"</s>",
true,
)
.unwrap();
Expand All @@ -118,14 +133,12 @@ fn test() {
"<s>[INST] Hello, who are you? [/INST]<|im_start|>assistant\n"
);

let result = ChatTemplate::new(MINICPM.into())
let result = ChatTemplate::new(MINICPM.into(), "<s>".into(), "</s>".into())
.render(
&[Message {
role: "user",
content: "Hello, who are you?",
}],
"<s>",
"</s>",
true,
)
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Model for Transformer {
let gguf = GGufModel::read(_files.iter().map(|f| &**f));

let tokenize = build_tokenize(&gguf);
let render = build_render(&gguf);
let render = build_render(&gguf, &*tokenize);
let model = LlamaModel::from_gguf(&gguf);

#[inline(always)]
Expand Down
2 changes: 1 addition & 1 deletion models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Model for Transformer {
let gguf = GGufModel::read(_files.iter().map(|f| &**f));

let tokenize = build_tokenize(&gguf);
let render = build_render(&gguf);
let render = build_render(&gguf, &*tokenize);
let llama = LlamaModel::from_gguf(&gguf);
let LlamaMeta {
dt_norm,
Expand Down
1 change: 0 additions & 1 deletion service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ authors = ["YdrMaster <ydrml@hotmail.com>"]
common = { path = "../common" }
tensor = { path = "../tensor" }
causal-lm = { path = "../causal-lm" }
chat-template = { path = "../chat-template" }
log.workspace = true
tokio.workspace = true
memmap2.workspace = true
Expand Down
67 changes: 0 additions & 67 deletions xtask/src/cast.rs

This file was deleted.

1 change: 0 additions & 1 deletion xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod cast;
mod chat;
mod deploy;
mod generate;
Expand Down

0 comments on commit 1973925

Please sign in to comment.