cacheing, tests as dict, new tests
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
||||
from libs.test_class import Test
|
||||
from langchain.tools import Tool
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.graph import StateGraph, MessagesState
|
||||
# from langgraph.prebuilt import ToolNode
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
system_msg = test.runnable_input['system_msg']
|
||||
human_msg = test.runnable_input['human_msg']
|
||||
|
||||
@@ -19,8 +25,9 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
ai_msg = llm.invoke(prompt)
|
||||
return ai_msg.content
|
||||
|
||||
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
|
||||
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
system_msg = test.runnable_input['system_msg']
|
||||
human_msg = test.runnable_input['human_msg']
|
||||
tools_dict = test.runnable_input['tools']
|
||||
@@ -42,11 +49,233 @@ def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> st
|
||||
prompt.append(ai_msg)
|
||||
|
||||
try:
|
||||
tool_call = ai_msg.tool_calls[0]
|
||||
selected_tool = tools_dict[tool_call["name"].lower()]
|
||||
tool_msg = selected_tool.invoke(tool_call)
|
||||
prompt.append(tool_msg)
|
||||
ai_msg = llm.invoke(prompt)
|
||||
except IndexError:
|
||||
pass
|
||||
return ai_msg.content
|
||||
tool_calls = []
|
||||
for i in range(len(ai_msg.tool_calls)):
|
||||
tool_call = ai_msg.tool_calls[i]
|
||||
selected_tool = tools_dict[tool_call["name"].lower()]
|
||||
tool_msg = selected_tool.invoke(tool_call)
|
||||
prompt.append(tool_msg)
|
||||
ai_msg = llm.invoke(prompt)
|
||||
tool_calls.append({
|
||||
"tool": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"index": 0
|
||||
})
|
||||
except IndexError: # LLM didnt use a tool -> jsut return the content
|
||||
tool_calls = []
|
||||
return {
|
||||
"answer": ai_msg.content,
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
|
||||
|
||||
|
||||
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
tool_calls = []
|
||||
index = -1
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
nonlocal index
|
||||
if last_message.tool_calls:
|
||||
index += 1
|
||||
return "tools"
|
||||
return "__end__"
|
||||
|
||||
def call_llm(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
response = llm.invoke(messages)
|
||||
return {"messages": [response]}
|
||||
|
||||
class NxToolNode:
|
||||
"""A node that runs the tools requested in the last AIMessage."""
|
||||
|
||||
def __init__(self, tools: list) -> None:
|
||||
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def __call__(self, inputs: dict):
|
||||
if messages := inputs.get("messages", []):
|
||||
message = messages[-1]
|
||||
else:
|
||||
raise ValueError("No message found in input")
|
||||
outputs = []
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
nonlocal tool_calls
|
||||
nonlocal index
|
||||
tool_calls.append({
|
||||
"tool": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"index": index
|
||||
})
|
||||
|
||||
try:
|
||||
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
||||
except KeyError as e:
|
||||
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
||||
except ValidationError as e:
|
||||
tool_result = 'Tool got invalid input:\n' + e
|
||||
except Exception as e:
|
||||
tool_result = 'Error: ' + str(e)
|
||||
|
||||
outputs.append(
|
||||
ToolMessage(
|
||||
content=json.dumps(tool_result),
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return {"messages": outputs}
|
||||
|
||||
|
||||
tools_dict = test.runnable_input['tools']
|
||||
tools = []
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
tool_node = NxToolNode(tools)
|
||||
llm = ChatOllama(
|
||||
model=model,
|
||||
seed=seed,
|
||||
base_url=base_url
|
||||
).bind_tools(tools)
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
# Define the two nodes we will cycle between
|
||||
workflow.add_node("agent", call_llm)
|
||||
workflow.add_node("tools", tool_node)
|
||||
|
||||
workflow.add_edge("__start__", "agent")
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
)
|
||||
workflow.add_edge("tools", "agent")
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
# example with a single tool call
|
||||
start_messages = [
|
||||
SystemMessage(content=test.runnable_input['system_msg']),
|
||||
HumanMessage(content=test.runnable_input['human_msg'])
|
||||
]
|
||||
|
||||
chunks = []
|
||||
|
||||
for chunk in graph.stream(
|
||||
{"messages": start_messages},
|
||||
stream_mode="values",
|
||||
): chunks.append(chunk["messages"][-1])
|
||||
|
||||
return {
|
||||
"answer": chunks[-1].content,
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
tool_calls = []
|
||||
index = -1
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
nonlocal index
|
||||
if last_message.tool_calls:
|
||||
index += 1
|
||||
return "tools"
|
||||
return "__end__"
|
||||
|
||||
def call_llm(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
response = llm.invoke(messages)
|
||||
return {"messages": [response]}
|
||||
|
||||
class NxToolNode:
|
||||
"""A node that runs the tools requested in the last AIMessage."""
|
||||
|
||||
def __init__(self, tools: list) -> None:
|
||||
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def __call__(self, inputs: dict):
|
||||
if messages := inputs.get("messages", []):
|
||||
message = messages[-1]
|
||||
else:
|
||||
raise ValueError("No message found in input")
|
||||
outputs = []
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
nonlocal tool_calls
|
||||
nonlocal index
|
||||
tool_calls.append({
|
||||
"tool": tool_call["name"],
|
||||
"args": tool_call["args"],
|
||||
"index": index
|
||||
})
|
||||
|
||||
try:
|
||||
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
||||
except KeyError as e:
|
||||
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
||||
except ValidationError as e:
|
||||
tool_result = 'Tool got invalid input:\n' + e
|
||||
except Exception as e:
|
||||
tool_result = 'Error: ' + str(e)
|
||||
|
||||
outputs.append(
|
||||
ToolMessage(
|
||||
content=json.dumps(tool_result),
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return {"messages": outputs}
|
||||
|
||||
|
||||
tools_dict = test.runnable_input['tools']
|
||||
tools = []
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
tool_node = NxToolNode(tools)
|
||||
llm = ChatOllama(
|
||||
model=model,
|
||||
seed=seed,
|
||||
base_url=base_url
|
||||
).bind_tools(tools)
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
# Define the two nodes we will cycle between
|
||||
workflow.add_node("agent", call_llm)
|
||||
workflow.add_node("tools", tool_node)
|
||||
|
||||
workflow.add_edge("__start__", "agent")
|
||||
workflow.add_conditional_edges(
|
||||
"agent",
|
||||
should_continue,
|
||||
)
|
||||
workflow.add_edge("tools", "agent")
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
# example with a single tool call
|
||||
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||
|
||||
chunks = []
|
||||
|
||||
for chunk in graph.stream(
|
||||
{"messages": start_messages},
|
||||
stream_mode="values",
|
||||
): chunks.append(chunk["messages"][-1])
|
||||
|
||||
return {
|
||||
"answer": chunks[-1].content,
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user