AI Agents with MongoDB / Build the agent and add memory
Code Summary: Build the Agent's Decision-making Capabilities
This code creates a graph using LangGraph that gives the agent the ability to make decisions.
Link to code on GitHub
Import Packages
Update the imported packages in the main.py file with the following:
import key_param
from pymongo import MongoClient
from langchain.agents import tool
from typing import List
from typing import Annotated
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import ToolMessage
from langgraph.graph import END, StateGraph, START
import voyageai
Define the State
Define a GraphState class which will serve as the state for our agent with the following:
# Define the graph state type with messages that can accumulate
class GraphState(TypedDict):
# Define a messages field that keeps track of conversation history
messages: Annotated[list, add_messages]
Create the Agent Node
Create an agent node that takes the current state and the language model with tool bindings:
def agent(state: GraphState, llm_with_tools) -> GraphState:
"""
Agent node.
Args:
state (GraphState): The graph state.
llm_with_tools: The LLM with tools.
Returns:
GraphState: The updated messages.
"""
messages = state["messages"]
result = llm_with_tools.invoke(messages)
return {"messages": [result]}
Create the Tool Node
Create a tool node which receives the current state and a dictionary that maps tool names to their functions:
def tool_node(state: GraphState, tools_by_name) -> GraphState:
"""
Tool node.
Args:
state (GraphState): The graph state.
tools_by_name (Dict[str, Callable]): The tools by name.
Returns:
GraphState: The updated messages.
"""
result = []
tool_calls = state["messages"][-1].tool_calls
for tool_call in tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
Create a Router Function
Create a router function that checks the latest message for tool calls. If found, it routes to the "tools" node for execution. Otherwise, it returns the final answer to the user, signifying the end of the processing cycle.
def route_tools(state: GraphState):
"""
Route to the tool node if the last message has tool calls. Otherwise, route to the end.
Args:
state (GraphState): The graph state.
Returns:
str: The next node to route to.
"""
messages = state.get("messages", [])
if len(messages) > 0:
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
Create the Graph
Define an `init_graph` function that creates and connects all elements of the graph:
def init_graph(llm_with_tools, tools_by_name):
"""
Initialize the graph.
Args:
llm_with_tools: The LLM with tools.
tools_by_name (Dict[str, Callable]): The tools by name.
mongodb_client (MongoClient): The MongoDB client.
Returns:
StateGraph: The compiled graph.
"""
graph = StateGraph(GraphState)
graph.add_node("agent", lambda state: agent(state, llm_with_tools))
graph.add_node("tools", lambda state: tool_node(state, tools_by_name))
graph.add_edge(START, "agent")
graph.add_edge("tools", "agent")
graph.add_conditional_edges("agent", route_tools, {"tools": "tools", END: END})
return graph.compile()
Run the Graph
Create a `execute_graph` function that receives the LLM with our graph and the user’s input:
def execute_graph(app, user_input: str) -> None:
"""
Stream outputs from the graph.
Args:
app: The compiled graph application.
thread_id (str): The thread ID.
user_input (str): The user's input.
"""
input = {"messages": [("user", user_input)]}
for output in app.stream(input):
for key, value in output.items():
print(f"Node {key}:")
print(value)
print("---FINAL ANSWER---")
print(value["messages"][-1].content)
Update the Main Function
Update the main function to execute the graph when a user queries:
def main():
"""
Main function to initialize and execute the graph.
"""
mongodb_client, vs_collection, full_collection = init_mongodb()
tools = [
get_information_for_question_answering,
get_page_content_for_summarization
]
llm = ChatOpenAI(openai_api_key=key_param.openai_api_key, temperature=0, model="gpt-4o")
prompt = ChatPromptTemplate.from_messages(
[
(
"You are a helpful AI assistant."
" You are provided with tools to answer questions and summarize technical documentation related to MongoDB."
" Think step-by-step and use these tools to get the information required to answer the user query."
" Do not re-run tools unless absolutely necessary."
" If you are not able to get enough information using the tools, reply with I DON'T KNOW."
" You have access to the following tools: {tool_names}."
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
bind_tools = llm.bind_tools(tools)
llm_with_tools = prompt | bind_tools
tools_by_name = {tool.name: tool for tool in tools}
app = init_graph(llm_with_tools, tools_by_name)
execute_graph(app, "What are some best practices for data backups in MongoDB?")
execute_graph(app, "Give me a summary of the page titled Create a MongoDB Deployment")