Files
test-small-llms/libs/runnables.py
Lennart J. Kurzweg (Nx2) a578dd26a0 mega commit
2024-08-20 20:47:17 +02:00

200 lines
6.2 KiB
Python

from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from libs.ollama_functions import OllamaFunctions
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Test, Model
from langchain.tools import Tool
from typing import Literal
from langgraph.graph import StateGraph, MessagesState
import json
from pydantic import ValidationError
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
if model.supports_tools:
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
else:
llm = OllamaFunctions(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
)
if tools:
llm = llm.bind_tools(tools=tools)
return llm
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
ai_msg = llm.invoke(messages)
messages += [ ai_msg ]
try:
tool_calls = []
assert isinstance(ai_msg, AIMessage)
calls = ai_msg.tool_calls
for call in calls:
selected_tool = tools_dict[call["name"].lower()]
tool_msg = selected_tool.invoke(call)
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": call["name"],
"args": call["args"],
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
if len(ai_msg.tool_calls) > 0:
to_append_calls = []
for call in ai_msg.tool_calls:
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
return {
"answer": ">>LLM did not respond conversationally<<",
"tool_calls": tool_calls + to_append_calls,
}
return {
"answer": ai_msg.content,
"tool_calls": tool_calls,
}
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
tool_calls = []
index = -1
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls:
index += 1
return "tools"
return "__end__"
def call_llm(state: MessagesState):
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
class NxToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
nonlocal tool_calls
nonlocal index
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": index
})
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + str(e)
except Exception as e:
tool_result = 'Error: ' + str(e)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_llm)
workflow.add_node("tools", tool_node)
workflow.add_edge("__start__", "agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
start_messages += test.runnable_input['fsp_messages']
except KeyError:
pass
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
):
chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,
"tool_calls": tool_calls
}