diff --git a/libs/run_tests.py b/libs/run_tests.py index 8e2595f..fd9390f 100644 --- a/libs/run_tests.py +++ b/libs/run_tests.py @@ -19,10 +19,10 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: results.append({"test": test,"model": model, "seed": seed, "result": result}) print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.") except Exception as e: - print("\033[0;31mError:\033[0m" + e) + print("\033[0;31mError:\033[0m " + str(e)) for result in results: - result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url) + result['validation'] = result['test'].validator(test=result['test'], answer=result['result'], base_url=base_url) print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m')) diff --git a/libs/runnables.py b/libs/runnables.py index 31e4d0d..34129c4 100644 --- a/libs/runnables.py +++ b/libs/runnables.py @@ -1,11 +1,15 @@ from langchain_ollama.chat_models import ChatOllama from langchain_core.messages import SystemMessage, HumanMessage from libs.test_class import Test +from langchain.tools import Tool def basic(model: str, seed: int, test: Test, base_url: str) -> str: - if test.system_msg == None: prompt = [ test.human_msg ] - else: prompt = [ test.system_msg, test.human_msg ] + system_msg = test.runnable_input['system_msg'] + human_msg = test.runnable_input['human_msg'] + + if system_msg == None: prompt = [ human_msg ] + else: prompt = [ system_msg, human_msg ] llm = ChatOllama( model=model, @@ -14,3 +18,35 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str: ) ai_msg = llm.invoke(prompt) return ai_msg.content + +def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str: + + system_msg = test.runnable_input['system_msg'] + human_msg = test.runnable_input['human_msg'] + tools_dict = test.runnable_input['tools'] + tools = [] + for key in tools_dict: + tools.append(tools_dict[key]) + + if system_msg == None: prompt = [ human_msg ] + else: prompt = [ system_msg, human_msg ] + + llm = ChatOllama( + model=model, + seed=seed, + base_url=base_url + ).bind_tools(tools) + + ai_msg = llm.invoke(prompt) + + prompt.append(ai_msg) + + try: + tool_call = ai_msg.tool_calls[0] + selected_tool = tools_dict[tool_call["name"].lower()] + tool_msg = selected_tool.invoke(tool_call) + prompt.append(tool_msg) + ai_msg = llm.invoke(prompt) + except IndexError: + pass + return ai_msg.content diff --git a/libs/test_class.py b/libs/test_class.py index 5020fd7..dfa57b3 100644 --- a/libs/test_class.py +++ b/libs/test_class.py @@ -1,14 +1,10 @@ from dataclasses import dataclass, field -from typing import Callable +from typing import Callable, Any @dataclass class Test: name: str - system_msg: field(default="You are a helful AI assistant.") - human_msg: str - validation_info: field(default="""- it is factually correct -- it fits/answers the system message and human query -- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""") runnable: Callable + runnable_input: dict validator: Callable - + validation_input: dict diff --git a/libs/tools.py b/libs/tools.py new file mode 100644 index 0000000..8f88a62 --- /dev/null +++ b/libs/tools.py @@ -0,0 +1,11 @@ +from langchain.tools import tool + +@tool +def add(a: float, b: float) -> float: + """Adds a+b and retuns the sum""" + return a+b + +@tool +def multiply(a: float, b: float) -> float: + """Multiplies a*b and retuns the product""" + return a*b diff --git a/libs/validators.py b/libs/validators.py index 6eda1ae..3b45fd2 100644 --- a/libs/validators.py +++ b/libs/validators.py @@ -10,11 +10,11 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: prompt = ChatPromptTemplate.from_messages([ SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is - {validation_info} + {validation_input} else as incorrect. Only use the rate tool. Do not answer conversationally."""), # SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if -# {validation_info} +# {validation_input} # If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect. @@ -31,9 +31,9 @@ Answer: {answer} """) ]).invoke({ - "validation_info": test.validation_info, - "system_msg": test.system_msg, - "human_msg": test.human_msg, + "validation_input": test.validation_input, + "system_msg": test.runnable_input['system_msg'], + "human_msg": test.runnable_input['human_msg'], "answer": answer }) @@ -51,3 +51,11 @@ Answer: print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...") return system_human_answer_match(test=test, answer=answer) +from re import search + +def regex_match_any(test: Test, answer: str, base_url: str) -> bool: + match = False + for pattern in test.validation_input['patterns']: + if search(pattern, answer): + match = True + return match diff --git a/test_small_llms.py b/test_small_llms.py index 9140501..542b15d 100644 --- a/test_small_llms.py +++ b/test_small_llms.py @@ -2,12 +2,13 @@ from libs.test_class import Test from libs.run_tests import run_tests from libs.runnables import * from libs.validators import * +from libs.tools import * from pprint import pprint def main(): models = [ - # "llama3.1", # 8b + "llama3.1", # 8b # "llama3.1:70b", # "llama3-groq-tool-use", # latest # "llama3-groq-tool-use:70b", @@ -15,29 +16,45 @@ def main(): # "mixtral:8x22b", # "gemma2:2b", # "phi3", # 3.8b - "tinyllama:1.1b", + # "tinyllama:1.1b", ] seeds = [ # 2, 222, - # 22222, - 2222222 + 22222, + # 2222222 ] tests = [ Test( name="Chinese Fruit", - system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.", - human_msg="什么蔬菜最健康?", - validation_info="""- in Mandarin Chinese -- factually correct -- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""", - runnable=basic, - validator=system_human_answer_match + runnable=basic, + runnable_input={ + "system_msg": "You are a helpful assistant. You serve people across the globe.", + "human_msg": "什么蔬菜最健康?", + }, + validator=system_human_answer_match, + validation_input={ + "criteria": """- in Mandarin Chinese + - factually correct + - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""", + } + ), + Test( + name="Simple Multiplication", + runnable=one_tool_call_answer, + runnable_input={ + "system_msg": "You are a helpful assistant.", + "human_msg": "What is 234215 times 143243?", + "tools": { + "add": add, + "multiply": multiply + } + }, + validator=regex_match_any, + validation_input={ + "patterns": ["33549659245", "33,549,659,245", "33.549.659.245"] + } ), - # Test( - # name="Simple Multiplication", - # system_msg= - # ) ] results = run_tests(