from langchain_ollama.chat_models import ChatOllama from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage from libs.test_class import Test from langchain.tools import Tool from typing import Literal from langgraph.graph import StateGraph, MessagesState # from langgraph.prebuilt import ToolNode import json from pydantic import ValidationError def basic(model: str, seed: int, test: Test, base_url: str) -> str: system_msg = test.runnable_input['system_msg'] human_msg = test.runnable_input['human_msg'] if system_msg == None: prompt = [ human_msg ] else: prompt = [ system_msg, human_msg ] llm = ChatOllama( model=model, seed=seed, base_url=base_url ) ai_msg = llm.invoke(prompt) return ai_msg.content def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: system_msg = test.runnable_input['system_msg'] human_msg = test.runnable_input['human_msg'] tools_dict = test.runnable_input['tools'] tools = [] for key in tools_dict: tools.append(tools_dict[key]) if system_msg == None: prompt = [ human_msg ] else: prompt = [ system_msg, human_msg ] llm = ChatOllama( model=model, seed=seed, base_url=base_url ).bind_tools(tools) ai_msg = llm.invoke(prompt) prompt.append(ai_msg) try: tool_calls = [] for i in range(len(ai_msg.tool_calls)): tool_call = ai_msg.tool_calls[i] selected_tool = tools_dict[tool_call["name"].lower()] tool_msg = selected_tool.invoke(tool_call) prompt.append(tool_msg) ai_msg = llm.invoke(prompt) tool_calls.append({ "tool": tool_call["name"], "args": tool_call["args"], "index": 0 }) except IndexError: # LLM didnt use a tool -> jsut return the content tool_calls = [] return { "answer": ai_msg.content, "tool_calls": tool_calls } def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: tool_calls = [] index = -1 def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: messages = state["messages"] last_message = messages[-1] nonlocal index if last_message.tool_calls: index += 1 return "tools" return "__end__" def call_llm(state: MessagesState): messages = state["messages"] response = llm.invoke(messages) return {"messages": [response]} class NxToolNode: """A node that runs the tools requested in the last AIMessage.""" def __init__(self, tools: list) -> None: self.tools_by_name = {tool.name: tool for tool in tools} def __call__(self, inputs: dict): if messages := inputs.get("messages", []): message = messages[-1] else: raise ValueError("No message found in input") outputs = [] for tool_call in message.tool_calls: nonlocal tool_calls nonlocal index tool_calls.append({ "tool": tool_call["name"], "args": tool_call["args"], "index": index }) try: tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) except KeyError as e: tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' except ValidationError as e: tool_result = 'Tool got invalid input:\n' + e except Exception as e: tool_result = 'Error: ' + str(e) outputs.append( ToolMessage( content=json.dumps(tool_result), name=tool_call["name"], tool_call_id=tool_call["id"], ) ) return {"messages": outputs} tools_dict = test.runnable_input['tools'] tools = [] for key in tools_dict: tools.append(tools_dict[key]) tool_node = NxToolNode(tools) llm = ChatOllama( model=model, seed=seed, base_url=base_url ).bind_tools(tools) workflow = StateGraph(MessagesState) # Define the two nodes we will cycle between workflow.add_node("agent", call_llm) workflow.add_node("tools", tool_node) workflow.add_edge("__start__", "agent") workflow.add_conditional_edges( "agent", should_continue, ) workflow.add_edge("tools", "agent") graph = workflow.compile() # example with a single tool call start_messages = [ SystemMessage(content=test.runnable_input['system_msg']), HumanMessage(content=test.runnable_input['human_msg']) ] chunks = [] for chunk in graph.stream( {"messages": start_messages}, stream_mode="values", ): chunks.append(chunk["messages"][-1]) return { "answer": chunks[-1].content, "tool_calls": tool_calls } def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str: tool_calls = [] index = -1 def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: messages = state["messages"] last_message = messages[-1] nonlocal index if last_message.tool_calls: index += 1 return "tools" return "__end__" def call_llm(state: MessagesState): messages = state["messages"] response = llm.invoke(messages) return {"messages": [response]} class NxToolNode: """A node that runs the tools requested in the last AIMessage.""" def __init__(self, tools: list) -> None: self.tools_by_name = {tool.name: tool for tool in tools} def __call__(self, inputs: dict): if messages := inputs.get("messages", []): message = messages[-1] else: raise ValueError("No message found in input") outputs = [] for tool_call in message.tool_calls: nonlocal tool_calls nonlocal index tool_calls.append({ "tool": tool_call["name"], "args": tool_call["args"], "index": index }) try: tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) except KeyError as e: tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' except ValidationError as e: tool_result = 'Tool got invalid input:\n' + e except Exception as e: tool_result = 'Error: ' + str(e) outputs.append( ToolMessage( content=json.dumps(tool_result), name=tool_call["name"], tool_call_id=tool_call["id"], ) ) return {"messages": outputs} tools_dict = test.runnable_input['tools'] tools = [] for key in tools_dict: tools.append(tools_dict[key]) tool_node = NxToolNode(tools) llm = ChatOllama( model=model, seed=seed, base_url=base_url ).bind_tools(tools) workflow = StateGraph(MessagesState) # Define the two nodes we will cycle between workflow.add_node("agent", call_llm) workflow.add_node("tools", tool_node) workflow.add_edge("__start__", "agent") workflow.add_conditional_edges( "agent", should_continue, ) workflow.add_edge("tools", "agent") graph = workflow.compile() # example with a single tool call start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ] chunks = [] for chunk in graph.stream( {"messages": start_messages}, stream_mode="values", ): chunks.append(chunk["messages"][-1]) return { "answer": chunks[-1].content, "tool_calls": tool_calls }