Skip to content

Commit

Permalink
feat(causal-lm): 识别是否需要替换空格,恢复添加和替换空格机制
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Aug 23, 2024
1 parent 1973925 commit d68a0eb
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 224 deletions.
84 changes: 40 additions & 44 deletions causal-lm/src/render.rs → causal-lm/src/chat_template.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Tokenize;
use crate::Tokenizer;
use common::GGufModel;
use minijinja::Environment;
use serde::Serialize;
Expand All @@ -21,39 +21,38 @@ pub struct Message<'a> {
pub content: &'a str,
}

/// Build a chat template from the GGuf model.
pub fn build_render(gguf: &GGufModel, tokenize: &dyn Tokenize) -> Option<ChatTemplate> {
let template = gguf
.meta_kvs
.get("tokenizer.chat_template")?
.value_reader()
.read_str()
.unwrap()
.into();

let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();
let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();
impl ChatTemplate {
pub fn from_gguf(gguf: &GGufModel, tokenize: &Tokenizer) -> Option<ChatTemplate> {
let template = gguf
.meta_kvs
.get("tokenizer.chat_template")?
.value_reader()
.read_str()
.unwrap()
.into();

Some(ChatTemplate::new(
template,
tokenize.decode(bos).into(),
tokenize.decode(eos).into(),
))
}
let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();
let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"]
.value_reader()
.read::<utok>()
.unwrap();

Some(ChatTemplate::new(
template,
tokenize.decode(bos).into(),
tokenize.decode(eos).into(),
))
}

impl ChatTemplate {
/// Create a new chat template.
pub fn new(template: String, bos: String, eos: String) -> Self {
static NEXT: AtomicUsize = AtomicUsize::new(0);
let id = NEXT.fetch_add(1, Relaxed).to_string();

jinja()
JINJA_ENV
.write()
.unwrap()
.add_template_owned(id.clone(), template)
Expand All @@ -76,7 +75,7 @@ impl ChatTemplate {
add_generation_prompt: bool,
}

jinja()
JINJA_ENV
.read()
.unwrap()
.get_template(&self.id)
Expand All @@ -92,26 +91,23 @@ impl ChatTemplate {

impl Drop for ChatTemplate {
fn drop(&mut self) {
jinja().write().unwrap().remove_template(&self.id);
JINJA_ENV.write().unwrap().remove_template(&self.id);
}
}

fn jinja() -> &'static RwLock<Environment<'static>> {
static ENV: LazyLock<RwLock<Environment<'_>>> = LazyLock::new(|| {
let mut env = Environment::empty();
env.set_unknown_method_callback(|_, value, method, args| {
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
match (method, value.kind(), args) {
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
value.to_str().unwrap().trim().into(),
)),
_ => Err(UnknownMethod.into()),
}
});
RwLock::new(env)
static JINJA_ENV: LazyLock<RwLock<Environment<'_>>> = LazyLock::new(|| {
let mut env = Environment::empty();
env.set_unknown_method_callback(|_, value, method, args| {
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
match (method, value.kind(), args) {
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
value.to_str().unwrap().trim().into(),
)),
_ => Err(UnknownMethod.into()),
}
});
&ENV
}
RwLock::new(env)
});

#[test]
fn test() {
Expand Down
60 changes: 39 additions & 21 deletions causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
#![doc = include_str!("../README.md")]
#![deny(warnings, missing_docs)]
// #![deny(warnings, missing_docs)]

mod chat_template;
mod decoding;
mod query_context;
mod render;
mod tokenize;
mod tokenizer;

use common::{upos, utok};
use digit_layout::types::U32;
use std::{path::Path, time::Duration};
use std::{io::Write, path::Path};
use tensor::{udim, Tensor};

pub use chat_template::ChatTemplate;
pub use decoding::DecodingMeta;
pub use operators::random_sample::SampleArgs;
pub use query_context::QueryContext;
pub use render::{build_render, ChatTemplate};
pub use tokenize::{build_tokenize, Tokenize};
pub use tokenizer::Tokenizer;

/// 从文件系统加载的模型。
pub trait Model: Sized {
Expand All @@ -24,17 +24,17 @@ pub trait Model: Sized {
/// 模型加载中可能的错误。
type Error;
/// 从文件系统加载模型。
fn load(gguf: impl AsRef<Path>, meta: Self::Config) -> Result<FromGGuf<Self>, Self::Error>;
fn load(gguf: impl AsRef<Path>, config: Self::Config) -> Result<FromGGuf<Self>, Self::Error>;
}

/// 从 GGuf 文件加载模型、分词器和渲染模板。
pub struct FromGGuf<M: Model> {
/// 模型。
pub model: M,
/// 分词器。
pub tokenize: Box<dyn Tokenize>,
pub tokenizer: Tokenizer,
/// 渲染模板。
pub render: Option<ChatTemplate>,
pub chat_template: Option<ChatTemplate>,
}

/// 因果语言模型。
Expand Down Expand Up @@ -119,32 +119,38 @@ pub fn pos<'a, S: 'a>(
}

/// 测试模型实现。
pub fn test_impl<M>(meta: M::Config, prompt: &[utok])
pub fn test_impl<M>(meta: M::Config, max_steps: usize, prompt: &str)
where
M: CausalLM,
M::Error: std::fmt::Debug,
{
use std::time::Instant;
use std::time::{Duration, Instant};

let Some(gguf) = common::test_model::find() else {
return;
};
println!("model: {}", gguf.display());

let t0 = Instant::now();
let FromGGuf { model, .. } = M::load(gguf, meta).unwrap();
let t1 = Instant::now();
println!("load {:?}", t1 - t0);
let time = Instant::now();
let FromGGuf {
model, tokenizer, ..
} = M::load(gguf, meta).unwrap();
println!("load {:?}", time.elapsed());

let mut cache = model.new_cache();
let mut prompt = tokenizer.encode(prompt);
print!("prompt:");
for t in &prompt {
print!(" {t}");
}

let mut prompt = prompt.to_vec();
let mut tokens = prompt.clone();
let mut pos = 0;

let mut time = Duration::ZERO;
let mut steps = 0;

while prompt != [model.eos_token()] {
let mut cache = model.new_cache();
while prompt != [model.eos_token()] && steps <= max_steps {
let start = Instant::now();

let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());
Expand All @@ -165,21 +171,33 @@ where
num_decode: 1,
args: SampleArgs::ARG_MAX,
}];
let tokens = CausalLM::sample(&model, args, logits);
let token = CausalLM::sample(&model, args, logits)[0];

if steps > 0 {
time += start.elapsed();
}
steps += 1;

println!("{:?}", tokens);
print!(" {token}");
std::io::stdout().flush().unwrap();

pos += prompt.len() as upos;
prompt = tokens;
prompt.clear();
prompt.push(token);
tokens.push(token);
}

steps -= 1;
println!();
println!(
"steps = {steps}, average decoding time = {:?}",
time.div_f32(steps as _)
);
println!();
println!("---");
for t in tokens {
print!("{}", tokenizer.decode(t));
}
println!();
println!("---");
}
132 changes: 0 additions & 132 deletions causal-lm/src/tokenize.rs

This file was deleted.

Loading

0 comments on commit d68a0eb

Please sign in to comment.