282 lines
8.4 KiB
Python
282 lines
8.4 KiB
Python
from langchain_ollama.chat_models import ChatOllama
|
|
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
|
from libs.test_class import Test
|
|
from langchain.tools import Tool
|
|
from typing import Literal
|
|
|
|
from langgraph.graph import StateGraph, MessagesState
|
|
# from langgraph.prebuilt import ToolNode
|
|
import json
|
|
from pydantic import ValidationError
|
|
|
|
|
|
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
system_msg = test.runnable_input['system_msg']
|
|
human_msg = test.runnable_input['human_msg']
|
|
|
|
if system_msg == None: prompt = [ human_msg ]
|
|
else: prompt = [ system_msg, human_msg ]
|
|
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
)
|
|
ai_msg = llm.invoke(prompt)
|
|
return ai_msg.content
|
|
|
|
|
|
|
|
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
system_msg = test.runnable_input['system_msg']
|
|
human_msg = test.runnable_input['human_msg']
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
|
|
if system_msg == None: prompt = [ human_msg ]
|
|
else: prompt = [ system_msg, human_msg ]
|
|
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
).bind_tools(tools)
|
|
|
|
ai_msg = llm.invoke(prompt)
|
|
|
|
prompt.append(ai_msg)
|
|
|
|
try:
|
|
tool_calls = []
|
|
for i in range(len(ai_msg.tool_calls)):
|
|
tool_call = ai_msg.tool_calls[i]
|
|
selected_tool = tools_dict[tool_call["name"].lower()]
|
|
tool_msg = selected_tool.invoke(tool_call)
|
|
prompt.append(tool_msg)
|
|
ai_msg = llm.invoke(prompt)
|
|
tool_calls.append({
|
|
"tool": tool_call["name"],
|
|
"args": tool_call["args"],
|
|
"index": 0
|
|
})
|
|
except IndexError: # LLM didnt use a tool -> jsut return the content
|
|
tool_calls = []
|
|
return {
|
|
"answer": ai_msg.content,
|
|
"tool_calls": tool_calls
|
|
}
|
|
|
|
|
|
|
|
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
|
|
tool_calls = []
|
|
index = -1
|
|
|
|
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
nonlocal index
|
|
if last_message.tool_calls:
|
|
index += 1
|
|
return "tools"
|
|
return "__end__"
|
|
|
|
def call_llm(state: MessagesState):
|
|
messages = state["messages"]
|
|
response = llm.invoke(messages)
|
|
return {"messages": [response]}
|
|
|
|
class NxToolNode:
|
|
"""A node that runs the tools requested in the last AIMessage."""
|
|
|
|
def __init__(self, tools: list) -> None:
|
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
|
|
def __call__(self, inputs: dict):
|
|
if messages := inputs.get("messages", []):
|
|
message = messages[-1]
|
|
else:
|
|
raise ValueError("No message found in input")
|
|
outputs = []
|
|
for tool_call in message.tool_calls:
|
|
|
|
nonlocal tool_calls
|
|
nonlocal index
|
|
tool_calls.append({
|
|
"tool": tool_call["name"],
|
|
"args": tool_call["args"],
|
|
"index": index
|
|
})
|
|
|
|
try:
|
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
|
except KeyError as e:
|
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
|
except ValidationError as e:
|
|
tool_result = 'Tool got invalid input:\n' + e
|
|
except Exception as e:
|
|
tool_result = 'Error: ' + str(e)
|
|
|
|
outputs.append(
|
|
ToolMessage(
|
|
content=json.dumps(tool_result),
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
return {"messages": outputs}
|
|
|
|
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
tool_node = NxToolNode(tools)
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
).bind_tools(tools)
|
|
|
|
workflow = StateGraph(MessagesState)
|
|
|
|
# Define the two nodes we will cycle between
|
|
workflow.add_node("agent", call_llm)
|
|
workflow.add_node("tools", tool_node)
|
|
|
|
workflow.add_edge("__start__", "agent")
|
|
workflow.add_conditional_edges(
|
|
"agent",
|
|
should_continue,
|
|
)
|
|
workflow.add_edge("tools", "agent")
|
|
|
|
graph = workflow.compile()
|
|
|
|
# example with a single tool call
|
|
start_messages = [
|
|
SystemMessage(content=test.runnable_input['system_msg']),
|
|
HumanMessage(content=test.runnable_input['human_msg'])
|
|
]
|
|
|
|
chunks = []
|
|
|
|
for chunk in graph.stream(
|
|
{"messages": start_messages},
|
|
stream_mode="values",
|
|
): chunks.append(chunk["messages"][-1])
|
|
|
|
return {
|
|
"answer": chunks[-1].content,
|
|
"tool_calls": tool_calls
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
|
|
tool_calls = []
|
|
index = -1
|
|
|
|
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
nonlocal index
|
|
if last_message.tool_calls:
|
|
index += 1
|
|
return "tools"
|
|
return "__end__"
|
|
|
|
def call_llm(state: MessagesState):
|
|
messages = state["messages"]
|
|
response = llm.invoke(messages)
|
|
return {"messages": [response]}
|
|
|
|
class NxToolNode:
|
|
"""A node that runs the tools requested in the last AIMessage."""
|
|
|
|
def __init__(self, tools: list) -> None:
|
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
|
|
def __call__(self, inputs: dict):
|
|
if messages := inputs.get("messages", []):
|
|
message = messages[-1]
|
|
else:
|
|
raise ValueError("No message found in input")
|
|
outputs = []
|
|
for tool_call in message.tool_calls:
|
|
|
|
nonlocal tool_calls
|
|
nonlocal index
|
|
tool_calls.append({
|
|
"tool": tool_call["name"],
|
|
"args": tool_call["args"],
|
|
"index": index
|
|
})
|
|
|
|
try:
|
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
|
except KeyError as e:
|
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
|
except ValidationError as e:
|
|
tool_result = 'Tool got invalid input:\n' + e
|
|
except Exception as e:
|
|
tool_result = 'Error: ' + str(e)
|
|
|
|
outputs.append(
|
|
ToolMessage(
|
|
content=json.dumps(tool_result),
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
return {"messages": outputs}
|
|
|
|
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
tool_node = NxToolNode(tools)
|
|
llm = ChatOllama(
|
|
model=model,
|
|
seed=seed,
|
|
base_url=base_url
|
|
).bind_tools(tools)
|
|
|
|
workflow = StateGraph(MessagesState)
|
|
|
|
# Define the two nodes we will cycle between
|
|
workflow.add_node("agent", call_llm)
|
|
workflow.add_node("tools", tool_node)
|
|
|
|
workflow.add_edge("__start__", "agent")
|
|
workflow.add_conditional_edges(
|
|
"agent",
|
|
should_continue,
|
|
)
|
|
workflow.add_edge("tools", "agent")
|
|
|
|
graph = workflow.compile()
|
|
|
|
# example with a single tool call
|
|
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
|
|
|
|
chunks = []
|
|
|
|
for chunk in graph.stream(
|
|
{"messages": start_messages},
|
|
stream_mode="values",
|
|
): chunks.append(chunk["messages"][-1])
|
|
|
|
return {
|
|
"answer": chunks[-1].content,
|
|
"tool_calls": tool_calls
|
|
}
|