mega commit

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-20 20:47:17 +02:00
parent 4860179a1c
commit a578dd26a0
13 changed files with 608 additions and 305 deletions

18
libs/classes.py Normal file
View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import Callable
@dataclass
class Test:
name: str
runnable: Callable
runnable_input: dict
validator: Callable
validation_input: dict
@dataclass
class Model:
display_name: str
identifier: str
supports_tools: bool
parameter_count_in_b: float

5
libs/functions.py Normal file
View File

@@ -0,0 +1,5 @@
def nxhash(text:str) -> str: # @BenVida StackOverflow
hash=0
for ch in text:
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
return str(hex(hash)[2:].upper().zfill(8))

322
libs/ollama_functions.py Normal file
View File

@@ -0,0 +1,322 @@
import json
import uuid
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
Tuple,
)
from types import NoneType
from langchain_ollama.chat_models import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from libs.functions import nxhash
DEFAULT_SYTEM_PROMPT = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
"""
DEFAULT_SYTEM_PROMPT_WITH_HISTORY = """{system_msg}
You continue a chat history either conversationally or with a tool call.
You have access to the following tools:
{tools}
You must either select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
or answer conversationally normally.
The conversation before consisted of the following messages:
{history}
Now you must answer accordingly either conversationally or with another tool call.
For conversational answers: Answer as if it was a continuous conversation. The Human only sees the conversational responses, and not anything about the tools. Do not mention the tools or the process of using them.
"""
CONVERSATIONAL_RESPONSE_TOOL = {
"name": "__conversational_response",
"description": (
"Respond conversationally if no other tools should be called for a given query."
),
"parameters": {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "Conversational response to the user.",
},
},
"required": ["response"],
},
}
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM]
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.construct().schema()
name = schema["title"]
elif isinstance(tool, BaseTool):
schema = tool.tool_call_schema.schema()
name = tool.get_name()
description = tool.description
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
return definition
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = ""
for m in messages:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
system_msg += str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
try:
parsed_chat_result = json.loads(chat_result_str)
except json.JSONDecodeError:
parsed_chat_result = chat_result_str
return parsed_chat_result
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
if not parsed_chat_result:
called_tool_name = None
elif "tool" in parsed_chat_result:
called_tool_name = d["tool"] # per spec
elif "name" in d:
called_tool_name = d["name"] # Phi3 often does this
elif "tool_name" in d:
called_tool_name = d["tool_name"] # Phi3 often does this
elif "action" in d:
called_tool_name = d["action"] # Phi3 does this
else:
return None
try:
called_tool = [tool for tool in functions_list if tool['name'] == called_tool_name][0]
except IndexError:
return None # when a tool is called, but the tool doesnt exist
return called_tool
def _extract_conversaional_response(self, d: dict) -> str:
if ("tool_input" in d and "response" in d["tool_input"]):
response = d["tool_input"]["response"]
elif ("input" in d and "response" in d["input"]):
response = d["input"]["response"]
elif ("args" in d and "response" in d["args"]):
response = d["args"]["response"]
elif "response" in d:
response = d["response"]
elif "input" in d:
response = d["input"]
elif "args" in d:
response = d["args"]
elif "tool_input" in d:
response = d["tool_input"]
else:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
try:
assert isinstance(response, str)
except AssertionError:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
return response
def _extract_tool_args(self, d: dict) -> dict:
if "tool_input" in parsed_chat_result:
called_tool_args = d["tool_input"] # per spec
elif "input" in d:
called_tool_args = d["input"] # Phi3 often does this
elif "args" in d:
called_tool_args = d["args"]
else:
called_tool_args = {}
return called_tool_args
# prepare generation
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# prepare generation without history
else:
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(
tools=functions_str
)
final_messages = [ system_message ] + messages
# genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text
# chekc for validity
if not isinstance(chat_result, str):
raise ValueError("OllamaFunctions does not support non-string output.")
# make str to dict
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
# if model failed to return vailid json, just retrun the whole thing
if isinstance(parsed_chat_result, str):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
# get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if not called_tool:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else:
response_msg = AIMessage(
content="",
tool_calls=[ToolCall(
name=called_tool['name'],
args=_extract_tool_args(self, d=parsed_chat_result),
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)],
)
return ChatResult(generations=[ChatGeneration(message=response_msg)])
@property
def _llm_type(self) -> str:
return "ollama_functions"

View File

@@ -1,30 +1,19 @@
from libs.test_class import Test from libs.classes import Test, Model
from libs.functions import nxhash
from typing import Union from typing import Union
import json import json
def padd(list, element):
longest = 0
for s in list:
longest = max(longest, len(str(s)))
return str(element).ljust(longest)
def nxhash(text:str): # @BenVida StackOverflow
hash=0
for ch in text:
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
return hex(hash)[2:].upper().zfill(8)
def get_len(collection: Union[list, dict]) -> int: def get_len(collection: Union[list, dict]) -> int:
maximum_length = 0 maximum_length = 0
if isinstance(collection, dict): if isinstance(collection, list):
collection_type = "tests" collection_type = "seeds"
elif isinstance(collection, list): elif isinstance(collection, dict):
if isinstance(collection[0], str): if isinstance(collection[list(collection.keys())[0]], Model):
collection_type = "models" collection_type = "models"
elif isinstance(collection[0], int): elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "seeds" collection_type = "tests"
else: else:
raise TypeError("get_len: unsupported collection_type") raise TypeError("get_len: unsupported collection_type")
else: else:
@@ -32,8 +21,8 @@ def get_len(collection: Union[list, dict]) -> int:
match collection_type: match collection_type:
case "models": case "models":
for model_name in collection: for model_id in collection:
maximum_length = max(maximum_length, len(model_name)) maximum_length = max(maximum_length, len(collection[model_id].display_name))
case "seeds": case "seeds":
for seed in collection: for seed in collection:
maximum_length = max(maximum_length, len(str(seed))) maximum_length = max(maximum_length, len(str(seed)))
@@ -48,40 +37,42 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str): def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
try: try:
print("Trying to load saved_results.json") print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f: with open("./saved_results.json", "r") as f:
saved_results = json.load(fp=f) saved_results = json.load(fp=f)
print("Loaded.") print("Loaded.")
except: except FileNotFoundError:
print("saved_results.json not found. Initializing empty.") print("saved_results.json not found. Initializing empty.")
saved_results = {} saved_results = {}
# Get Results # Get Results
run_results = {} run_results = {}
print("Starting to run Tests ... ") print("Starting to run Tests ... ")
for model in models: for model_id in models:
model = models[model_id]
for test_id in tests: for test_id in tests:
test = tests[test_id] test = tests[test_id]
for seed in seeds: for seed in seeds:
# Init dict # Init dict
combination = { combination = {
'test_id': test_id, 'test_id': test_id,
'model': model, 'model_id': model_id,
'seed': seed, 'seed': seed,
} }
hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
combination['test_name'] = test.name combination['test_name'] = test.name
combination['model_name'] = model.display_name
# if hash_key == "DE3D137E": # if hash_key == "DE3D137E":
# pass # pass
if hash_key not in saved_results.keys(): if hash_key not in saved_results.keys():
try: try:
print("\033[0;35mModel '\033[0m" + print("\033[0;35mModel '\033[0m" +
model + model.display_name +
"\033[0;35m'" + "\033[0;35m'" +
(" " * (get_len(models) - len(model))) + (" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" + " with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
@@ -96,7 +87,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
end="" end=""
) )
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
if isinstance(answer, str): # tool capabile return tools called as a list[dict] if isinstance(answer, str):
combination['answer'] = answer combination['answer'] = answer
# combination['tool_calls'] = [] # no entry # combination['tool_calls'] = [] # no entry
del answer del answer
@@ -105,15 +96,14 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
combination['tool_calls'] = answer['tool_calls'] combination['tool_calls'] = answer['tool_calls']
del answer del answer
else: else:
raise Exception(f"runnable returd unkown type {type(answer)}.") raise Exception(f"runnable returned unkown type {type(answer)}.")
combination['test'] = test combination['test'] = test
run_results[hash_key] = combination run_results[hash_key] = combination
print("\r\033[0;32mModel '\033[0m" + print("\r\033[0;32mModel '\033[0m" +
model + model.display_name +
"\033[0;32m'" + "\033[0;32m'" +
(" " * (get_len(models) - len(model))) + (" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" + " with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
@@ -127,12 +117,12 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
"\033[0;32m)\033[0m" "\033[0;32m)\033[0m"
) )
except Exception as e: except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...") print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else: else:
print("\r\033[0;34mModel '\033[0m" + print("\r\033[0;34mModel '\033[0m" +
model + model.display_name +
"\033[0;34m'" + "\033[0;34m'" +
(" " * (get_len(models) - len(model))) + (" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" + " with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
@@ -148,7 +138,8 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
# Validate Results # Validate Results
if run_results != {}: print("\nStarting validation of tests ...") if run_results != {}:
print("\nStarting validation of tests ...")
for hash_key in run_results: for hash_key in run_results:
result = run_results[hash_key] result = run_results[hash_key]
@@ -156,27 +147,28 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
entry = { entry = {
'test_name': result['test_name'], 'test_name': result['test_name'],
'test_id': result['test_id'], 'test_id': result['test_id'],
'model': result['model'], 'model_name': result['model_name'],
'model_id': result['model_id'],
'seed': result['seed'], 'seed': result['seed'],
'answer': result['answer'], 'answer': result['answer'],
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url), 'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
} }
except Exception as e: except Exception as e:
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m") print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m ")
continue continue
try: try:
entry['tool_calls'] = result['tool_calls'] entry['tool_calls'] = result['tool_calls']
except: except KeyError:
pass pass
saved_results[hash_key] = entry # add result with validation to saved results saved_results[hash_key] = entry # add result with validation to saved results
print("\033[0;36mTest results of model '\033[0m" + print("\033[0;36mTest results of model '\033[0m" +
entry['model'] + entry['model_name'] +
"\033[0;36m'" + "\033[0;36m'" +
(" " * (get_len(models) - len(entry['model']))) + (" " * (get_len(models) - len(entry['model_name']))) +
" with seed \033[0m\033[0;30m" + " with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(entry['seed'])))) + ("0" * (get_len(seeds) - len(str(entry['seed'])))) +
"\033[0m" + "\033[0m" +
@@ -188,7 +180,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
" (\033[0m" + " (\033[0m" +
hash_key + hash_key +
"\033[0;36m) evaluated to \033[0m" + "\033[0;36m) evaluated to \033[0m" +
('\033[0;32mcorrect\033[0m' if entry['validation'] == True else '\033[0;31mincorrect\033[0m') ('\033[0;32mcorrect\033[0m' if entry['validation'] else '\033[0;31mincorrect\033[0m')
) )
with open("./saved_results.json", "w") as f: with open("./saved_results.json", "w") as f:

View File

@@ -1,76 +1,100 @@
from types import NoneType
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage from libs.ollama_functions import OllamaFunctions
from libs.test_class import Test from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from libs.classes import Test, Model
from langchain.tools import Tool from langchain.tools import Tool
from typing import Literal from typing import Literal
from langgraph.graph import StateGraph, MessagesState from langgraph.graph import StateGraph, MessagesState
# from langgraph.prebuilt import ToolNode
import json import json
from pydantic import ValidationError from pydantic import ValidationError
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
if model.supports_tools:
llm = ChatOllama(
model=model.identifier,
seed=seed,
base_url=base_url
)
else:
llm = OllamaFunctions(
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
)
def basic(model: str, seed: int, test: Test, base_url: str) -> str: if tools:
system_msg = test.runnable_input['system_msg'] llm = llm.bind_tools(tools=tools)
human_msg = test.runnable_input['human_msg']
if system_msg == None: prompt = [ human_msg ] return llm
else: prompt = [ system_msg, human_msg ]
llm = ChatOllama(
model=model, def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
seed=seed,
base_url=base_url messages = [SystemMessage(test.runnable_input['system_msg'])]
) try:
ai_msg = llm.invoke(prompt) messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = _get_llm(model=model, base_url=base_url, seed=seed)
ai_msg = llm.invoke(messages)
assert isinstance(ai_msg.content, str)
return ai_msg.content return ai_msg.content
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
tools_dict = test.runnable_input['tools'] tools_dict = test.runnable_input['tools']
tools = [] tools = []
for key in tools_dict: for key in tools_dict:
tools.append(tools_dict[key]) tools.append(tools_dict[key])
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
if system_msg == None: prompt = [ human_msg ] messages = [SystemMessage(test.runnable_input['system_msg'])]
else: prompt = [ system_msg, human_msg ] try:
messages += test.runnable_input['fsp_messages']
except KeyError:
pass
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
llm = ChatOllama( ai_msg = llm.invoke(messages)
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
ai_msg = llm.invoke(prompt) messages += [ ai_msg ]
prompt.append(ai_msg)
try: try:
tool_calls = [] tool_calls = []
for i in range(len(ai_msg.tool_calls)): assert isinstance(ai_msg, AIMessage)
tool_call = ai_msg.tool_calls[i] calls = ai_msg.tool_calls
selected_tool = tools_dict[tool_call["name"].lower()] for call in calls:
tool_msg = selected_tool.invoke(tool_call) selected_tool = tools_dict[call["name"].lower()]
prompt.append(tool_msg) tool_msg = selected_tool.invoke(call)
ai_msg = llm.invoke(prompt) messages.append(tool_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({ tool_calls.append({
"tool": tool_call["name"], "tool": call["name"],
"args": tool_call["args"], "args": call["args"],
"index": 0
}) })
except IndexError: # LLM didnt use a tool -> jsut return the content except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = [] tool_calls = []
if len(ai_msg.tool_calls) > 0:
to_append_calls = []
for call in ai_msg.tool_calls:
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
return {
"answer": ">>LLM did not respond conversationally<<",
"tool_calls": tool_calls + to_append_calls,
}
return { return {
"answer": ai_msg.content, "answer": ai_msg.content,
"tool_calls": tool_calls "tool_calls": tool_calls,
} }
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
tool_calls = [] tool_calls = []
index = -1 index = -1
@@ -79,6 +103,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
messages = state["messages"] messages = state["messages"]
last_message = messages[-1] last_message = messages[-1]
nonlocal index nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls: if last_message.tool_calls:
index += 1 index += 1
return "tools" return "tools"
@@ -113,10 +138,10 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
try: try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e: except KeyError:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e: except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e tool_result = 'Tool got invalid input:\n' + str(e)
except Exception as e: except Exception as e:
tool_result = 'Error: ' + str(e) tool_result = 'Error: ' + str(e)
@@ -135,11 +160,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
for key in tools_dict: for key in tools_dict:
tools.append(tools_dict[key]) tools.append(tools_dict[key])
tool_node = NxToolNode(tools) tool_node = NxToolNode(tools)
llm = ChatOllama( llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
workflow = StateGraph(MessagesState) workflow = StateGraph(MessagesState)
@@ -156,124 +177,21 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
graph = workflow.compile() graph = workflow.compile()
# example with a single tool call # compose "history" supprts few shot prompting
start_messages = [ start_messages = [SystemMessage(test.runnable_input['system_msg'])]
SystemMessage(content=test.runnable_input['system_msg']), try:
HumanMessage(content=test.runnable_input['human_msg']) start_messages += test.runnable_input['fsp_messages']
] except KeyError:
pass
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = [] chunks = []
for chunk in graph.stream( for chunk in graph.stream(
{"messages": start_messages}, {"messages": start_messages},
stream_mode="values", stream_mode="values",
): chunks.append(chunk["messages"][-1]) ):
chunks.append(chunk["messages"][-1])
return {
"answer": chunks[-1].content,
"tool_calls": tool_calls
}
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
tool_calls = []
index = -1
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
nonlocal index
if last_message.tool_calls:
index += 1
return "tools"
return "__end__"
def call_llm(state: MessagesState):
messages = state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
class NxToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
nonlocal tool_calls
nonlocal index
tool_calls.append({
"tool": tool_call["name"],
"args": tool_call["args"],
"index": index
})
try:
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
except KeyError as e:
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
except ValidationError as e:
tool_result = 'Tool got invalid input:\n' + e
except Exception as e:
tool_result = 'Error: ' + str(e)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
tool_node = NxToolNode(tools)
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_llm)
workflow.add_node("tools", tool_node)
workflow.add_edge("__start__", "agent")
workflow.add_conditional_edges(
"agent",
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# example with a single tool call
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
): chunks.append(chunk["messages"][-1])
return { return {
"answer": chunks[-1].content, "answer": chunks[-1].content,

View File

@@ -1,10 +0,0 @@
from dataclasses import dataclass, field
from typing import Callable, Any
@dataclass
class Test:
name: str
runnable: Callable
runnable_input: dict
validator: Callable
validation_input: dict

View File

@@ -6,14 +6,14 @@ from typing import Union
@tool @tool
def add(a: float, b: float) -> str: def add(a: float, b: float) -> str:
"""Adds a+b and retuns the sum""" """Adds a+b and returns the sum"""
af = float(a) af = float(a)
bf = float(b) bf = float(b)
return f"{a} + {b} = {a+b}" return f"{a} + {b} = {a+b}"
@tool @tool
def multiply(a: float, b: float) -> str: def multiply(a: float, b: float) -> str:
"""Multiplies a*b and retuns the product""" """Multiplies a*b and returns the product"""
af = float(a) af = float(a)
bf = float(b) bf = float(b)
return f"{a} * {b} = {a*b}" return f"{a} * {b} = {a*b}"

View File

@@ -1,7 +1,7 @@
from langchain_ollama.chat_models import ChatOllama from langchain_ollama.chat_models import ChatOllama
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.tools import tool from langchain.tools import tool
from libs.test_class import Test from libs.classes import Test
from re import search from re import search
from textwrap import dedent from textwrap import dedent
@@ -16,7 +16,7 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
{validation_input} {validation_input}
else as incorrect. Only use the rate tool. Do not answer conversationally.""")), else as incorrect. Only use the `rate` tool. You do not have accesss to any other tools. Do not answer conversationally.""")),
HumanMessagePromptTemplate.from_template(template=dedent("""System Message: HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
{system_msg} {system_msg}
@@ -50,7 +50,10 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
elif ret_str.lower() == 'false': return False elif ret_str.lower() == 'false': return False
else: raise Exception(f"rate tool retured {ret_str}") else: raise Exception(f"rate tool retured {ret_str}")
except IndexError as e: except IndexError as e:
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...") print(f"\033[0;31mValidation Error of\033[0m {test.name} \033[0;31m<\033[0m{ai_msg.content[:20]}\033[0;31m...> Retrying...\033[0m")
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
except Exception as e:
print(f"\033[0;31mValidation Error \033[0mof {test.name} \033[0;31m<\033[0m{e}\033[0;31m> Retrying...\033[0m")
return system_human_answer_match(test=test, answer=answer, base_url=base_url) return system_human_answer_match(test=test, answer=answer, base_url=base_url)
def regex_match_any(test: Test, answer: str, base_url: str) -> bool: def regex_match_any(test: Test, answer: str, base_url: str) -> bool:

View File

@@ -1,6 +1,7 @@
langchain langchain
langchain-core langchain-core
langchain-ollama langchain-ollama
langchain-community
langgraph langgraph
seaborn seaborn
pandas pandas

View File

@@ -1,13 +1,112 @@
models = [ from libs.classes import Model
"llama3.1", # 8b
"llama3.1:70b", models = {
"llama3-groq-tool-use", # latest 245: Model(
"llama3-groq-tool-use:70b", display_name="llama3.1 8b",
# "mixtral:8x7b", identifier="llama3.1",
"mixtral:8x22b", supports_tools=True,
# "gemma2:2b", parameter_count_in_b=8
# "phi3", # 3.8b ),
# "tinyllama:1.1b", 238: Model(
"mistral-nemo:12b", display_name="llama3.1 70b",
"command-r-plus:104b", identifier="llama3.1:70b",
] supports_tools=True,
parameter_count_in_b=70
),
120: Model(
display_name="llama3 groq TU 8b",
identifier="llama3-groq-tool-use",
supports_tools=True,
parameter_count_in_b=8
),
890: Model(
display_name="llama3 groq TU 70b",
identifier="llama3-groq-tool-use:70b",
supports_tools=True,
parameter_count_in_b=70
),
348: Model(
display_name="Mixtral MoE 8x7b",
identifier="mixtral:8x7b",
supports_tools=False,
parameter_count_in_b=13,
),
789: Model(
display_name="Mixtral MoE 8x22b",
identifier="mixtral:8x22b",
supports_tools=True,
parameter_count_in_b=39
),
445: Model(
display_name="Gemma2 2b",
identifier="gemma2:2b",
supports_tools=False,
parameter_count_in_b=2
),
475: Model(
display_name="Gemma2 9b",
identifier="gemma2:2b",
supports_tools=False,
parameter_count_in_b=9
),
626: Model(
display_name="Gemma2 27b",
identifier="gemma2:2b",
supports_tools=False,
parameter_count_in_b=27
),
229: Model(
display_name="Phi3 3.8b",
identifier="phi3",
supports_tools=False,
parameter_count_in_b=3.8
),
903: Model(
display_name="Tinyllama 1.1b",
identifier="tinyllama:1.1b",
supports_tools=False,
parameter_count_in_b=1.1
),
670: Model(
display_name="Mistral Nemo 12b",
identifier="mistral-nemo:12b",
supports_tools=True,
parameter_count_in_b=12
),
404: Model(
display_name="Command R+ 104b",
identifier="command-r-plus:104b",
supports_tools=True,
parameter_count_in_b=104
),
701: Model(
display_name="Yi 6b",
identifier="yi:7b",
supports_tools=False,
parameter_count_in_b=6
),
704: Model(
display_name="Yi 6b",
identifier="yi:7b",
supports_tools=False,
parameter_count_in_b=6
),
724: Model(
display_name="Yi 34b",
identifier="yi:34b",
supports_tools=False,
parameter_count_in_b=34
),
129: Model(
display_name="Yi 34b",
identifier="yi:34b",
supports_tools=False,
parameter_count_in_b=34
),
853: Model(
display_name="Qwen2 0.5b",
identifier="qwen2:0.5b",
supports_tools=False,
parameter_count_in_b=0.5
),
}

View File

@@ -1,21 +1,21 @@
from libs.test_class import Test from libs.classes import Test
from libs.runnables import * from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools
from libs.validators import * from libs.validators import regex_match_any, system_human_answer_match
from libs.tools import * from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note
from textwrap import dedent from textwrap import dedent
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
tests = { tests = {
607: Test( 607: Test(
name="Healthy Vegetables in Chinese", name="Healthy Vegetables in Chinese",
runnable=basic, runnable=basic_prompt,
runnable_input={ runnable_input={
"system_msg": "You are a helpful assistant. You serve people across the globe.", "system_msg": "You are a helpful assistant. You serve people across the globe.",
"human_msg": "什么蔬菜最健康?", "human_msg": "什么蔬菜最健康?",
}, },
validator=system_human_answer_match, validator=system_human_answer_match,
validation_input={ validation_input={
"criteria": dedent("""- in Mandarin Chinese from front to finnish "criteria": dedent("""- in Mandarin Chinese from front to finnish
- factually correct - factually correct
- about healthy vegetables - about healthy vegetables
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes) - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)
@@ -23,7 +23,7 @@ tests = {
Again, the message has to be entirely in Manadarin Chineese. Again, the message has to be entirely in Manadarin Chineese.
That means If the answer is not in Chinese the answer is NOT correct! Only if the message in in Chinese rate as correct"""), That means If the answer is not in Chinese the answer is NOT correct! Only if the message in in Chinese rate as correct"""),
} }
), ),
693: Test( 693: Test(
name="Simple Multiplication", name="Simple Multiplication",
runnable=one_tool_call_answer, runnable=one_tool_call_answer,
@@ -52,12 +52,12 @@ tests = {
"multiply": multiply "multiply": multiply
} }
}, },
validator=regex_match_any, validator=regex_match_any,
validation_input={ validation_input={
"patterns": [ "6134205", "6.134.205", "6,134,205" ] "patterns": [ "6134205", "6.134.205", "6,134,205" ]
} }
), ),
283: Test( 283: Test(
name="Notes from last Saturday", name="Notes from last Saturday",
runnable=agent_with_tools, runnable=agent_with_tools,
runnable_input={ runnable_input={
@@ -67,16 +67,16 @@ tests = {
"get_current_date_and_time": get_current_date_and_time, "get_current_date_and_time": get_current_date_and_time,
"get_notes_in_timespan": get_notes_in_timespan, "get_notes_in_timespan": get_notes_in_timespan,
"get_notes_containing": get_notes_containing, "get_notes_containing": get_notes_containing,
"Write note": write_note "Write note": write_note,
} }
}, },
validator=system_human_answer_match, validator=system_human_answer_match,
validation_input={ validation_input={
"criteria": dedent("""- containing the information that the Human should call Wolfgang "criteria": dedent("""- containing the information that the Human should call Wolfgang
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
} }
), ),
260: Test( 260: Test(
name="Notes from last Saturday TSO", # time span only name="Notes from last Saturday TSO", # time span only
runnable=agent_with_tools, runnable=agent_with_tools,
runnable_input={ runnable_input={
@@ -88,15 +88,15 @@ tests = {
"Write note": write_note "Write note": write_note
} }
}, },
validator=system_human_answer_match, validator=system_human_answer_match,
validation_input={ validation_input={
"criteria": dedent("""- containing the information that the Human should call Wolfgang "criteria": dedent("""- containing the information that the Human should call Wolfgang
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
} }
), ),
856: Test( 856: Test(
name="Notes from last Saturday TSO FSP", name="Notes from last Saturday TSO FSP",
runnable=agent_with_tools_fsp, runnable=agent_with_tools,
runnable_input={ runnable_input={
"system_msg": "You are a helpful assistant. You can use tools to accomplish tasks. Once you've called a tool, the resulting tool_message content can be taken into consideration again. With that you can do \"multiple rounds\" of tool calling. To know the date, use the tool get_current_date_and_time.", "system_msg": "You are a helpful assistant. You can use tools to accomplish tasks. Once you've called a tool, the resulting tool_message content can be taken into consideration again. With that you can do \"multiple rounds\" of tool calling. To know the date, use the tool get_current_date_and_time.",
"fsp_messages": [ "fsp_messages": [
@@ -121,12 +121,12 @@ tests = {
"Write note": write_note "Write note": write_note
} }
}, },
validator=system_human_answer_match, validator=system_human_answer_match,
validation_input={ validation_input={
"criteria": dedent("""- containing the information that the Human should call Wolfgang "criteria": dedent("""- containing the information that the Human should call Wolfgang
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
} }
), ),
# 363: Test(), # 363: Test(),
# 600: Test(), # 600: Test(),
# 221: Test(), # 221: Test(),

View File

@@ -1,15 +1,16 @@
from libs.run_tests import run_tests from libs.run_tests import run_tests
from suite_settings.models import models from suite_settings.models import models
from suite_settings.seeds import seeds from suite_settings.seeds import seeds
from suite_settings.tests import tests from suite_settings.tests import tests
def main(): def main():
results = run_tests( run_tests(
models=models, models=models,
seeds=seeds, seeds=seeds,
tests=tests, tests=tests,
base_url="http://bolt.hs-mittweida.de:11434" base_url="http://bolt.hs-mittweida.de:11434",
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,7 +3,6 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns
from math import pi
# Load the JSON data # Load the JSON data
with open('saved_results.json', 'r') as f: with open('saved_results.json', 'r') as f:
@@ -14,7 +13,7 @@ results = []
for test_hash, test_data in data.items(): for test_hash, test_data in data.items():
results.append({ results.append({
"hash": test_hash, "hash": test_hash,
"model": test_data['model'], "model": test_data['model_name'],
"seed": test_data['seed'], "seed": test_data['seed'],
"test_name": test_data['test_name'], "test_name": test_data['test_name'],
"validation": test_data['validation'] "validation": test_data['validation']
@@ -61,52 +60,7 @@ plt.savefig('validation_results_by_test_name.png')
## 3rd Chart ## 3rd Chart
# Prepare data for the spider chart
models = df['model'].unique()
# Calculate the pass rate for each model on each test
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0) pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
tests = df['test_name'].unique().tolist()
# Initialize the spider plot
num_vars = len(pass_rate)-1
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += [ angles[0] ]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
# Plot each model's performance
for model in models:
values = pass_rate.loc[model].tolist()
values += [ values[0] ]
ax.fill(angles, values, alpha=0.25)
ax.plot(angles, values, label=model)
#
# Configure the spider chart
ax.set_theta_offset(pi / 2)
ax.set_theta_direction(-1)
tests.append(tests[0])
tests.pop(0)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(tests)
ax.set_yticks(np.linspace(0, 1, 5))
ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8)
ax.set_ylim(0, 1)
plt.title('Model Performance on Each Test')
plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
plt.tight_layout()
plt.savefig('model_performance_spider_chart.png')
# 4th chart
# Create a heatmap # Create a heatmap
plt.figure(figsize=(8, 8)) plt.figure(figsize=(8, 8))
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10}) sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})