Files
test-small-llms/libs/runnables.py
Lennart J. Kurzweg (Nx2) 314077a63d typos
2024-10-05 20:47:56 +02:00

231 lines
7.8 KiB
Python

from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from libs.ollama_functions import OllamaFunctionsLSM, OllamaFunctionsT2S
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Technique, Test, Model
from langchain.tools import Tool
from typing import Literal
from langgraph.graph import StateGraph, MessagesState
import json
from pydantic import ValidationError
from suite_settings.techniques import techniques
def _get_llm(model: Model, base_url: str, seed: int, technique: Technique, tools: list[Tool]|NoneType = None):
if technique == techniques[1]: # Native
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
elif technique == techniques[903]: # Long System Message
llm = OllamaFunctionsLSM(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json",
)
elif technique == techniques[572]: # ToolMessages to SystemMessages
llm = OllamaFunctionsT2S(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json",
)
else:
raise ValueError("Unkown Technique in _get_llm()")
if tools:
llm = llm.bind_tools(tools=tools)
return llm
def basic_prompt(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> str:
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, technique=technique, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content
def one_tool_call_answer(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict:
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
ai_msg = llm.invoke(messages)
messages += [ ai_msg ]
try:
tool_calls = []
assert isinstance(ai_msg, AIMessage)
calls = ai_msg.tool_calls
for call in calls:
try:
selected_tool = tools_dict[call["name"].lower()]
tool_msg = selected_tool.invoke(call)
except KeyError:
tool_msg = SystemMessage(f"Tool '{call['name'].lower()}' does not exist. Available are {tools_dict.keys()}")
except Exception as e:
tool_msg = SystemMessage(f"Tool '{call['name'].lower()}' returned a input validation error:" + "\n" + str(e))
finally:
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
i = 0
while isinstance(ai_msg, SystemMessage):
i += 1
if i <= 5:
return {
"answer": ">>LLM failed to use tools<<",
"tool_calls": tool_calls,
}
messages.append(ai_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": call["name"],
"args": call["args"],
"times_failed": i
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
if len(ai_msg.tool_calls) > 0:
to_append_calls = []
for call in ai_msg.tool_calls:
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
return {
"answer": ">>LLM did not respond conversationally<<",
"tool_calls": tool_calls + to_append_calls,
}
return {
"answer": ai_msg.content,
"tool_calls": tool_calls,
}
def agent_with_tools(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict[str, str|list]:
tool_calls = []
index = -1
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
if last_message.tool_calls:
index += 1
return "tools"
return "__end__"
def call_llm(state: MessagesState):
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
class NxToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
nonlocal tool_calls
nonlocal index
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": index
})
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + str(e)
except Exception as e:
tool_result = 'Error: ' + str(e)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_llm)
workflow.add_node("tools", tool_node)
workflow.add_edge("__start__", "agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
start_messages += test.runnable_input['fsp_messages']
except KeyError:
pass
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
try:
for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
chunks.append(chunk["messages"][-1])
except RecursionError:
return {
"answer": ">>Model did not come to a conclusion (Recursion Error)<<",
"tool_calls": tool_calls
}
return {
"answer": chunks[-1].content,
"tool_calls": tool_calls
}