225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
from types import NoneType
|
|
from langchain_ollama.chat_models import ChatOllama
|
|
from libs.ollama_functions import OllamaFunctionsLSM, OllamaFunctionsT2S
|
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
|
|
from libs.classes import Technique, Test, Model
|
|
from langchain.tools import Tool
|
|
from typing import Literal
|
|
|
|
from langgraph.graph import StateGraph, MessagesState
|
|
import json
|
|
from pydantic import ValidationError
|
|
|
|
from suite_settings.techniques import techniques
|
|
|
|
def _get_llm(model: Model, base_url: str, seed: int, technique: Technique, tools: list[Tool]|NoneType = None):
|
|
if technique == techniques[1]: # Native
|
|
llm = ChatOllama(
|
|
model=model.identifier,
|
|
seed=seed,
|
|
base_url=base_url
|
|
)
|
|
elif technique == techniques[903]: # Long System Message
|
|
llm = OllamaFunctionsLSM(
|
|
model=model.identifier,
|
|
seed=seed,
|
|
base_url=base_url,
|
|
format="json",
|
|
)
|
|
elif technique == techniques[572]: # ToolMessages to SystemMessages
|
|
llm = OllamaFunctionsT2S(
|
|
model=model.identifier,
|
|
seed=seed,
|
|
base_url=base_url,
|
|
format="json",
|
|
)
|
|
else:
|
|
raise ValueError("Unkown Technique in _get_llm()")
|
|
|
|
if tools:
|
|
llm = llm.bind_tools(tools=tools)
|
|
|
|
return llm
|
|
|
|
|
|
def basic_prompt(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> str:
|
|
|
|
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
|
try:
|
|
messages += test.runnable_input['fsp_messages']
|
|
except KeyError:
|
|
pass
|
|
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
|
|
|
llm = _get_llm(model=model, base_url=base_url, technique=technique, seed=seed)
|
|
ai_msg = llm.invoke(messages)
|
|
assert isinstance(ai_msg.content, str)
|
|
return ai_msg.content
|
|
|
|
|
|
|
|
def one_tool_call_answer(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict:
|
|
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
|
|
|
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
|
try:
|
|
messages += test.runnable_input['fsp_messages']
|
|
except KeyError:
|
|
pass
|
|
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
|
|
|
ai_msg = llm.invoke(messages)
|
|
|
|
messages += [ ai_msg ]
|
|
|
|
try:
|
|
tool_calls = []
|
|
assert isinstance(ai_msg, AIMessage)
|
|
calls = ai_msg.tool_calls
|
|
for call in calls:
|
|
selected_tool = tools_dict[call["name"].lower()]
|
|
tool_msg = selected_tool.invoke(call)
|
|
messages.append(tool_msg)
|
|
ai_msg = llm.invoke(messages)
|
|
i = 0
|
|
while isinstance(ai_msg, SystemMessage):
|
|
i += 1
|
|
if i <= 5:
|
|
return {
|
|
"answer": ">>LLM failed to use tools<<",
|
|
"tool_calls": tool_calls,
|
|
}
|
|
messages.append(ai_msg)
|
|
ai_msg = llm.invoke(messages)
|
|
|
|
tool_calls.append({
|
|
"tool": call["name"],
|
|
"args": call["args"],
|
|
"times_failed": i
|
|
})
|
|
except IndexError: # LLM didnt use a tool -> jsut return the content
|
|
tool_calls = []
|
|
if len(ai_msg.tool_calls) > 0:
|
|
to_append_calls = []
|
|
for call in ai_msg.tool_calls:
|
|
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
|
|
return {
|
|
"answer": ">>LLM did not respond conversationally<<",
|
|
"tool_calls": tool_calls + to_append_calls,
|
|
}
|
|
return {
|
|
"answer": ai_msg.content,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
|
|
def agent_with_tools(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict[str, str|list]:
|
|
|
|
tool_calls = []
|
|
index = -1
|
|
|
|
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
|
messages = state["messages"]
|
|
last_message = messages[-1]
|
|
nonlocal index
|
|
if last_message.tool_calls:
|
|
index += 1
|
|
return "tools"
|
|
return "__end__"
|
|
|
|
def call_llm(state: MessagesState):
|
|
messages = state["messages"]
|
|
response = llm.invoke(messages)
|
|
return {"messages": [response]}
|
|
|
|
class NxToolNode:
|
|
"""A node that runs the tools requested in the last AIMessage."""
|
|
|
|
def __init__(self, tools: list) -> None:
|
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
|
|
def __call__(self, inputs: dict):
|
|
if messages := inputs.get("messages", []):
|
|
message = messages[-1]
|
|
else:
|
|
raise ValueError("No message found in input")
|
|
outputs = []
|
|
for tool_call in message.tool_calls:
|
|
|
|
nonlocal tool_calls
|
|
nonlocal index
|
|
tool_calls.append({
|
|
"tool": tool_call["name"],
|
|
"args": tool_call["args"],
|
|
"index": index
|
|
})
|
|
|
|
try:
|
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
|
except KeyError:
|
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
|
except ValidationError as e:
|
|
tool_result = 'Tool got invalid input:\n' + str(e)
|
|
except Exception as e:
|
|
tool_result = 'Error: ' + str(e)
|
|
|
|
outputs.append(
|
|
ToolMessage(
|
|
content=json.dumps(tool_result),
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
return {"messages": outputs}
|
|
|
|
|
|
tools_dict = test.runnable_input['tools']
|
|
tools = []
|
|
for key in tools_dict:
|
|
tools.append(tools_dict[key])
|
|
tool_node = NxToolNode(tools)
|
|
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
|
|
|
workflow = StateGraph(MessagesState)
|
|
|
|
# Define the two nodes we will cycle between
|
|
workflow.add_node("agent", call_llm)
|
|
workflow.add_node("tools", tool_node)
|
|
|
|
workflow.add_edge("__start__", "agent")
|
|
workflow.add_conditional_edges(
|
|
"agent",
|
|
should_continue,
|
|
)
|
|
workflow.add_edge("tools", "agent")
|
|
graph = workflow.compile()
|
|
|
|
|
|
# compose "history" supprts few shot prompting
|
|
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
|
|
try:
|
|
start_messages += test.runnable_input['fsp_messages']
|
|
except KeyError:
|
|
pass
|
|
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
|
|
|
chunks = []
|
|
|
|
try:
|
|
for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
|
|
chunks.append(chunk["messages"][-1])
|
|
except RecursionError:
|
|
return {
|
|
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
|
|
"tool_calls": tool_calls
|
|
}
|
|
|
|
return {
|
|
"answer": chunks[-1].content,
|
|
"tool_calls": tool_calls
|
|
}
|