How to force tool calling behavior
Prerequisites
This guide assumes familiarity with the following concepts:
In order to force our LLM to select a specific tool, we can use the
tool_choice
parameter to ensure certain behavior. First, letโs define
our model and tools:
import { tool } from "@langchain/core/tools";
import { z } from "zod";
const add = tool(
(input) => {
return `${input.a + input.b}`;
},
{
name: "add",
description: "Adds a and b.",
schema: z.object({
a: z.number(),
b: z.number(),
}),
}
);
const multiply = tool(
(input) => {
return `${input.a * input.b}`;
},
{
name: "Multiply",
description: "Multiplies a and b.",
schema: z.object({
a: z.number(),
b: z.number(),
}),
}
);
const tools = [add, multiply];
import { ChatOpenAI } from "@langchain/openai";
const llm = new ChatOpenAI({
model: "gpt-3.5-turbo",
});
For example, we can force our tool to call the multiply tool by using the following code:
const llmForcedToMultiply = llm.bindTools(tools, {
tool_choice: "Multiply",
});
const result = await llmForcedToMultiply.invoke("what is 2 + 4");
console.log(JSON.stringify(result.tool_calls, null, 2));
[
{
"name": "Multiply",
"args": {
"a": 2,
"b": 4
},
"type": "tool_call",
"id": "call_d5isFbUkn17Wjr6yEtNz7dDF"
}
]
Even if we pass it something that doesnโt require multiplcation - it will still call the tool!
We can also just force our tool to select at least one of our tools by
passing in the โanyโ (or โrequiredโ which is OpenAI specific) keyword to
the tool_choice
parameter.
const llmForcedToUseTool = llm.bindTools(tools, {
tool_choice: "any",
});
const result = await llmForcedToUseTool.invoke("What day is today?");
console.log(JSON.stringify(result.tool_calls, null, 2));
[
{
"name": "add",
"args": {
"a": 2,
"b": 3
},
"type": "tool_call",
"id": "call_La72g7Aj0XHG0pfPX6Dwg2vT"
}
]
note
Currently, the following packages have a minimum version for this style of forced tool calls:
Package Name | Min Package Version | Min Core Version |
---|---|---|
@langchain/aws | 0.0.3 | 0.2.17 |
@langchain/openai | 0.2.4 | 0.2.17 |
@langchain/groq | 0.0.14 | 0.2.17 |