diff --git a/libs/classes.py b/libs/classes.py new file mode 100644 index 0000000..4bbd754 --- /dev/null +++ b/libs/classes.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Callable + +@dataclass +class Test: + name: str + runnable: Callable + runnable_input: dict + validator: Callable + validation_input: dict + +@dataclass +class Model: + display_name: str + identifier: str + supports_tools: bool + parameter_count_in_b: float + diff --git a/libs/functions.py b/libs/functions.py new file mode 100644 index 0000000..3ec1b67 --- /dev/null +++ b/libs/functions.py @@ -0,0 +1,5 @@ +def nxhash(text:str) -> str: # @BenVida StackOverflow + hash=0 + for ch in text: + hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF + return str(hex(hash)[2:].upper().zfill(8)) \ No newline at end of file diff --git a/libs/ollama_functions.py b/libs/ollama_functions.py new file mode 100644 index 0000000..b831ea5 --- /dev/null +++ b/libs/ollama_functions.py @@ -0,0 +1,322 @@ +import json +import uuid +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + Tuple, +) +from types import NoneType + +from langchain_ollama.chat_models import ChatOllama +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.prompts import SystemMessagePromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool, Tool +from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass + +from libs.functions import nxhash + +DEFAULT_SYTEM_PROMPT = """You have access to the following tools: + +{tools} + +You must always select one of the above tools and respond with only a JSON object matching the following schema: + +{{ + "tool": , + "tool_input": +}} +""" + + +DEFAULT_SYTEM_PROMPT_WITH_HISTORY = """{system_msg} + +You continue a chat history either conversationally or with a tool call. + +You have access to the following tools: + +{tools} + +You must either select one of the above tools and respond with only a JSON object matching the following schema: + +{{ + "tool": , + "tool_input": +}} + +or answer conversationally normally. + +The conversation before consisted of the following messages: + +{history} + +Now you must answer accordingly either conversationally or with another tool call. + +For conversational answers: Answer as if it was a continuous conversation. The Human only sees the conversational responses, and not anything about the tools. Do not mention the tools or the process of using them. +""" +CONVERSATIONAL_RESPONSE_TOOL = { + "name": "__conversational_response", + "description": ( + "Respond conversationally if no other tools should be called for a given query." + ), + "parameters": { + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "Conversational response to the user.", + }, + }, + "required": ["response"], + }, +} + + +_BM = TypeVar("_BM", bound=BaseModel) +_DictOrPydantic = Union[Dict, _BM] + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and ( + is_basemodel_subclass(obj) or BaseModel in obj.__bases__ + ) + + +def convert_to_ollama_tool(tool: Any) -> Dict: + """Convert a tool to an Ollama tool.""" + description = None + if _is_pydantic_class(tool): + schema = tool.construct().schema() + name = schema["title"] + elif isinstance(tool, BaseTool): + schema = tool.tool_call_schema.schema() + name = tool.get_name() + description = tool.description + elif is_basemodel_instance(tool): + schema = tool.get_input_schema().schema() + name = tool.get_name() + description = tool.description + elif isinstance(tool, dict) and "name" in tool and "parameters" in tool: + return tool.copy() + else: + raise ValueError( + f"""Cannot convert {tool} to an Ollama tool. + {tool} needs to be a Pydantic class, model, or a dict.""" + ) + definition = {"name": name, "parameters": schema} + if description: + definition["description"] = description + + return definition + +def parse_response(message: BaseMessage) -> str: + """Extract `function_call` from `AIMessage`.""" + if isinstance(message, AIMessage): + kwargs = message.additional_kwargs + tool_calls = message.tool_calls + if len(tool_calls) > 0: + tool_call = tool_calls[-1] + args = tool_call.get("args") + return json.dumps(args) + elif "function_call" in kwargs: + if "arguments" in kwargs["function_call"]: + return kwargs["function_call"]["arguments"] + raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}") + else: + raise ValueError("`tool_calls` missing from AIMessage: {message}") + raise ValueError(f"`message` is not an instance of `AIMessage`: {message}") + + + +class OllamaFunctions(ChatOllama): + """Function chat model that uses Ollama API.""" + + tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT + tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + return self.bind(functions=tools, **kwargs) + + + def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult: + + def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]: + def _format_tools_for_history(tool_calls: list[ToolCall]) -> str: + call_list = [] + for c in tool_calls: + call_list.append({ + "id": nxhash(c['id'])[-4:], + "tool": c['name'], + "args": c['args'] + }) + if len(call_list) == 1: + return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2) + else: + return json.dumps(obj=call_list, ensure_ascii=False, indent=2) + formated_history = "" + system_msg = "" + for m in messages: + + if formated_history != "": + formated_history += "\n\n" + + if isinstance(m, SystemMessage): + system_msg += str(m.content) + elif isinstance(m, HumanMessage): + formated_history += "The Human said:\n" + str(m.content) + elif isinstance(m, AIMessage) and m.tool_calls: + formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls) + elif isinstance(m, ToolMessage): + formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content) + elif isinstance(m, AIMessage) and not m.tool_calls: + formated_history += "You said:\n" + str(m.content) + else: + raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m))) + + return system_msg, formated_history + + + def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]: + try: + parsed_chat_result = json.loads(chat_result_str) + except json.JSONDecodeError: + parsed_chat_result = chat_result_str + return parsed_chat_result + + def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType: + if not parsed_chat_result: + called_tool_name = None + elif "tool" in parsed_chat_result: + called_tool_name = d["tool"] # per spec + elif "name" in d: + called_tool_name = d["name"] # Phi3 often does this + elif "tool_name" in d: + called_tool_name = d["tool_name"] # Phi3 often does this + elif "action" in d: + called_tool_name = d["action"] # Phi3 does this + else: + return None + + try: + called_tool = [tool for tool in functions_list if tool['name'] == called_tool_name][0] + except IndexError: + return None # when a tool is called, but the tool doesnt exist + + return called_tool + + def _extract_conversaional_response(self, d: dict) -> str: + if ("tool_input" in d and "response" in d["tool_input"]): + response = d["tool_input"]["response"] + elif ("input" in d and "response" in d["input"]): + response = d["input"]["response"] + elif ("args" in d and "response" in d["args"]): + response = d["args"]["response"] + elif "response" in d: + response = d["response"] + elif "input" in d: + response = d["input"] + elif "args" in d: + response = d["args"] + elif "tool_input" in d: + response = d["tool_input"] + else: + raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") + + try: + assert isinstance(response, str) + except AssertionError: + raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") + + return response + + def _extract_tool_args(self, d: dict) -> dict: + if "tool_input" in parsed_chat_result: + called_tool_args = d["tool_input"] # per spec + elif "input" in d: + called_tool_args = d["input"] # Phi3 often does this + elif "args" in d: + called_tool_args = d["args"] + else: + called_tool_args = {} + return called_tool_args + + # prepare generation + functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])] + functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) + functions_str = json.dumps(functions_list, indent=2) + + # prepare generation with history + if True in [ isinstance(m, ToolMessage) for m in messages ]: + system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages) + + system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history) + system_message = system_message_prompt_template.format( + tools=functions_str, + history=formated_history, + system_msg=system_msg + ) + final_messages = [ system_message ] + + # prepare generation without history + else: + system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) + system_message = system_message_prompt_template.format( + tools=functions_str + ) + final_messages = [ system_message ] + messages + + # genrerate chat result + response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs) + chat_result = response_message.generations[0].text + + # chekc for validity + if not isinstance(chat_result, str): + raise ValueError("OllamaFunctions does not support non-string output.") + + # make str to dict + parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) + # if model failed to return vailid json, just retrun the whole thing + if isinstance(parsed_chat_result, str): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))]) + + + # get the called tool from the dict + called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list) + + if not called_tool: + response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) + elif called_tool == CONVERSATIONAL_RESPONSE_TOOL: + response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) + else: + response_msg = AIMessage( + content="", + tool_calls=[ToolCall( + name=called_tool['name'], + args=_extract_tool_args(self, d=parsed_chat_result), + id=f"call_{str(uuid.uuid4()).replace('-', '')}", + )], + ) + + return ChatResult(generations=[ChatGeneration(message=response_msg)]) + + @property + def _llm_type(self) -> str: + return "ollama_functions" diff --git a/libs/run_tests.py b/libs/run_tests.py index c043dec..5a250a1 100644 --- a/libs/run_tests.py +++ b/libs/run_tests.py @@ -1,30 +1,19 @@ -from libs.test_class import Test +from libs.classes import Test, Model +from libs.functions import nxhash from typing import Union import json -def padd(list, element): - longest = 0 - for s in list: - longest = max(longest, len(str(s))) - return str(element).ljust(longest) - -def nxhash(text:str): # @BenVida StackOverflow - hash=0 - for ch in text: - hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF - return hex(hash)[2:].upper().zfill(8) - def get_len(collection: Union[list, dict]) -> int: maximum_length = 0 - if isinstance(collection, dict): - collection_type = "tests" - elif isinstance(collection, list): - if isinstance(collection[0], str): - collection_type = "models" - elif isinstance(collection[0], int): - collection_type = "seeds" + if isinstance(collection, list): + collection_type = "seeds" + elif isinstance(collection, dict): + if isinstance(collection[list(collection.keys())[0]], Model): + collection_type = "models" + elif isinstance(collection[list(collection.keys())[0]], Test): + collection_type = "tests" else: raise TypeError("get_len: unsupported collection_type") else: @@ -32,8 +21,8 @@ def get_len(collection: Union[list, dict]) -> int: match collection_type: case "models": - for model_name in collection: - maximum_length = max(maximum_length, len(model_name)) + for model_id in collection: + maximum_length = max(maximum_length, len(collection[model_id].display_name)) case "seeds": for seed in collection: maximum_length = max(maximum_length, len(str(seed))) @@ -48,40 +37,42 @@ def get_len(collection: Union[list, dict]) -> int: -def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str): +def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str): try: print("Trying to load saved_results.json") with open("./saved_results.json", "r") as f: saved_results = json.load(fp=f) print("Loaded.") - except: + except FileNotFoundError: print("saved_results.json not found. Initializing empty.") saved_results = {} # Get Results run_results = {} print("Starting to run Tests ... ") - for model in models: + for model_id in models: + model = models[model_id] for test_id in tests: test = tests[test_id] for seed in seeds: # Init dict combination = { 'test_id': test_id, - 'model': model, + 'model_id': model_id, 'seed': seed, } hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) combination['test_name'] = test.name - + combination['model_name'] = model.display_name + # if hash_key == "DE3D137E": # pass if hash_key not in saved_results.keys(): try: print("\033[0;35mModel '\033[0m" + - model + + model.display_name + "\033[0;35m'" + - (" " * (get_len(models) - len(model))) + + (" " * (get_len(models) - len(model.display_name))) + " with seed \033[0m\033[0;30m" + ("0" * (get_len(seeds) - len(str(seed)))) + "\033[0m" + @@ -96,7 +87,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: end="" ) answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) - if isinstance(answer, str): # tool capabile return tools called as a list[dict] + if isinstance(answer, str): combination['answer'] = answer # combination['tool_calls'] = [] # no entry del answer @@ -105,15 +96,14 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: combination['tool_calls'] = answer['tool_calls'] del answer else: - raise Exception(f"runnable returd unkown type {type(answer)}.") - + raise Exception(f"runnable returned unkown type {type(answer)}.") combination['test'] = test run_results[hash_key] = combination print("\r\033[0;32mModel '\033[0m" + - model + + model.display_name + "\033[0;32m'" + - (" " * (get_len(models) - len(model))) + + (" " * (get_len(models) - len(model.display_name))) + " with seed \033[0m\033[0;30m" + ("0" * (get_len(seeds) - len(str(seed)))) + "\033[0m" + @@ -127,12 +117,12 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: "\033[0;32m)\033[0m" ) except Exception as e: - print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...") + print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ") else: print("\r\033[0;34mModel '\033[0m" + - model + + model.display_name + "\033[0;34m'" + - (" " * (get_len(models) - len(model))) + + (" " * (get_len(models) - len(model.display_name))) + " with seed \033[0m\033[0;30m" + ("0" * (get_len(seeds) - len(str(seed)))) + "\033[0m" + @@ -148,7 +138,8 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: # Validate Results - if run_results != {}: print("\nStarting validation of tests ...") + if run_results != {}: + print("\nStarting validation of tests ...") for hash_key in run_results: result = run_results[hash_key] @@ -156,27 +147,28 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: entry = { 'test_name': result['test_name'], 'test_id': result['test_id'], - 'model': result['model'], + 'model_name': result['model_name'], + 'model_id': result['model_id'], 'seed': result['seed'], 'answer': result['answer'], 'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url), } except Exception as e: - print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m") + print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m ") continue try: entry['tool_calls'] = result['tool_calls'] - except: + except KeyError: pass saved_results[hash_key] = entry # add result with validation to saved results print("\033[0;36mTest results of model '\033[0m" + - entry['model'] + + entry['model_name'] + "\033[0;36m'" + - (" " * (get_len(models) - len(entry['model']))) + + (" " * (get_len(models) - len(entry['model_name']))) + " with seed \033[0m\033[0;30m" + ("0" * (get_len(seeds) - len(str(entry['seed'])))) + "\033[0m" + @@ -188,7 +180,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: " (\033[0m" + hash_key + "\033[0;36m) evaluated to \033[0m" + - ('\033[0;32mcorrect\033[0m' if entry['validation'] == True else '\033[0;31mincorrect\033[0m') + ('\033[0;32mcorrect\033[0m' if entry['validation'] else '\033[0;31mincorrect\033[0m') ) with open("./saved_results.json", "w") as f: diff --git a/libs/runnables.py b/libs/runnables.py index de05d5f..d1f7065 100644 --- a/libs/runnables.py +++ b/libs/runnables.py @@ -1,76 +1,100 @@ +from types import NoneType from langchain_ollama.chat_models import ChatOllama -from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage -from libs.test_class import Test +from libs.ollama_functions import OllamaFunctions +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage +from libs.classes import Test, Model from langchain.tools import Tool from typing import Literal from langgraph.graph import StateGraph, MessagesState -# from langgraph.prebuilt import ToolNode import json from pydantic import ValidationError +def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None): + if model.supports_tools: + llm = ChatOllama( + model=model.identifier, + seed=seed, + base_url=base_url + ) + else: + llm = OllamaFunctions( + model=model.identifier, + seed=seed, + base_url=base_url, + format="json" + ) -def basic(model: str, seed: int, test: Test, base_url: str) -> str: - system_msg = test.runnable_input['system_msg'] - human_msg = test.runnable_input['human_msg'] + if tools: + llm = llm.bind_tools(tools=tools) - if system_msg == None: prompt = [ human_msg ] - else: prompt = [ system_msg, human_msg ] + return llm - llm = ChatOllama( - model=model, - seed=seed, - base_url=base_url - ) - ai_msg = llm.invoke(prompt) + +def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str: + + messages = [SystemMessage(test.runnable_input['system_msg'])] + try: + messages += test.runnable_input['fsp_messages'] + except KeyError: + pass + messages += [ HumanMessage(test.runnable_input['human_msg']) ] + + llm = _get_llm(model=model, base_url=base_url, seed=seed) + ai_msg = llm.invoke(messages) + assert isinstance(ai_msg.content, str) return ai_msg.content -def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: - system_msg = test.runnable_input['system_msg'] - human_msg = test.runnable_input['human_msg'] +def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict: + tools_dict = test.runnable_input['tools'] tools = [] for key in tools_dict: tools.append(tools_dict[key]) + llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools) - if system_msg == None: prompt = [ human_msg ] - else: prompt = [ system_msg, human_msg ] + messages = [SystemMessage(test.runnable_input['system_msg'])] + try: + messages += test.runnable_input['fsp_messages'] + except KeyError: + pass + messages += [ HumanMessage(test.runnable_input['human_msg']) ] - llm = ChatOllama( - model=model, - seed=seed, - base_url=base_url - ).bind_tools(tools) + ai_msg = llm.invoke(messages) - ai_msg = llm.invoke(prompt) - - prompt.append(ai_msg) + messages += [ ai_msg ] try: tool_calls = [] - for i in range(len(ai_msg.tool_calls)): - tool_call = ai_msg.tool_calls[i] - selected_tool = tools_dict[tool_call["name"].lower()] - tool_msg = selected_tool.invoke(tool_call) - prompt.append(tool_msg) - ai_msg = llm.invoke(prompt) + assert isinstance(ai_msg, AIMessage) + calls = ai_msg.tool_calls + for call in calls: + selected_tool = tools_dict[call["name"].lower()] + tool_msg = selected_tool.invoke(call) + messages.append(tool_msg) + ai_msg = llm.invoke(messages) tool_calls.append({ - "tool": tool_call["name"], - "args": tool_call["args"], - "index": 0 + "tool": call["name"], + "args": call["args"], }) except IndexError: # LLM didnt use a tool -> jsut return the content tool_calls = [] + if len(ai_msg.tool_calls) > 0: + to_append_calls = [] + for call in ai_msg.tool_calls: + to_append_calls.append({ "tool": call["name"], "args": call["args"] }) + return { + "answer": ">>LLM did not respond conversationally<<", + "tool_calls": tool_calls + to_append_calls, + } return { "answer": ai_msg.content, - "tool_calls": tool_calls + "tool_calls": tool_calls, } - - -def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: +def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]: tool_calls = [] index = -1 @@ -79,6 +103,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: messages = state["messages"] last_message = messages[-1] nonlocal index + assert isinstance(last_message, AIMessage) # this is just so the type checker is happy if last_message.tool_calls: index += 1 return "tools" @@ -113,10 +138,10 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: try: tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) - except KeyError as e: + except KeyError: tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' except ValidationError as e: - tool_result = 'Tool got invalid input:\n' + e + tool_result = 'Tool got invalid input:\n' + str(e) except Exception as e: tool_result = 'Error: ' + str(e) @@ -135,11 +160,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: for key in tools_dict: tools.append(tools_dict[key]) tool_node = NxToolNode(tools) - llm = ChatOllama( - model=model, - seed=seed, - base_url=base_url - ).bind_tools(tools) + llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools) workflow = StateGraph(MessagesState) @@ -156,124 +177,21 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: graph = workflow.compile() - # example with a single tool call - start_messages = [ - SystemMessage(content=test.runnable_input['system_msg']), - HumanMessage(content=test.runnable_input['human_msg']) - ] + # compose "history" supprts few shot prompting + start_messages = [SystemMessage(test.runnable_input['system_msg'])] + try: + start_messages += test.runnable_input['fsp_messages'] + except KeyError: + pass + start_messages += [ HumanMessage(test.runnable_input['human_msg']) ] chunks = [] for chunk in graph.stream( {"messages": start_messages}, stream_mode="values", - ): chunks.append(chunk["messages"][-1]) - - return { - "answer": chunks[-1].content, - "tool_calls": tool_calls - } - - - - - - -def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str: - - tool_calls = [] - index = -1 - - def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: - messages = state["messages"] - last_message = messages[-1] - nonlocal index - if last_message.tool_calls: - index += 1 - return "tools" - return "__end__" - - def call_llm(state: MessagesState): - messages = state["messages"] - response = llm.invoke(messages) - return {"messages": [response]} - - class NxToolNode: - """A node that runs the tools requested in the last AIMessage.""" - - def __init__(self, tools: list) -> None: - self.tools_by_name = {tool.name: tool for tool in tools} - - def __call__(self, inputs: dict): - if messages := inputs.get("messages", []): - message = messages[-1] - else: - raise ValueError("No message found in input") - outputs = [] - for tool_call in message.tool_calls: - - nonlocal tool_calls - nonlocal index - tool_calls.append({ - "tool": tool_call["name"], - "args": tool_call["args"], - "index": index - }) - - try: - tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) - except KeyError as e: - tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' - except ValidationError as e: - tool_result = 'Tool got invalid input:\n' + e - except Exception as e: - tool_result = 'Error: ' + str(e) - - outputs.append( - ToolMessage( - content=json.dumps(tool_result), - name=tool_call["name"], - tool_call_id=tool_call["id"], - ) - ) - return {"messages": outputs} - - - tools_dict = test.runnable_input['tools'] - tools = [] - for key in tools_dict: - tools.append(tools_dict[key]) - tool_node = NxToolNode(tools) - llm = ChatOllama( - model=model, - seed=seed, - base_url=base_url - ).bind_tools(tools) - - workflow = StateGraph(MessagesState) - - # Define the two nodes we will cycle between - workflow.add_node("agent", call_llm) - workflow.add_node("tools", tool_node) - - workflow.add_edge("__start__", "agent") - workflow.add_conditional_edges( - "agent", - should_continue, - ) - workflow.add_edge("tools", "agent") - - graph = workflow.compile() - - # example with a single tool call - start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ] - - chunks = [] - - for chunk in graph.stream( - {"messages": start_messages}, - stream_mode="values", - ): chunks.append(chunk["messages"][-1]) + ): + chunks.append(chunk["messages"][-1]) return { "answer": chunks[-1].content, diff --git a/libs/test_class.py b/libs/test_class.py deleted file mode 100644 index dfa57b3..0000000 --- a/libs/test_class.py +++ /dev/null @@ -1,10 +0,0 @@ -from dataclasses import dataclass, field -from typing import Callable, Any - -@dataclass -class Test: - name: str - runnable: Callable - runnable_input: dict - validator: Callable - validation_input: dict diff --git a/libs/tools.py b/libs/tools.py index 0fdad07..b8f848a 100644 --- a/libs/tools.py +++ b/libs/tools.py @@ -6,14 +6,14 @@ from typing import Union @tool def add(a: float, b: float) -> str: - """Adds a+b and retuns the sum""" + """Adds a+b and returns the sum""" af = float(a) bf = float(b) return f"{a} + {b} = {a+b}" @tool def multiply(a: float, b: float) -> str: - """Multiplies a*b and retuns the product""" + """Multiplies a*b and returns the product""" af = float(a) bf = float(b) return f"{a} * {b} = {a*b}" diff --git a/libs/validators.py b/libs/validators.py index 6760a03..f49000d 100644 --- a/libs/validators.py +++ b/libs/validators.py @@ -1,7 +1,7 @@ from langchain_ollama.chat_models import ChatOllama from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate from langchain.tools import tool -from libs.test_class import Test +from libs.classes import Test from re import search from textwrap import dedent @@ -16,7 +16,7 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is {validation_input} - else as incorrect. Only use the rate tool. Do not answer conversationally.""")), + else as incorrect. Only use the `rate` tool. You do not have accesss to any other tools. Do not answer conversationally.""")), HumanMessagePromptTemplate.from_template(template=dedent("""System Message: {system_msg} @@ -50,7 +50,10 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: elif ret_str.lower() == 'false': return False else: raise Exception(f"rate tool retured {ret_str}") except IndexError as e: - print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...") + print(f"\033[0;31mValidation Error of\033[0m {test.name} \033[0;31m<\033[0m{ai_msg.content[:20]}\033[0;31m...> Retrying...\033[0m") + return system_human_answer_match(test=test, answer=answer, base_url=base_url) + except Exception as e: + print(f"\033[0;31mValidation Error \033[0mof {test.name} \033[0;31m<\033[0m{e}\033[0;31m> Retrying...\033[0m") return system_human_answer_match(test=test, answer=answer, base_url=base_url) def regex_match_any(test: Test, answer: str, base_url: str) -> bool: diff --git a/requirements.txt b/requirements.txt index 8564c9d..22897d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ langchain langchain-core langchain-ollama +langchain-community langgraph seaborn pandas diff --git a/suite_settings/models.py b/suite_settings/models.py index 4b3181d..6f50eba 100644 --- a/suite_settings/models.py +++ b/suite_settings/models.py @@ -1,13 +1,112 @@ -models = [ - "llama3.1", # 8b - "llama3.1:70b", - "llama3-groq-tool-use", # latest - "llama3-groq-tool-use:70b", - # "mixtral:8x7b", - "mixtral:8x22b", - # "gemma2:2b", - # "phi3", # 3.8b - # "tinyllama:1.1b", - "mistral-nemo:12b", - "command-r-plus:104b", -] +from libs.classes import Model + +models = { + 245: Model( + display_name="llama3.1 8b", + identifier="llama3.1", + supports_tools=True, + parameter_count_in_b=8 + ), + 238: Model( + display_name="llama3.1 70b", + identifier="llama3.1:70b", + supports_tools=True, + parameter_count_in_b=70 + ), + 120: Model( + display_name="llama3 groq TU 8b", + identifier="llama3-groq-tool-use", + supports_tools=True, + parameter_count_in_b=8 + ), + 890: Model( + display_name="llama3 groq TU 70b", + identifier="llama3-groq-tool-use:70b", + supports_tools=True, + parameter_count_in_b=70 + ), + 348: Model( + display_name="Mixtral MoE 8x7b", + identifier="mixtral:8x7b", + supports_tools=False, + parameter_count_in_b=13, + ), + 789: Model( + display_name="Mixtral MoE 8x22b", + identifier="mixtral:8x22b", + supports_tools=True, + parameter_count_in_b=39 + ), + 445: Model( + display_name="Gemma2 2b", + identifier="gemma2:2b", + supports_tools=False, + parameter_count_in_b=2 + ), + 475: Model( + display_name="Gemma2 9b", + identifier="gemma2:2b", + supports_tools=False, + parameter_count_in_b=9 + ), + 626: Model( + display_name="Gemma2 27b", + identifier="gemma2:2b", + supports_tools=False, + parameter_count_in_b=27 + ), + 229: Model( + display_name="Phi3 3.8b", + identifier="phi3", + supports_tools=False, + parameter_count_in_b=3.8 + ), + 903: Model( + display_name="Tinyllama 1.1b", + identifier="tinyllama:1.1b", + supports_tools=False, + parameter_count_in_b=1.1 + ), + 670: Model( + display_name="Mistral Nemo 12b", + identifier="mistral-nemo:12b", + supports_tools=True, + parameter_count_in_b=12 + ), + 404: Model( + display_name="Command R+ 104b", + identifier="command-r-plus:104b", + supports_tools=True, + parameter_count_in_b=104 + ), + 701: Model( + display_name="Yi 6b", + identifier="yi:7b", + supports_tools=False, + parameter_count_in_b=6 + ), + 704: Model( + display_name="Yi 6b", + identifier="yi:7b", + supports_tools=False, + parameter_count_in_b=6 + ), + 724: Model( + display_name="Yi 34b", + identifier="yi:34b", + supports_tools=False, + parameter_count_in_b=34 + ), + 129: Model( + display_name="Yi 34b", + identifier="yi:34b", + supports_tools=False, + parameter_count_in_b=34 + ), + 853: Model( + display_name="Qwen2 0.5b", + identifier="qwen2:0.5b", + supports_tools=False, + parameter_count_in_b=0.5 + ), +} diff --git a/suite_settings/tests.py b/suite_settings/tests.py index c004ae2..8dc7325 100644 --- a/suite_settings/tests.py +++ b/suite_settings/tests.py @@ -1,21 +1,21 @@ -from libs.test_class import Test -from libs.runnables import * -from libs.validators import * -from libs.tools import * +from libs.classes import Test +from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools +from libs.validators import regex_match_any, system_human_answer_match +from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note from textwrap import dedent from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage tests = { 607: Test( name="Healthy Vegetables in Chinese", - runnable=basic, + runnable=basic_prompt, runnable_input={ "system_msg": "You are a helpful assistant. You serve people across the globe.", "human_msg": "什么蔬菜最健康?", }, - validator=system_human_answer_match, + validator=system_human_answer_match, validation_input={ - "criteria": dedent("""- in Mandarin Chinese from front to finnish + "criteria": dedent("""- in Mandarin Chinese from front to finnish - factually correct - about healthy vegetables - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes) @@ -23,7 +23,7 @@ tests = { Again, the message has to be entirely in Manadarin Chineese. That means If the answer is not in Chinese the answer is NOT correct! Only if the message in in Chinese rate as correct"""), } - ), + ), 693: Test( name="Simple Multiplication", runnable=one_tool_call_answer, @@ -52,12 +52,12 @@ tests = { "multiply": multiply } }, - validator=regex_match_any, + validator=regex_match_any, validation_input={ "patterns": [ "6134205", "6.134.205", "6,134,205" ] } - ), - 283: Test( + ), + 283: Test( name="Notes from last Saturday", runnable=agent_with_tools, runnable_input={ @@ -67,16 +67,16 @@ tests = { "get_current_date_and_time": get_current_date_and_time, "get_notes_in_timespan": get_notes_in_timespan, "get_notes_containing": get_notes_containing, - "Write note": write_note + "Write note": write_note, } }, - validator=system_human_answer_match, + validator=system_human_answer_match, validation_input={ - "criteria": dedent("""- containing the information that the Human should call Wolfgang + "criteria": dedent("""- containing the information that the Human should call Wolfgang - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") } ), - 260: Test( + 260: Test( name="Notes from last Saturday TSO", # time span only runnable=agent_with_tools, runnable_input={ @@ -88,15 +88,15 @@ tests = { "Write note": write_note } }, - validator=system_human_answer_match, + validator=system_human_answer_match, validation_input={ - "criteria": dedent("""- containing the information that the Human should call Wolfgang + "criteria": dedent("""- containing the information that the Human should call Wolfgang - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") } - ), + ), 856: Test( name="Notes from last Saturday TSO FSP", - runnable=agent_with_tools_fsp, + runnable=agent_with_tools, runnable_input={ "system_msg": "You are a helpful assistant. You can use tools to accomplish tasks. Once you've called a tool, the resulting tool_message content can be taken into consideration again. With that you can do \"multiple rounds\" of tool calling. To know the date, use the tool get_current_date_and_time.", "fsp_messages": [ @@ -121,12 +121,12 @@ tests = { "Write note": write_note } }, - validator=system_human_answer_match, + validator=system_human_answer_match, validation_input={ - "criteria": dedent("""- containing the information that the Human should call Wolfgang + "criteria": dedent("""- containing the information that the Human should call Wolfgang - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") } - ), + ), # 363: Test(), # 600: Test(), # 221: Test(), diff --git a/test_suite.py b/test_suite.py index aa91b29..945b4ea 100644 --- a/test_suite.py +++ b/test_suite.py @@ -1,15 +1,16 @@ -from libs.run_tests import run_tests +from libs.run_tests import run_tests from suite_settings.models import models -from suite_settings.seeds import seeds -from suite_settings.tests import tests +from suite_settings.seeds import seeds +from suite_settings.tests import tests + def main(): - results = run_tests( + run_tests( models=models, seeds=seeds, tests=tests, - base_url="http://bolt.hs-mittweida.de:11434" + base_url="http://bolt.hs-mittweida.de:11434", ) if __name__ == "__main__": diff --git a/visualize.py b/visualize.py index 7674f85..6631993 100644 --- a/visualize.py +++ b/visualize.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import pandas as pd import numpy as np import seaborn as sns -from math import pi # Load the JSON data with open('saved_results.json', 'r') as f: @@ -14,7 +13,7 @@ results = [] for test_hash, test_data in data.items(): results.append({ "hash": test_hash, - "model": test_data['model'], + "model": test_data['model_name'], "seed": test_data['seed'], "test_name": test_data['test_name'], "validation": test_data['validation'] @@ -61,52 +60,7 @@ plt.savefig('validation_results_by_test_name.png') ## 3rd Chart -# Prepare data for the spider chart -models = df['model'].unique() - -# Calculate the pass rate for each model on each test pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0) -tests = df['test_name'].unique().tolist() - -# Initialize the spider plot -num_vars = len(pass_rate)-1 -angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() -angles += [ angles[0] ] - -fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) - -# Plot each model's performance -for model in models: - values = pass_rate.loc[model].tolist() - values += [ values[0] ] - ax.fill(angles, values, alpha=0.25) - ax.plot(angles, values, label=model) -# - -# Configure the spider chart -ax.set_theta_offset(pi / 2) -ax.set_theta_direction(-1) - -tests.append(tests[0]) -tests.pop(0) - -ax.set_xticks(angles[:-1]) -ax.set_xticklabels(tests) - -ax.set_yticks(np.linspace(0, 1, 5)) -ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8) -ax.set_ylim(0, 1) - -plt.title('Model Performance on Each Test') -plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1)) -plt.tight_layout() -plt.savefig('model_performance_spider_chart.png') - - - - - -# 4th chart # Create a heatmap plt.figure(figsize=(8, 8)) sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})