diff --git a/libs/classes.py b/libs/classes.py index 4bbd754..e166151 100644 --- a/libs/classes.py +++ b/libs/classes.py @@ -16,3 +16,9 @@ class Model: supports_tools: bool parameter_count_in_b: float +@dataclass +class Technique: + name: str + for_supports_tools: bool + for_not_supports_tools: bool + diff --git a/libs/ollama_functions.py b/libs/ollama_functions.py index b831ea5..71d2679 100644 --- a/libs/ollama_functions.py +++ b/libs/ollama_functions.py @@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool, Tool from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass +from textwrap import dedent from libs.functions import nxhash @@ -87,65 +88,24 @@ _BM = TypeVar("_BM", bound=BaseModel) _DictOrPydantic = Union[Dict, _BM] +class OllamaError(Exception): + def __init__(self, message): + self.message = message # Store the message + super().__init__(message) + def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and ( is_basemodel_subclass(obj) or BaseModel in obj.__bases__ ) - -def convert_to_ollama_tool(tool: Any) -> Dict: - """Convert a tool to an Ollama tool.""" - description = None - if _is_pydantic_class(tool): - schema = tool.construct().schema() - name = schema["title"] - elif isinstance(tool, BaseTool): - schema = tool.tool_call_schema.schema() - name = tool.get_name() - description = tool.description - elif is_basemodel_instance(tool): - schema = tool.get_input_schema().schema() - name = tool.get_name() - description = tool.description - elif isinstance(tool, dict) and "name" in tool and "parameters" in tool: - return tool.copy() - else: - raise ValueError( - f"""Cannot convert {tool} to an Ollama tool. - {tool} needs to be a Pydantic class, model, or a dict.""" - ) - definition = {"name": name, "parameters": schema} - if description: - definition["description"] = description - - return definition - -def parse_response(message: BaseMessage) -> str: - """Extract `function_call` from `AIMessage`.""" - if isinstance(message, AIMessage): - kwargs = message.additional_kwargs - tool_calls = message.tool_calls - if len(tool_calls) > 0: - tool_call = tool_calls[-1] - args = tool_call.get("args") - return json.dumps(args) - elif "function_call" in kwargs: - if "arguments" in kwargs["function_call"]: - return kwargs["function_call"]["arguments"] - raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}") - else: - raise ValueError("`tool_calls` missing from AIMessage: {message}") - raise ValueError(f"`message` is not an instance of `AIMessage`: {message}") - - - class OllamaFunctions(ChatOllama): """Function chat model that uses Ollama API.""" tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY + max_tool_call_fails: int = 5 - def __init__(self, **kwargs: Any) -> None: + def __init__(self, max_tool_call_fails, **kwargs: Any) -> None: super().__init__(**kwargs) def bind_tools( @@ -157,61 +117,55 @@ class OllamaFunctions(ChatOllama): def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult: + def _convert_to_ollama_tool(self, tool: Any) -> Dict: + """Convert a tool to an Ollama tool.""" + description = None + if _is_pydantic_class(tool): + schema = tool.construct().schema() + name = schema["title"] + elif isinstance(tool, BaseTool): + schema = tool.tool_call_schema.schema() + name = tool.get_name() + description = tool.description + elif is_basemodel_instance(tool): + schema = tool.get_input_schema().schema() + name = tool.get_name() + description = tool.description + elif isinstance(tool, dict) and "name" in tool and "parameters" in tool: + return tool.copy() + else: + raise ValueError( + f"""Cannot convert {tool} to an Ollama tool. + {tool} needs to be a Pydantic class, model, or a dict.""" + ) + definition = {"name": name, "parameters": schema} + if description: + definition["description"] = description - def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]: - def _format_tools_for_history(tool_calls: list[ToolCall]) -> str: - call_list = [] - for c in tool_calls: - call_list.append({ - "id": nxhash(c['id'])[-4:], - "tool": c['name'], - "args": c['args'] - }) - if len(call_list) == 1: - return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2) - else: - return json.dumps(obj=call_list, ensure_ascii=False, indent=2) - formated_history = "" - system_msg = "" - for m in messages: + return definition - if formated_history != "": - formated_history += "\n\n" - - if isinstance(m, SystemMessage): - system_msg += str(m.content) - elif isinstance(m, HumanMessage): - formated_history += "The Human said:\n" + str(m.content) - elif isinstance(m, AIMessage) and m.tool_calls: - formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls) - elif isinstance(m, ToolMessage): - formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content) - elif isinstance(m, AIMessage) and not m.tool_calls: - formated_history += "You said:\n" + str(m.content) - else: - raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m))) - - return system_msg, formated_history - def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]: + def _get_parsed_chat_result(self, chat_result_str: str) -> dict: try: parsed_chat_result = json.loads(chat_result_str) + return parsed_chat_result except json.JSONDecodeError: - parsed_chat_result = chat_result_str - return parsed_chat_result + raise OllamaError(message="Error. Message is not valid JSON.") def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType: - if not parsed_chat_result: + if not d: called_tool_name = None - elif "tool" in parsed_chat_result: + elif "tool" in d: called_tool_name = d["tool"] # per spec elif "name" in d: called_tool_name = d["name"] # Phi3 often does this elif "tool_name" in d: called_tool_name = d["tool_name"] # Phi3 often does this elif "action" in d: - called_tool_name = d["action"] # Phi3 does this + called_tool_name = d["action"] # Gemma2 does this + elif "task" in d: + called_tool_name = d["task"] # Gemma2 does this else: return None @@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama): elif "tool_input" in d: response = d["tool_input"] else: - raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") + raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" + + "{\n" + + ' "tool": ,\n' + + ' "tool_input": \n' + + "}") try: assert isinstance(response, str) except AssertionError: - raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") + raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" + + "{\n" + + ' "tool": ,\n' + + ' "tool_input": \n' + + "}") return response def _extract_tool_args(self, d: dict) -> dict: - if "tool_input" in parsed_chat_result: + if "tool_input" in d: called_tool_args = d["tool_input"] # per spec elif "input" in d: called_tool_args = d["input"] # Phi3 often does this @@ -258,64 +220,114 @@ class OllamaFunctions(ChatOllama): called_tool_args = {} return called_tool_args - # prepare generation - functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])] - functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) - functions_str = json.dumps(functions_list, indent=2) + def _get_final_message(self, messages: list, functions_str: str) -> list: + def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]: + def _format_tools_for_history(tool_calls: list[ToolCall]) -> str: + call_list = [] + for c in tool_calls: + call_list.append({ + "id": nxhash(c['id'])[-4:], + "tool": c['name'], + "args": c['args'] + }) + if len(call_list) == 1: + return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2) + else: + return json.dumps(obj=call_list, ensure_ascii=False, indent=2) + formated_history = "" + system_msg = messages[0] + for m in messages[1:]: - # prepare generation with history - if True in [ isinstance(m, ToolMessage) for m in messages ]: - system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages) - - system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history) - system_message = system_message_prompt_template.format( - tools=functions_str, - history=formated_history, - system_msg=system_msg - ) - final_messages = [ system_message ] + if formated_history != "": + formated_history += "\n\n" - # prepare generation without history - else: - system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) - system_message = system_message_prompt_template.format( - tools=functions_str - ) - final_messages = [ system_message ] + messages + if isinstance(m, SystemMessage): + formated_history += "The system provided the info:\n" + str(m.content) + elif isinstance(m, HumanMessage): + formated_history += "The Human said:\n" + str(m.content) + elif isinstance(m, AIMessage) and m.tool_calls: + formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls) + elif isinstance(m, ToolMessage): + formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content) + elif isinstance(m, AIMessage) and not m.tool_calls: + formated_history += "You said:\n" + str(m.content) + else: + raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m))) - # genrerate chat result - response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs) - chat_result = response_message.generations[0].text + return system_msg, formated_history - # chekc for validity - if not isinstance(chat_result, str): - raise ValueError("OllamaFunctions does not support non-string output.") + # prepare generation with history + if True in [ isinstance(m, ToolMessage) for m in messages ]: + system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages) + + system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history) + system_message = system_message_prompt_template.format( + tools=functions_str, + history=formated_history, + system_msg=system_msg + ) + final_messages = [ system_message ] - # make str to dict - parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) - # if model failed to return vailid json, just retrun the whole thing - if isinstance(parsed_chat_result, str): - return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))]) + # prepare generation without history + else: + system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) + system_message = system_message_prompt_template.format( + tools=functions_str + ) + final_messages = [ system_message ] + messages + + return final_messages - # get the called tool from the dict - called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list) + - if not called_tool: - response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) - elif called_tool == CONVERSATIONAL_RESPONSE_TOOL: - response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) - else: - response_msg = AIMessage( - content="", - tool_calls=[ToolCall( - name=called_tool['name'], - args=_extract_tool_args(self, d=parsed_chat_result), - id=f"call_{str(uuid.uuid4()).replace('-', '')}", - )], - ) + def gen(self, failed_tool_calls: int, messages: list) -> ChatResult: - return ChatResult(generations=[ChatGeneration(message=response_msg)]) + # prepare generation + functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])] + functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) + functions_str = json.dumps(functions_list, indent=2) + + # get messages to prompt with + final_messages = _get_final_message(self, messages=messages, functions_str=functions_str) + + # genrerate chat result + response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs) + chat_result = response_message.generations[0].text + + try: + # make str to dict + parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) + + # get the called tool from the dict + called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list) + + if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL): + response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) + else: + response_msg = AIMessage( + content="", + tool_calls=[ToolCall( + name=called_tool['name'], + args=_extract_tool_args(self, d=parsed_chat_result), + id=f"call_{str(uuid.uuid4()).replace('-', '')}", + )], + ) + return ChatResult(generations=[ChatGeneration(message=response_msg)]) + except OllamaError as e: + if failed_tool_calls < self.max_tool_call_fails: + # retry + messages.append(AIMessage(chat_result)) + messages.append(SystemMessage(e.message)) + return gen(self, failed_tool_calls+1, messages=messages) + else: + # return error + # return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))]) + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))]) + + # inital call with no failed runs + return gen(self, failed_tool_calls=0, messages=messages) + @property def _llm_type(self) -> str: diff --git a/libs/run_tests.py b/libs/run_tests.py index 5a250a1..8f1faa6 100644 --- a/libs/run_tests.py +++ b/libs/run_tests.py @@ -1,4 +1,5 @@ -from libs.classes import Test, Model +from os import name +from libs.classes import Technique, Test, Model from libs.functions import nxhash from typing import Union @@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int: collection_type = "models" elif isinstance(collection[list(collection.keys())[0]], Test): collection_type = "tests" + elif isinstance(collection[list(collection.keys())[0]], Technique): + collection_type = "techniques" else: raise TypeError("get_len: unsupported collection_type") else: @@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int: case "tests": for test_id in collection: maximum_length = max(maximum_length, len(collection[test_id].name)) + case "techniques": + for technique_id in collection: + maximum_length = max(maximum_length, len(collection[technique_id].name)) case _: for model_name in collection: raise TypeError("get_len: unsupported collection_type") @@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int: -def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str): +def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str): try: print("Trying to load saved_results.json") with open("./saved_results.json", "r") as f: @@ -53,88 +59,109 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test] model = models[model_id] for test_id in tests: test = tests[test_id] - for seed in seeds: - # Init dict - combination = { - 'test_id': test_id, - 'model_id': model_id, - 'seed': seed, - } - hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) - combination['test_name'] = test.name - combination['model_name'] = model.display_name - - # if hash_key == "DE3D137E": - # pass + for technique_id in techniques: + technique = techniques[technique_id] + if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)): + continue + for seed in seeds: + # Init dict + combination = { + 'test_id': test_id, + 'model_id': model_id, + 'seed': seed, + 'technique_id': technique_id + } + hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) - if hash_key not in saved_results.keys(): - try: - print("\033[0;35mModel '\033[0m" + - model.display_name + - "\033[0;35m'" + - (" " * (get_len(models) - len(model.display_name))) + - " with seed \033[0m\033[0;30m" + - ("0" * (get_len(seeds) - len(str(seed)))) + - "\033[0m" + - str(seed) + - "\033[0;35m now runs test '\033[0m" + - test.name + - "\033[0;35m'" + - (" " * (get_len(tests) - len(test.name))) + - " (\033[0m" + - hash_key + - "\033[0;35m)\033[0m", - end="" - ) - answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) - if isinstance(answer, str): - combination['answer'] = answer - # combination['tool_calls'] = [] # no entry - del answer - elif isinstance(answer, dict): # calls - combination['answer'] = answer['answer'] - combination['tool_calls'] = answer['tool_calls'] - del answer - else: - raise Exception(f"runnable returned unkown type {type(answer)}.") + combination.update({ + 'test_name': test.name, + 'model_name': model.display_name, + 'technique_name': technique.name, + }) + + # if hash_key == "DE3D137E": + # pass + + if hash_key not in saved_results.keys(): + try: + print("\033[0;35mModel '\033[0m" + + model.display_name + + "\033[0;35m'" + + (" " * (get_len(models) - len(model.display_name))) + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(seed)))) + + "\033[0m" + + str(seed) + + "\033[0;35m using technique '\033[0m" + + technique.name + + "\033[0;35m'" + + (" " * (get_len(techniques) - len(technique.name))) + + "\033[0;35m now runs test '\033[0m" + + test.name + + "\033[0;35m'" + + (" " * (get_len(tests) - len(test.name))) + + " (\033[0m" + + hash_key + + "\033[0;35m)\033[0m", + end="" + ) + answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) + if isinstance(answer, str): + combination['answer'] = answer + # combination['tool_calls'] = [] # no entry + del answer + elif isinstance(answer, dict): # calls + combination['answer'] = answer['answer'] + combination['tool_calls'] = answer['tool_calls'] + del answer + else: + raise Exception(f"runnable returned unkown type {type(answer)}.") - combination['test'] = test - run_results[hash_key] = combination - print("\r\033[0;32mModel '\033[0m" + + combination['test'] = test + run_results[hash_key] = combination + print("\r\033[0;32mModel '\033[0m" + + model.display_name + + "\033[0;32m'" + + (" " * (get_len(models) - len(model.display_name))) + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(seed)))) + + "\033[0m" + + str(seed) + + "\033[0;32m using technique '\033[0m" + + technique.name + + "\033[0;32m'" + + (" " * (get_len(techniques) - len(technique.name))) + + "\033[0;32m finished test '\033[0m" + + test.name + + "\033[0;32m'" + + (" " * (get_len(tests) - len(test.name))) + + " (\033[0m" + + hash_key + + "\033[0;32m)\033[0m" + ) + except Exception as e: + print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ") + else: + print("\r\033[0;34mModel '\033[0m" + model.display_name + - "\033[0;32m'" + + "\033[0;34m'" + (" " * (get_len(models) - len(model.display_name))) + " with seed \033[0m\033[0;30m" + ("0" * (get_len(seeds) - len(str(seed)))) + "\033[0m" + str(seed) + - "\033[0;32m finished test '\033[0m" + + "\033[0;34m using technique '\033[0m" + + technique.name + + "\033[0;34m'" + + (" " * (get_len(techniques) - len(technique.name))) + + "\033[0;34m skipped test '\033[0m" + test.name + - "\033[0;32m'" + + "\033[0;34m'" + (" " * (get_len(tests) - len(test.name))) + " (\033[0m" + hash_key + - "\033[0;32m)\033[0m" + "\033[0;34m) becasue its results exists in saved_results.json\033[0m" ) - except Exception as e: - print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ") - else: - print("\r\033[0;34mModel '\033[0m" + - model.display_name + - "\033[0;34m'" + - (" " * (get_len(models) - len(model.display_name))) + - " with seed \033[0m\033[0;30m" + - ("0" * (get_len(seeds) - len(str(seed)))) + - "\033[0m" + - str(seed) + - "\033[0;34m skipped test '\033[0m" + - test.name + - "\033[0;34m'" + - (" " * (get_len(tests) - len(test.name))) + - " (\033[0m" + - hash_key + - "\033[0;34m) becasue its results exists in saved_results.json\033[0m" - ) # Validate Results diff --git a/libs/runnables.py b/libs/runnables.py index d1f7065..a3cf00f 100644 --- a/libs/runnables.py +++ b/libs/runnables.py @@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType model=model.identifier, seed=seed, base_url=base_url, - format="json" + format="json", + max_tool_call_fails=3, + temperature=0.0 ) if tools: @@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> tool_msg = selected_tool.invoke(call) messages.append(tool_msg) ai_msg = llm.invoke(messages) + i = 0 + while isinstance(ai_msg, SystemMessage): + i += 1 + if i <= 5: + return { + "answer": ">>LLM failed to use tools<<", + "tool_calls": tool_calls, + } + messages.append(ai_msg) + ai_msg = llm.invoke(messages) + tool_calls.append({ "tool": call["name"], "args": call["args"], + "times_failed": i }) except IndexError: # LLM didnt use a tool -> jsut return the content tool_calls = [] @@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict messages = state["messages"] last_message = messages[-1] nonlocal index - assert isinstance(last_message, AIMessage) # this is just so the type checker is happy if last_message.tool_calls: index += 1 return "tools" @@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict should_continue, ) workflow.add_edge("tools", "agent") - graph = workflow.compile() + # compose "history" supprts few shot prompting start_messages = [SystemMessage(test.runnable_input['system_msg'])] try: @@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict chunks = [] - for chunk in graph.stream( - {"messages": start_messages}, - stream_mode="values", - ): - chunks.append(chunk["messages"][-1]) + try: + for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}): + chunks.append(chunk["messages"][-1]) + except RecursionError: + return { + "answer": ">>Model did not come to a conclusion (Recusion Error)<<", + "tool_calls": tool_calls + } return { "answer": chunks[-1].content, diff --git a/suite_settings/models.py b/suite_settings/models.py index 6f50eba..711059a 100644 --- a/suite_settings/models.py +++ b/suite_settings/models.py @@ -81,13 +81,13 @@ models = { ), 701: Model( display_name="Yi 6b", - identifier="yi:7b", + identifier="yi:6b", supports_tools=False, parameter_count_in_b=6 ), 704: Model( - display_name="Yi 6b", - identifier="yi:7b", + display_name="Yi 9b", + identifier="yi:9b", supports_tools=False, parameter_count_in_b=6 ), @@ -97,12 +97,6 @@ models = { supports_tools=False, parameter_count_in_b=34 ), - 129: Model( - display_name="Yi 34b", - identifier="yi:34b", - supports_tools=False, - parameter_count_in_b=34 - ), 853: Model( display_name="Qwen2 0.5b", identifier="qwen2:0.5b", diff --git a/suite_settings/techniques.py b/suite_settings/techniques.py new file mode 100644 index 0000000..67d3cb7 --- /dev/null +++ b/suite_settings/techniques.py @@ -0,0 +1,19 @@ +from libs.classes import Technique + +techniques = { + 190: Technique( + name="Native", + for_supports_tools=True, + for_not_supports_tools=False, + ), + 903: Technique( + name="Long System Message", + for_supports_tools=False, + for_not_supports_tools=True, + ), + # 572: Technique( + # name="Tool to System Messsages", + # for_supports_tools=False, + # for_not_supports_tools=True, + # ), +} diff --git a/suite_settings/tests.py b/suite_settings/tests.py index 8dc7325..50bd7e0 100644 --- a/suite_settings/tests.py +++ b/suite_settings/tests.py @@ -121,12 +121,13 @@ tests = { "Write note": write_note } }, - validator=system_human_answer_match, + validator=system_human_answer_match, validation_input={ "criteria": dedent("""- containing the information that the Human should call Wolfgang - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") } - ), + ), + # 363: Test(), # 600: Test(), # 221: Test(), diff --git a/test_suite.py b/test_suite.py index 945b4ea..9c400f9 100644 --- a/test_suite.py +++ b/test_suite.py @@ -1,7 +1,9 @@ +from libs.classes import Technique from libs.run_tests import run_tests from suite_settings.models import models from suite_settings.seeds import seeds from suite_settings.tests import tests +from suite_settings.techniques import techniques def main(): @@ -10,6 +12,7 @@ def main(): models=models, seeds=seeds, tests=tests, + techniques=techniques, base_url="http://bolt.hs-mittweida.de:11434", )