Mul test, valdidation works, but printing it doesnt

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-05 14:09:06 +02:00
parent 52a180b936
commit abd6320ce9
6 changed files with 99 additions and 31 deletions

View File

@@ -19,10 +19,10 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
results.append({"test": test,"model": model, "seed": seed, "result": result})
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
except Exception as e:
print("\033[0;31mError:\033[0m" + e)
print("\033[0;31mError:\033[0m " + str(e))
for result in results:
result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url)
result['validation'] = result['test'].validator(test=result['test'], answer=result['result'], base_url=base_url)
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))

View File

@@ -1,11 +1,15 @@
from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage
from libs.test_class import Test
from langchain.tools import Tool
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
if test.system_msg == None: prompt = [ test.human_msg ]
else: prompt = [ test.system_msg, test.human_msg ]
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
llm = ChatOllama(
model=model,
@@ -14,3 +18,35 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str:
)
ai_msg = llm.invoke(prompt)
return ai_msg.content
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
system_msg = test.runnable_input['system_msg']
human_msg = test.runnable_input['human_msg']
tools_dict = test.runnable_input['tools']
tools = []
for key in tools_dict:
tools.append(tools_dict[key])
if system_msg == None: prompt = [ human_msg ]
else: prompt = [ system_msg, human_msg ]
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
).bind_tools(tools)
ai_msg = llm.invoke(prompt)
prompt.append(ai_msg)
try:
tool_call = ai_msg.tool_calls[0]
selected_tool = tools_dict[tool_call["name"].lower()]
tool_msg = selected_tool.invoke(tool_call)
prompt.append(tool_msg)
ai_msg = llm.invoke(prompt)
except IndexError:
pass
return ai_msg.content

View File

@@ -1,14 +1,10 @@
from dataclasses import dataclass, field
from typing import Callable
from typing import Callable, Any
@dataclass
class Test:
name: str
system_msg: field(default="You are a helful AI assistant.")
human_msg: str
validation_info: field(default="""- it is factually correct
- it fits/answers the system message and human query
- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""")
runnable: Callable
runnable_input: dict
validator: Callable
validation_input: dict

11
libs/tools.py Normal file
View File

@@ -0,0 +1,11 @@
from langchain.tools import tool
@tool
def add(a: float, b: float) -> float:
"""Adds a+b and retuns the sum"""
return a+b
@tool
def multiply(a: float, b: float) -> float:
"""Multiplies a*b and retuns the product"""
return a*b

View File

@@ -10,11 +10,11 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
{validation_info}
{validation_input}
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
# {validation_info}
# {validation_input}
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
@@ -31,9 +31,9 @@ Answer:
{answer}
""")
]).invoke({
"validation_info": test.validation_info,
"system_msg": test.system_msg,
"human_msg": test.human_msg,
"validation_input": test.validation_input,
"system_msg": test.runnable_input['system_msg'],
"human_msg": test.runnable_input['human_msg'],
"answer": answer
})
@@ -51,3 +51,11 @@ Answer:
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
return system_human_answer_match(test=test, answer=answer)
from re import search
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
match = False
for pattern in test.validation_input['patterns']:
if search(pattern, answer):
match = True
return match

View File

@@ -2,12 +2,13 @@ from libs.test_class import Test
from libs.run_tests import run_tests
from libs.runnables import *
from libs.validators import *
from libs.tools import *
from pprint import pprint
def main():
models = [
# "llama3.1", # 8b
"llama3.1", # 8b
# "llama3.1:70b",
# "llama3-groq-tool-use", # latest
# "llama3-groq-tool-use:70b",
@@ -15,29 +16,45 @@ def main():
# "mixtral:8x22b",
# "gemma2:2b",
# "phi3", # 3.8b
"tinyllama:1.1b",
# "tinyllama:1.1b",
]
seeds = [
# 2,
222,
# 22222,
2222222
22222,
# 2222222
]
tests = [
Test(
name="Chinese Fruit",
system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.",
human_msg="什么蔬菜最健康?",
validation_info="""- in Mandarin Chinese
runnable=basic,
runnable_input={
"system_msg": "You are a helpful assistant. You serve people across the globe.",
"human_msg": "什么蔬菜最健康?",
},
validator=system_human_answer_match,
validation_input={
"criteria": """- in Mandarin Chinese
- factually correct
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
runnable=basic,
validator=system_human_answer_match
}
),
Test(
name="Simple Multiplication",
runnable=one_tool_call_answer,
runnable_input={
"system_msg": "You are a helpful assistant.",
"human_msg": "What is 234215 times 143243?",
"tools": {
"add": add,
"multiply": multiply
}
},
validator=regex_match_any,
validation_input={
"patterns": ["33549659245", "33,549,659,245", "33.549.659.245"]
}
),
# Test(
# name="Simple Multiplication",
# system_msg=
# )
]
results = run_tests(