diff --git a/libs/run_tests.py b/libs/run_tests.py index 1c44ff4..c043dec 100644 --- a/libs/run_tests.py +++ b/libs/run_tests.py @@ -1,6 +1,5 @@ from libs.test_class import Test -from libs.validators import system_human_answer_match -from libs.runnables import basic +from typing import Union import json @@ -16,18 +15,36 @@ def nxhash(text:str): # @BenVida StackOverflow hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF return hex(hash)[2:].upper().zfill(8) -def get_len(l: list) -> int: - m = 0 - for e in l: - if isinstance(e, Test): - m = max(m, len(e.name)) - elif isinstance(e, str): - m = max(m, len(e)) - elif isinstance(e, int): - m = max(m, len(str(e))) - else: - raise Exception(f"get_len() only supports lits of Test, str or int but got {type(e)}") - return m +def get_len(collection: Union[list, dict]) -> int: + maximum_length = 0 + + if isinstance(collection, dict): + collection_type = "tests" + elif isinstance(collection, list): + if isinstance(collection[0], str): + collection_type = "models" + elif isinstance(collection[0], int): + collection_type = "seeds" + else: + raise TypeError("get_len: unsupported collection_type") + else: + raise TypeError("get_len: unsupported collection_type") + + match collection_type: + case "models": + for model_name in collection: + maximum_length = max(maximum_length, len(model_name)) + case "seeds": + for seed in collection: + maximum_length = max(maximum_length, len(str(seed))) + case "tests": + for test_id in collection: + maximum_length = max(maximum_length, len(collection[test_id].name)) + case _: + for model_name in collection: + raise TypeError("get_len: unsupported collection_type") + + return maximum_length @@ -44,29 +61,63 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: run_results = {} print("Starting to run Tests ... ") for model in models: - for seed in seeds: - for test in tests: - + for test_id in tests: + test = tests[test_id] + for seed in seeds: # Init dict combination = { - 'test_name': test.name, + 'test_id': test_id, 'model': model, 'seed': seed, } hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) + combination['test_name'] = test.name + + # if hash_key == "DE3D137E": + # pass if hash_key not in saved_results.keys(): try: - combination['answer'] = test.runnable(model=model, seed=seed, test=test, base_url=base_url) + print("\033[0;35mModel '\033[0m" + + model + + "\033[0;35m'" + + (" " * (get_len(models) - len(model))) + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(seed)))) + + "\033[0m" + + str(seed) + + "\033[0;35m now runs test '\033[0m" + + test.name + + "\033[0;35m'" + + (" " * (get_len(tests) - len(test.name))) + + " (\033[0m" + + hash_key + + "\033[0;35m)\033[0m", + end="" + ) + answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) + if isinstance(answer, str): # tool capabile return tools called as a list[dict] + combination['answer'] = answer + # combination['tool_calls'] = [] # no entry + del answer + elif isinstance(answer, dict): # calls + combination['answer'] = answer['answer'] + combination['tool_calls'] = answer['tool_calls'] + del answer + else: + raise Exception(f"runnable returd unkown type {type(answer)}.") + + combination['test'] = test run_results[hash_key] = combination - print("\033[0;32mModel '\033[0m" + + print("\r\033[0;32mModel '\033[0m" + model + "\033[0;32m'" + (" " * (get_len(models) - len(model))) + - " with seed \033[0m" + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(seed)))) + + "\033[0m" + str(seed) + - (" " * (get_len(seeds) - len(str(seed)))) + "\033[0;32m finished test '\033[0m" + test.name + "\033[0;32m'" + @@ -76,15 +127,16 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: "\033[0;32m)\033[0m" ) except Exception as e: - print("\033[0;31mError: <\033[0m " + str(e) + "\033[0;31m>\033[0m trying to continue...") + print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...") else: - print("\033[0;34mModel '\033[0m" + + print("\r\033[0;34mModel '\033[0m" + model + "\033[0;34m'" + (" " * (get_len(models) - len(model))) + - " with seed \033[0m" + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(seed)))) + + "\033[0m" + str(seed) + - (" " * (get_len(seeds) - len(str(seed)))) + "\033[0;34m skipped test '\033[0m" + test.name + "\033[0;34m'" + @@ -100,25 +152,37 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: for hash_key in run_results: result = run_results[hash_key] - entry = { - 'test_name': result['test_name'], - 'model': result['model'], - 'seed': result['seed'], - 'answer': result['answer'], - 'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url) - } + try: + entry = { + 'test_name': result['test_name'], + 'test_id': result['test_id'], + 'model': result['model'], + 'seed': result['seed'], + 'answer': result['answer'], + 'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url), + } + except Exception as e: + print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m") + continue + + try: + entry['tool_calls'] = result['tool_calls'] + except: + pass + saved_results[hash_key] = entry # add result with validation to saved results print("\033[0;36mTest results of model '\033[0m" + - model + + entry['model'] + "\033[0;36m'" + (" " * (get_len(models) - len(entry['model']))) + - " with seed \033[0m" + - str(seed) + - (" " * (get_len(seeds) - len(str(entry['seed'])))) + + " with seed \033[0m\033[0;30m" + + ("0" * (get_len(seeds) - len(str(entry['seed'])))) + + "\033[0m" + + str(entry['seed']) + "\033[0;36m on test '\033[0m" + - test.name + + entry['test_name'] + "\033[0;36m'" + (" " * (get_len(tests) - len(entry['test_name']))) + " (\033[0m" + diff --git a/libs/runnables.py b/libs/runnables.py index 34129c4..de05d5f 100644 --- a/libs/runnables.py +++ b/libs/runnables.py @@ -1,10 +1,16 @@ from langchain_ollama.chat_models import ChatOllama -from langchain_core.messages import SystemMessage, HumanMessage +from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage from libs.test_class import Test from langchain.tools import Tool +from typing import Literal + +from langgraph.graph import StateGraph, MessagesState +# from langgraph.prebuilt import ToolNode +import json +from pydantic import ValidationError + def basic(model: str, seed: int, test: Test, base_url: str) -> str: - system_msg = test.runnable_input['system_msg'] human_msg = test.runnable_input['human_msg'] @@ -19,8 +25,9 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str: ai_msg = llm.invoke(prompt) return ai_msg.content -def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: + +def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: system_msg = test.runnable_input['system_msg'] human_msg = test.runnable_input['human_msg'] tools_dict = test.runnable_input['tools'] @@ -42,11 +49,233 @@ def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> st prompt.append(ai_msg) try: - tool_call = ai_msg.tool_calls[0] - selected_tool = tools_dict[tool_call["name"].lower()] - tool_msg = selected_tool.invoke(tool_call) - prompt.append(tool_msg) - ai_msg = llm.invoke(prompt) - except IndexError: - pass - return ai_msg.content + tool_calls = [] + for i in range(len(ai_msg.tool_calls)): + tool_call = ai_msg.tool_calls[i] + selected_tool = tools_dict[tool_call["name"].lower()] + tool_msg = selected_tool.invoke(tool_call) + prompt.append(tool_msg) + ai_msg = llm.invoke(prompt) + tool_calls.append({ + "tool": tool_call["name"], + "args": tool_call["args"], + "index": 0 + }) + except IndexError: # LLM didnt use a tool -> jsut return the content + tool_calls = [] + return { + "answer": ai_msg.content, + "tool_calls": tool_calls + } + + + +def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str: + + tool_calls = [] + index = -1 + + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + nonlocal index + if last_message.tool_calls: + index += 1 + return "tools" + return "__end__" + + def call_llm(state: MessagesState): + messages = state["messages"] + response = llm.invoke(messages) + return {"messages": [response]} + + class NxToolNode: + """A node that runs the tools requested in the last AIMessage.""" + + def __init__(self, tools: list) -> None: + self.tools_by_name = {tool.name: tool for tool in tools} + + def __call__(self, inputs: dict): + if messages := inputs.get("messages", []): + message = messages[-1] + else: + raise ValueError("No message found in input") + outputs = [] + for tool_call in message.tool_calls: + + nonlocal tool_calls + nonlocal index + tool_calls.append({ + "tool": tool_call["name"], + "args": tool_call["args"], + "index": index + }) + + try: + tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) + except KeyError as e: + tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' + except ValidationError as e: + tool_result = 'Tool got invalid input:\n' + e + except Exception as e: + tool_result = 'Error: ' + str(e) + + outputs.append( + ToolMessage( + content=json.dumps(tool_result), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} + + + tools_dict = test.runnable_input['tools'] + tools = [] + for key in tools_dict: + tools.append(tools_dict[key]) + tool_node = NxToolNode(tools) + llm = ChatOllama( + model=model, + seed=seed, + base_url=base_url + ).bind_tools(tools) + + workflow = StateGraph(MessagesState) + + # Define the two nodes we will cycle between + workflow.add_node("agent", call_llm) + workflow.add_node("tools", tool_node) + + workflow.add_edge("__start__", "agent") + workflow.add_conditional_edges( + "agent", + should_continue, + ) + workflow.add_edge("tools", "agent") + + graph = workflow.compile() + + # example with a single tool call + start_messages = [ + SystemMessage(content=test.runnable_input['system_msg']), + HumanMessage(content=test.runnable_input['human_msg']) + ] + + chunks = [] + + for chunk in graph.stream( + {"messages": start_messages}, + stream_mode="values", + ): chunks.append(chunk["messages"][-1]) + + return { + "answer": chunks[-1].content, + "tool_calls": tool_calls + } + + + + + + +def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str: + + tool_calls = [] + index = -1 + + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + nonlocal index + if last_message.tool_calls: + index += 1 + return "tools" + return "__end__" + + def call_llm(state: MessagesState): + messages = state["messages"] + response = llm.invoke(messages) + return {"messages": [response]} + + class NxToolNode: + """A node that runs the tools requested in the last AIMessage.""" + + def __init__(self, tools: list) -> None: + self.tools_by_name = {tool.name: tool for tool in tools} + + def __call__(self, inputs: dict): + if messages := inputs.get("messages", []): + message = messages[-1] + else: + raise ValueError("No message found in input") + outputs = [] + for tool_call in message.tool_calls: + + nonlocal tool_calls + nonlocal index + tool_calls.append({ + "tool": tool_call["name"], + "args": tool_call["args"], + "index": index + }) + + try: + tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"]) + except KeyError as e: + tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}' + except ValidationError as e: + tool_result = 'Tool got invalid input:\n' + e + except Exception as e: + tool_result = 'Error: ' + str(e) + + outputs.append( + ToolMessage( + content=json.dumps(tool_result), + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + return {"messages": outputs} + + + tools_dict = test.runnable_input['tools'] + tools = [] + for key in tools_dict: + tools.append(tools_dict[key]) + tool_node = NxToolNode(tools) + llm = ChatOllama( + model=model, + seed=seed, + base_url=base_url + ).bind_tools(tools) + + workflow = StateGraph(MessagesState) + + # Define the two nodes we will cycle between + workflow.add_node("agent", call_llm) + workflow.add_node("tools", tool_node) + + workflow.add_edge("__start__", "agent") + workflow.add_conditional_edges( + "agent", + should_continue, + ) + workflow.add_edge("tools", "agent") + + graph = workflow.compile() + + # example with a single tool call + start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ] + + chunks = [] + + for chunk in graph.stream( + {"messages": start_messages}, + stream_mode="values", + ): chunks.append(chunk["messages"][-1]) + + return { + "answer": chunks[-1].content, + "tool_calls": tool_calls + } diff --git a/libs/tools.py b/libs/tools.py index 8f88a62..0fdad07 100644 --- a/libs/tools.py +++ b/libs/tools.py @@ -1,11 +1,150 @@ from langchain.tools import tool +from datetime import datetime, timedelta +from re import search +from dataclasses import dataclass +from typing import Union @tool -def add(a: float, b: float) -> float: +def add(a: float, b: float) -> str: """Adds a+b and retuns the sum""" - return a+b + af = float(a) + bf = float(b) + return f"{a} + {b} = {a+b}" @tool -def multiply(a: float, b: float) -> float: +def multiply(a: float, b: float) -> str: """Multiplies a*b and retuns the product""" - return a*b + af = float(a) + bf = float(b) + return f"{a} * {b} = {a*b}" + +@tool +def get_current_date_and_time() -> str: + """Return current Date and time""" + return "Thursday the 8th of August 2024 18:03" + + +@dataclass +class Entry: + time: datetime + content: str + +note_entries = [ + Entry( + time=datetime.strptime("2024/08/03 14:58", "%Y/%m/%d %H:%M"), + content="Granny Petra says I should call Wolfgang to ask him when Susanne comes back when he comes back from his holidays." + ), + Entry( + time=datetime.strptime("2024/08/07 09:15", "%Y/%m/%d %H:%M"), + content="Mom says to buy some fresh flowers for the living room before Aunt Linda visits." + ), + Entry( + time=datetime.strptime("2024/08/06 18:30", "%Y/%m/%d %H:%M"), + content="Pick up the dry cleaning on Thursday; they close early on Fridays." + ), + Entry( + time=datetime.strptime("2024/08/05 11:45", "%Y/%m/%d %H:%M"), + content="Ask Dr. Mills about the side effects of the new medication he got me." + ), + Entry( + time=datetime.strptime("2024/08/04 16:00", "%Y/%m/%d %H:%M"), + content="Call the plumber to fix the leak in the upstairs bathroom." + ), + Entry( + time=datetime.strptime("2024/08/03 08:00", "%Y/%m/%d %H:%M"), + content="Schedule a car service appointment before the road trip to the mountains." + ), + Entry( + time=datetime.strptime("2024/08/02 20:10", "%Y/%m/%d %H:%M"), + content="Check if the library has a copy of the new mystery novel everyone is talking about." + ), + Entry( + time=datetime.strptime("2024/08/01 14:30", "%Y/%m/%d %H:%M"), + content="Send a thank-you note to Mrs. Jenkins for the lovely dinner last weekend." + ), + Entry( + time=datetime.strptime("2024/07/31 12:05", "%Y/%m/%d %H:%M"), + content="Email the project update to the team by the end of the week." + ), + Entry( + time=datetime.strptime("2024/07/30 07:50", "%Y/%m/%d %H:%M"), + content="Pick up a birthday card for Uncle George before the family gathering on Sunday." + ), + Entry( + time=datetime.strptime("2024/07/29 15:20", "%Y/%m/%d %H:%M"), + content="Research local yoga classes; consider signing up for the weekend session." + ), + Entry( + time=datetime.strptime("2023/08/01 07:21", "%Y/%m/%d %H:%M"), + content="Talk to Joffrey for the insurance!" + ), + Entry( + time=datetime.strptime("2023/08/01 23:10", "%Y/%m/%d %H:%M"), + content="Went out with Charlotte for our anniversary. Pizza at Cavalinos. She loved the Necklace!" + ) +] + +@tool +def get_notes_in_timespan(begin: str, to: str) -> str: + """Recieves the Notes saved in a time span. + + aguments: + begin: str # start of the time span (incluive) %Y/%m/%d + to: str # end of the timespan (incluive) %Y/%m/%d + + exaples: + {"begin": "2012/08/31", "to": "2012/09/06"} # 7 days from the 31st 00:00 till the 6th 23:59 + {"begin": "2019/04/14", "to": "2019/04/14"} # All notes from the 19th of April 2019""" + + try: + begin_d = datetime.strptime(begin, "%Y/%m/%d") + to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M") + except: return "Error: Invalid input. Date format is %Y/%m/%d" + + try: assert begin_d < to_d + except: return "Error: from time has to be before to time." + + filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d] + + if filtered_entries == []: + return "No entries were found for that time period." + + ret = "" + is_first = True + for entry in filtered_entries: + ret += '' if is_first else '\n\n' + ret += f"{datetime.strftime(entry.time, '%Y/%m/%d %H:%M')} {entry.content}" + is_first = False + + return ret + +@tool +def get_notes_containing(patterns: Union[list[str], str]) -> str: + """Recieves the Notes matching any of the RegEx patterns. + + aguments: + patterns: Union[list[str], str] # case insensitive pattern(s) notes are to be mached against + + exaples: + {"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine""" + + if isinstance(patterns, list): big_pattern = '|'.join(f"({s})" for s in patterns) + elif isinstance(patterns, str): big_pattern = patterns + else: return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}." + + filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())] + + ret = "" + is_first = True + for entry in filtered_entries: + ret += '' if is_first else '\n\n' + ret += f"{datetime.strftime(entry.time, '%Y/%m/%d %H:%M')} {entry.content}" + is_first = False + + return ret + +@tool +def write_note(content: str) -> str: + """Write a not with the current time to the database.""" + return content + diff --git a/libs/validators.py b/libs/validators.py index 394fadc..6760a03 100644 --- a/libs/validators.py +++ b/libs/validators.py @@ -2,6 +2,8 @@ from langchain_ollama.chat_models import ChatOllama from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate from langchain.tools import tool from libs.test_class import Test +from re import search +from textwrap import dedent def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: @@ -11,27 +13,19 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: return rating prompt = ChatPromptTemplate.from_messages([ - SystemMessagePromptTemplate.from_template(template="""You evaluate LLMs. Rate the LLM answer as correct, if the answer is -{validation_input} + SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is + {validation_input} -else as incorrect. Only use the rate tool. Do not answer conversationally."""), -# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if -# {validation_input} + else as incorrect. Only use the rate tool. Do not answer conversationally.""")), + HumanMessagePromptTemplate.from_template(template=dedent("""System Message: + {system_msg} -# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect. + Human query: + {human_msg} -# **Only use the rate tool. Do not under any circumstances answer conversationally**. -# DO NOT ANSWER WITH or anything like it. -# Use the rate tool!"""), - HumanMessagePromptTemplate.from_template(template="""System Message: -{system_msg} - -Human query: -{human_msg} - -LLM answer: -{answer} -""") + LLM answer: + {answer} + """)) ]).invoke({ "validation_input": test.validation_input['criteria'], "system_msg": test.runnable_input['system_msg'], @@ -48,7 +42,10 @@ LLM answer: ai_msg = llm.invoke(prompt) try: - ret_str = rate.invoke(ai_msg.tool_calls[0]).content + tool_call = ai_msg.tool_calls[0] + if tool_call['name'] != "rate": + raise Exception(f"Verificaiton model tried to tool `{tool_call['name']}` not `rate`") + ret_str = rate.invoke(tool_call).content if ret_str.lower() == 'true': return True elif ret_str.lower() == 'false': return False else: raise Exception(f"rate tool retured {ret_str}") @@ -56,8 +53,6 @@ LLM answer: print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...") return system_human_answer_match(test=test, answer=answer, base_url=base_url) -from re import search - def regex_match_any(test: Test, answer: str, base_url: str) -> bool: match = False for pattern in test.validation_input['patterns']: