mega commit

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-20 20:47:17 +02:00
parent 4860179a1c
commit a578dd26a0
13 changed files with 608 additions and 305 deletions

18
libs/classes.py Normal file
View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import Callable
@dataclass
class Test:
name: str
runnable: Callable
runnable_input: dict
validator: Callable
validation_input: dict
@dataclass
class Model:
display_name: str
identifier: str
supports_tools: bool
parameter_count_in_b: float

5
libs/functions.py Normal file
View File

@@ -0,0 +1,5 @@
def nxhash(text:str) -> str: # @BenVida StackOverflow
hash=0
for ch in text:
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
return str(hex(hash)[2:].upper().zfill(8))

322
libs/ollama_functions.py Normal file
View File

@@ -0,0 +1,322 @@
import json
import uuid
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
Tuple,
)
from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from libs.functions import nxhash
DEFAULT_SYTEM_PROMPT = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
"""
DEFAULT_SYTEM_PROMPT_WITH_HISTORY = """{system_msg}
You continue a chat history either conversationally or with a tool call.
You have access to the following tools:
{tools}
You must either select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
or answer conversationally normally.
The conversation before consisted of the following messages:
{history}
Now you must answer accordingly either conversationally or with another tool call.
For conversational answers: Answer as if it was a continuous conversation. The Human only sees the conversational responses, and not anything about the tools. Do not mention the tools or the process of using them.
"""
CONVERSATIONAL_RESPONSE_TOOL = {
"name": "__conversational_response",
"description": (
"Respond conversationally if no other tools should be called for a given query."
),
"parameters": {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "Conversational response to the user.",
},
},
"required": ["response"],
},
}
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.construct().schema()
name = schema["title"]
elif isinstance(tool, BaseTool):
schema = tool.tool_call_schema.schema()
name = tool.get_name()
description = tool.description
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
return definition
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = ""
for m in messages:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
system_msg += str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
try:
parsed_chat_result = json.loads(chat_result_str)
except json.JSONDecodeError:
parsed_chat_result = chat_result_str
return parsed_chat_result
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
if not parsed_chat_result:
called_tool_name = None
elif "tool" in parsed_chat_result:
called_tool_name = d["tool"] # per spec
elif "name" in d:
called_tool_name = d["name"] # Phi3 often does this
elif "tool_name" in d:
called_tool_name = d["tool_name"] # Phi3 often does this
elif "action" in d:
called_tool_name = d["action"] # Phi3 does this
else:
return None
try:
called_tool = [tool for tool in functions_list if tool['name'] == called_tool_name][0]
except IndexError:
return None # when a tool is called, but the tool doesnt exist
return called_tool
def _extract_conversaional_response(self, d: dict) -> str:
if ("tool_input" in d and "response" in d["tool_input"]):
response = d["tool_input"]["response"]
elif ("input" in d and "response" in d["input"]):
response = d["input"]["response"]
elif ("args" in d and "response" in d["args"]):
response = d["args"]["response"]
elif "response" in d:
response = d["response"]
elif "input" in d:
response = d["input"]
elif "args" in d:
response = d["args"]
elif "tool_input" in d:
response = d["tool_input"]
else:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
try:
assert isinstance(response, str)
except AssertionError:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
return response
def _extract_tool_args(self, d: dict) -> dict:
if "tool_input" in parsed_chat_result:
called_tool_args = d["tool_input"] # per spec
elif "input" in d:
called_tool_args = d["input"] # Phi3 often does this
elif "args" in d:
called_tool_args = d["args"]
else:
called_tool_args = {}
return called_tool_args
# prepare generation
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# prepare generation without history
else:
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(
tools=functions_str
)
final_messages = [ system_message ] + messages
# genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text
# chekc for validity
if not isinstance(chat_result, str):
raise ValueError("OllamaFunctions does not support non-string output.")
# make str to dict
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
# if model failed to return vailid json, just retrun the whole thing
if isinstance(parsed_chat_result, str):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
# get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if not called_tool:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else:
response_msg = AIMessage(
content="",
tool_calls=[ToolCall(
name=called_tool['name'],
args=_extract_tool_args(self, d=parsed_chat_result),
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)],
)
return ChatResult(generations=[ChatGeneration(message=response_msg)])
@property
def _llm_type(self) -> str:
return "ollama_functions"

View File

@@ -1,30 +1,19 @@
from libs.test_class import Test
from libs.classes import Test, Model
from libs.functions import nxhash
from typing import Union
import json
def padd(list, element):
longest = 0
for s in list:
longest = max(longest, len(str(s)))
return str(element).ljust(longest)
def nxhash(text:str): # @BenVida StackOverflow
hash=0
for ch in text:
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
return hex(hash)[2:].upper().zfill(8)
def get_len(collection: Union[list, dict]) -> int:
maximum_length = 0
if isinstance(collection, dict):
collection_type = "tests"
elif isinstance(collection, list):
if isinstance(collection[0], str):
collection_type = "models"
elif isinstance(collection[0], int):
collection_type = "seeds"
if isinstance(collection, list):
collection_type = "seeds"
elif isinstance(collection, dict):
if isinstance(collection[list(collection.keys())[0]], Model):
collection_type = "models"
elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "tests"
else:
raise TypeError("get_len: unsupported collection_type")
else:
@@ -32,8 +21,8 @@ def get_len(collection: Union[list, dict]) -> int:
match collection_type:
case "models":
for model_name in collection:
maximum_length = max(maximum_length, len(model_name))
for model_id in collection:
maximum_length = max(maximum_length, len(collection[model_id].display_name))
case "seeds":
for seed in collection:
maximum_length = max(maximum_length, len(str(seed)))
@@ -48,40 +37,42 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str):
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
try:
print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f:
saved_results = json.load(fp=f)
print("Loaded.")
except:
except FileNotFoundError:
print("saved_results.json not found. Initializing empty.")
saved_results = {}
# Get Results
run_results = {}
print("Starting to run Tests ... ")
for model in models:
for model_id in models:
model = models[model_id]
for test_id in tests:
test = tests[test_id]
for seed in seeds:
# Init dict
combination = {
'test_id': test_id,
'model': model,
'model_id': model_id,
'seed': seed,
}
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
combination['test_name'] = test.name
combination['model_name'] = model.display_name
# if hash_key == "DE3D137E":
# pass
if hash_key not in saved_results.keys():
try:
print("\033[0;35mModel '\033[0m" +
model +
model.display_name +
"\033[0;35m'" +
(" " * (get_len(models) - len(model))) +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
@@ -96,7 +87,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
end=""
)
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
if isinstance(answer, str): # tool capabile return tools called as a list[dict]
if isinstance(answer, str):
combination['answer'] = answer
# combination['tool_calls'] = [] # no entry
del answer
@@ -105,15 +96,14 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
combination['tool_calls'] = answer['tool_calls']
del answer
else:
raise Exception(f"runnable returd unkown type {type(answer)}.")
raise Exception(f"runnable returned unkown type {type(answer)}.")
combination['test'] = test
run_results[hash_key] = combination
print("\r\033[0;32mModel '\033[0m" +
model +
model.display_name +
"\033[0;32m'" +
(" " * (get_len(models) - len(model))) +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
@@ -127,12 +117,12 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
"\033[0;32m)\033[0m"
)
except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...")
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else:
print("\r\033[0;34mModel '\033[0m" +
model +
model.display_name +
"\033[0;34m'" +
(" " * (get_len(models) - len(model))) +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
@@ -148,7 +138,8 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
# Validate Results
if run_results != {}: print("\nStarting validation of tests ...")
if run_results != {}:
print("\nStarting validation of tests ...")
for hash_key in run_results:
result = run_results[hash_key]
@@ -156,27 +147,28 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
entry = {
'test_name': result['test_name'],
'test_id': result['test_id'],
'model': result['model'],
'model_name': result['model_name'],
'model_id': result['model_id'],
'seed': result['seed'],
'answer': result['answer'],
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
}
except Exception as e:
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m")
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m ")
continue
try:
entry['tool_calls'] = result['tool_calls']
except:
except KeyError:
pass
saved_results[hash_key] = entry # add result with validation to saved results
print("\033[0;36mTest results of model '\033[0m" +
entry['model'] +
entry['model_name'] +
"\033[0;36m'" +
(" " * (get_len(models) - len(entry['model']))) +
(" " * (get_len(models) - len(entry['model_name']))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(entry['seed'])))) +
"\033[0m" +
@@ -188,7 +180,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
" (\033[0m" +
hash_key +
"\033[0;36m) evaluated to \033[0m" +
('\033[0;32mcorrect\033[0m' if entry['validation'] == True else '\033[0;31mincorrect\033[0m')
('\033[0;32mcorrect\033[0m' if entry['validation'] else '\033[0;31mincorrect\033[0m')
)
with open("./saved_results.json", "w") as f:

View File

@@ -1,76 +1,100 @@
from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
from libs.test_class import Test
from libs.ollama_functions import OllamaFunctions
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Test, Model
from langchain.tools import Tool
from typing import Literal
from langgraph.graph import StateGraph, MessagesState
# from langgraph.prebuilt import ToolNode
import json
from pydantic import ValidationError
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
if model.supports_tools:
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
else:
llm = OllamaFunctions(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
)
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
if tools:
llm = llm.bind_tools(tools=tools)
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
return llm
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
)
ai_msg = llm.invoke(prompt)
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
ai_msg = llm.invoke(messages)
ai_msg = llm.invoke(prompt)
prompt.append(ai_msg)
messages += [ ai_msg ]
try:
tool_calls = []
for i in range(len(ai_msg.tool_calls)):
tool_call = ai_msg.tool_calls[i]
selected_tool = tools_dict[tool_call["name"].lower()]
tool_msg = selected_tool.invoke(tool_call)
prompt.append(tool_msg)
ai_msg = llm.invoke(prompt)
assert isinstance(ai_msg, AIMessage)
calls = ai_msg.tool_calls
for call in calls:
selected_tool = tools_dict[call["name"].lower()]
tool_msg = selected_tool.invoke(call)
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": 0
"tool": call["name"],
"args": call["args"],
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
if len(ai_msg.tool_calls) > 0:
to_append_calls = []
for call in ai_msg.tool_calls:
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
return {
"answer": ">>LLM did not respond conversationally<<",
"tool_calls": tool_calls + to_append_calls,
}
return {
"answer": ai_msg.content,
"tool_calls": tool_calls
"tool_calls": tool_calls,
}
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
tool_calls = []
index = -1
@@ -79,6 +103,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls:
index += 1
return "tools"
@@ -113,10 +138,10 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e:
except KeyError:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e
tool_result = 'Tool got invalid input:\n' + str(e)
except Exception as e:
tool_result = 'Error: ' + str(e)
@@ -135,11 +160,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
workflow = StateGraph(MessagesState)
@@ -156,124 +177,21 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
graph = workflow.compile()
# example with a single tool call
start_messages = [
SystemMessage(content=test.runnable_input['system_msg']),
HumanMessage(content=test.runnable_input['human_msg'])
]
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
start_messages += test.runnable_input['fsp_messages']
except KeyError:
pass
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
): chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,
"tool_calls": tool_calls
}
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
tool_calls = []
index = -1
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
if last_message.tool_calls:
index += 1
return "tools"
return "__end__"
def call_llm(state: MessagesState):
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
class NxToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
nonlocal tool_calls
nonlocal index
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": index
})
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e
except Exception as e:
tool_result = 'Error: ' + str(e)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_llm)
workflow.add_node("tools", tool_node)
workflow.add_edge("__start__", "agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# example with a single tool call
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
): chunks.append(chunk["messages"][-1])
):
chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,

View File

@@ -1,10 +0,0 @@
from dataclasses import dataclass, field
from typing import Callable, Any
@dataclass
class Test:
name: str
runnable: Callable
runnable_input: dict
validator: Callable
validation_input: dict

View File

@@ -6,14 +6,14 @@ from typing import Union
@tool
def add(a: float, b: float) -> str:
"""Adds a+b and retuns the sum"""
"""Adds a+b and returns the sum"""
af = float(a)
bf = float(b)
return f"{a} + {b} = {a+b}"
@tool
def multiply(a: float, b: float) -> str:
"""Multiplies a*b and retuns the product"""
"""Multiplies a*b and returns the product"""
af = float(a)
bf = float(b)
return f"{a} * {b} = {a*b}"

View File

@@ -1,7 +1,7 @@
from langchain_ollama.chat_models import ChatOllama
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.tools import tool
from libs.test_class import Test
from libs.classes import Test
from re import search
from textwrap import dedent
@@ -16,7 +16,7 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
{validation_input}
else as incorrect. Only use the rate tool. Do not answer conversationally.""")),
else as incorrect. Only use the `rate` tool. You do not have accesss to any other tools. Do not answer conversationally.""")),
HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
{system_msg}
@@ -50,7 +50,10 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
elif ret_str.lower() == 'false': return False
else: raise Exception(f"rate tool retured {ret_str}")
except IndexError as e:
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
print(f"\033[0;31mValidation Error of\033[0m {test.name} \033[0;31m<\033[0m{ai_msg.content[:20]}\033[0;31m...> Retrying...\033[0m")
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
except Exception as e:
print(f"\033[0;31mValidation Error \033[0mof {test.name} \033[0;31m<\033[0m{e}\033[0;31m> Retrying...\033[0m")
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
def regex_match_any(test: Test, answer: str, base_url: str) -> bool: