AI Agents with MongoDB / Build the agent and add memory
Code Summary: Add Memory to the Agent
This code adds short-term memory to the agent, enabling it to remember details from earlier in the current session.
Link to code on GitHub
Import Packages
Update the imported packages in the main.py file with the following:
import key_param
from pymongo import MongoClient
from langchain.agents import tool
from typing import List
from typing import Annotated
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import ToolMessage
from langgraph.graph import END, StateGraph, START
import voyageai
Define the State
Define a GraphState class which will serve as the state for our agent with the following:
# Define the graph state type with messages that can accumulate
class GraphState(TypedDict):
# Define a messages field that keeps track of conversation history
messages: Annotated[list, add_messages]
Create the Agent Node
Update the imported packages in the main.py file with the following:
import key_param
from pymongo import MongoClient
from langchain.agents import tool
from typing import List
from typing import Annotated
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import ToolMessage
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.mongodb import MongoDBSaver
import voyageai
Add a Memory Store
Update `init_graph` to use the checkpointer, ensuring the MongoDB client (our memory store) is passed in.
def init_graph(llm_with_tools, tools_by_name, mongodb_client):
"""
Initialize the graph.
Args:
llm_with_tools: The LLM with tools.
tools_by_name (Dict[str, Callable]): The tools by name.
mongodb_client (MongoClient): The MongoDB client.
Returns:
StateGraph: The compiled graph.
"""
graph = StateGraph(GraphState)
graph.add_node("agent", lambda state: agent(state, llm_with_tools))
graph.add_node("tools", lambda state: tool_node(state, tools_by_name))
graph.add_edge(START, "agent")
graph.add_edge("tools", "agent")
graph.add_conditional_edges("agent", route_tools, {"tools": "tools", END: END})
checkpointer = MongoDBSaver(mongodb_client)
return graph.compile(checkpointer=checkpointer)
Add a Session Identifier
Update the `execute_graph` function to receive the thread_id to keep track of its current session:
def execute_graph(app, thread_id: str, user_input: str) -> None:
"""
Stream outputs from the graph.
Args:
app: The compiled graph application.
thread_id (str): The thread ID.
user_input (str): The user's input.
"""
input = {"messages": [("user", user_input)]}
config = {"configurable": {"thread_id": thread_id}}
for output in app.stream(input, config):
for key, value in output.items():
print(f"Node {key}:")
print(value)
print("---FINAL ANSWER---")
print(value["messages"][-1].content)
Update the Main Function
Update the main function to account for the thread_id:
def main():
"""
Main function to initialize and execute the graph.
"""
mongodb_client, vs_collection, full_collection = init_mongodb()
tools = [
get_information_for_question_answering,
get_page_content_for_summarization
]
llm = ChatOpenAI(openai_api_key=key_param.openai_api_key, temperature=0, model="gpt-4o")
prompt = ChatPromptTemplate.from_messages(
[
(
"You are a helpful AI assistant."
" You are provided with tools to answer questions and summarize technical documentation related to MongoDB."
" Think step-by-step and use these tools to get the information required to answer the user query."
" Do not re-run tools unless absolutely necessary."
" If you are not able to get enough information using the tools, reply with I DON'T KNOW."
" You have access to the following tools: {tool_names}."
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
bind_tools = llm.bind_tools(tools)
llm_with_tools = prompt | bind_tools
tools_by_name = {tool.name: tool for tool in tools}
app = init_graph(llm_with_tools, tools_by_name, mongodb_client)
execute_graph(app, "1", "What are some best practices for data backups in MongoDB?")
execute_graph(app, "1", "What did I just ask you?")