mega commit

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-20 20:47:17 +02:00
parent 4860179a1c
commit a578dd26a0
13 changed files with 608 additions and 305 deletions

View File

@@ -1,76 +1,100 @@
from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from libs.test_class import Test
from libs.ollama_functions import OllamaFunctions
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Test, Model
from langchain.tools import Tool
from typing import Literal
from langgraph.graph import StateGraph, MessagesState
# from langgraph.prebuilt import ToolNode
import json
from pydantic import ValidationError
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
if model.supports_tools:
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
else:
llm = OllamaFunctions(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
)
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
if tools:
llm = llm.bind_tools(tools=tools)
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
return llm
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
)
ai_msg = llm.invoke(prompt)
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
ai_msg = llm.invoke(messages)
ai_msg = llm.invoke(prompt)
prompt.append(ai_msg)
messages += [ ai_msg ]
try:
tool_calls = []
for i in range(len(ai_msg.tool_calls)):
tool_call = ai_msg.tool_calls[i]
selected_tool = tools_dict[tool_call["name"].lower()]
tool_msg = selected_tool.invoke(tool_call)
prompt.append(tool_msg)
ai_msg = llm.invoke(prompt)
assert isinstance(ai_msg, AIMessage)
calls = ai_msg.tool_calls
for call in calls:
selected_tool = tools_dict[call["name"].lower()]
tool_msg = selected_tool.invoke(call)
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": 0
"tool": call["name"],
"args": call["args"],
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
if len(ai_msg.tool_calls) > 0:
to_append_calls = []
for call in ai_msg.tool_calls:
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
return {
"answer": ">>LLM did not respond conversationally<<",
"tool_calls": tool_calls + to_append_calls,
}
return {
"answer": ai_msg.content,
"tool_calls": tool_calls
"tool_calls": tool_calls,
}
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
tool_calls = []
index = -1
@@ -79,6 +103,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls:
index += 1
return "tools"
@@ -113,10 +138,10 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e:
except KeyError:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e
tool_result = 'Tool got invalid input:\n' + str(e)
except Exception as e:
tool_result = 'Error: ' + str(e)
@@ -135,11 +160,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
workflow = StateGraph(MessagesState)
@@ -156,124 +177,21 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
graph = workflow.compile()
# example with a single tool call
start_messages = [
SystemMessage(content=test.runnable_input['system_msg']),
HumanMessage(content=test.runnable_input['human_msg'])
]
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
start_messages += test.runnable_input['fsp_messages']
except KeyError:
pass
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
): chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,
"tool_calls": tool_calls
}
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
tool_calls = []
index = -1
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
if last_message.tool_calls:
index += 1
return "tools"
return "__end__"
def call_llm(state: MessagesState):
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
class NxToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
nonlocal tool_calls
nonlocal index
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": index
})
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e
except Exception as e:
tool_result = 'Error: ' + str(e)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_llm)
workflow.add_node("tools", tool_node)
workflow.add_edge("__start__", "agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# example with a single tool call
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
): chunks.append(chunk["messages"][-1])
):
chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,