Skip to content

Commit

Permalink
agent: fix tool execution
Browse files Browse the repository at this point in the history
Still not close to being an agent but.
  • Loading branch information
danbev committed Jan 9, 2025
1 parent 9b618aa commit 0ff7899
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 23 deletions.
10 changes: 8 additions & 2 deletions agents/llama-cpp-agent/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ run-print-tool: tool-runner print-component
cd tool-runner && cargo run -- -c ../components/print-tool-component.wasm --message "Something to print"

### Agent
run-agent:
run-agent-echo:
cd agent && cargo run -- -m ../models/Phi-3-mini-4k-instruct-q4.gguf \
-c ../components/echo-tool-component.wasm \
-c ../components/print-tool-component.wasm \
-p "Please echo back 'Something'"
-p "Please echo the 'Hello'"

run-agent-print:
cd agent && cargo run -- -m ../models/Phi-3-mini-4k-instruct-q4.gguf \
-c ../components/echo-tool-component.wasm \
-c ../components/print-tool-component.wasm \
-p "Please print the following 'coffee'"

download-phi-mini-instruct: models
cd models && \
Expand Down
54 changes: 37 additions & 17 deletions agents/llama-cpp-agent/agent/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,29 +110,49 @@ Available tools and their usage patterns:
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);

if self.model.is_eog_token(token) {
eprintln!();
break;
if !self.model.is_eog_token(token) {
let output_bytes = self.model.token_to_bytes(token, Special::Tokenize)?;
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
response.push_str(&output_string);
}

let output_bytes = self.model.token_to_bytes(token, Special::Tokenize)?;
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
response.push_str(&output_string);

print!("{output_string}");
std::io::stdout().flush()?;

if response.contains("USE_TOOL: Echo") && response.contains("value=") {
if let Some(cmd) = response.split("USE_TOOL: Echo, value=").nth(1) {
let value = cmd.trim();
if !value.is_empty() {
return self.tool_manager.execute_tool("Echo", vec![("value".to_string(), value.to_string())]);
//print!("{output_string}");
//std::io::stdout().flush()?;

if self.model.is_eog_token(token) {
if response.contains("USE_TOOL:") {
let metadata = self.tool_manager.get_metadata();

for md in &metadata {
let tool_prefix = format!("USE_TOOL: {}", md.name);

if response.contains(&tool_prefix) {
let mut params = Vec::new();

if let Some(cmd_part) = response.split(&tool_prefix).nth(1) {
let cmd_part = cmd_part.trim_start_matches(", ").trim();
for param in &md.params {
let param_prefix = format!("{}=", param.name);

if let Some(value_part) = cmd_part.split(&param_prefix).nth(1) {
let value = value_part.trim();

if !value.is_empty() {
params.push((param.name.clone(), value.to_string()));
}
}
}

if params.len() == md.params.len() {
return self.tool_manager.execute_tool(&md.name, params);
}
}
}
}
}
}


batch.clear();
batch.add(token, n_cur, &[0], true)?;

Expand Down
1 change: 0 additions & 1 deletion agents/llama-cpp-agent/agent/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ impl ToolManager {
println!(" - {}: {}", name, value);
}
let result = tool.call_execute(&mut store, &params)?;
println!("Execution result: {}", result.data);

if result.success {
Ok(result.data)
Expand Down
4 changes: 1 addition & 3 deletions agents/llama-cpp-agent/tools/print/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl Guest for PrintTool {

let usage = ToolUsage {
user: "Please print this string 'something'".to_string(),
assistent: "USE_TOOL: Print, message='something'".to_string(),
assistent: "USE_TOOL: Print, message=something".to_string(),
};

ToolMetadata {
Expand All @@ -39,8 +39,6 @@ impl Guest for PrintTool {
fn execute(params: Vec<(String, String)>) -> ToolResult {
let message = &params.get(0).unwrap().1;

println!("{}", message);

ToolResult {
success: true,
data: format!("Successfully printed: {}", message),
Expand Down

0 comments on commit 0ff7899

Please sign in to comment.