Mul test, valdidation works, but printing it doesnt
This commit is contained in:
@@ -19,10 +19,10 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
||||
results.append({"test": test,"model": model, "seed": seed, "result": result})
|
||||
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
|
||||
except Exception as e:
|
||||
print("\033[0;31mError:\033[0m" + e)
|
||||
print("\033[0;31mError:\033[0m " + str(e))
|
||||
|
||||
for result in results:
|
||||
result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url)
|
||||
result['validation'] = result['test'].validator(test=result['test'], answer=result['result'], base_url=base_url)
|
||||
|
||||
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from libs.test_class import Test
|
||||
from langchain.tools import Tool
|
||||
|
||||
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
if test.system_msg == None: prompt = [ test.human_msg ]
|
||||
else: prompt = [ test.system_msg, test.human_msg ]
|
||||
system_msg = test.runnable_input['system_msg']
|
||||
human_msg = test.runnable_input['human_msg']
|
||||
|
||||
if system_msg == None: prompt = [ human_msg ]
|
||||
else: prompt = [ system_msg, human_msg ]
|
||||
|
||||
llm = ChatOllama(
|
||||
model=model,
|
||||
@@ -14,3 +18,35 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
)
|
||||
ai_msg = llm.invoke(prompt)
|
||||
return ai_msg.content
|
||||
|
||||
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||
|
||||
system_msg = test.runnable_input['system_msg']
|
||||
human_msg = test.runnable_input['human_msg']
|
||||
tools_dict = test.runnable_input['tools']
|
||||
tools = []
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
|
||||
if system_msg == None: prompt = [ human_msg ]
|
||||
else: prompt = [ system_msg, human_msg ]
|
||||
|
||||
llm = ChatOllama(
|
||||
model=model,
|
||||
seed=seed,
|
||||
base_url=base_url
|
||||
).bind_tools(tools)
|
||||
|
||||
ai_msg = llm.invoke(prompt)
|
||||
|
||||
prompt.append(ai_msg)
|
||||
|
||||
try:
|
||||
tool_call = ai_msg.tool_calls[0]
|
||||
selected_tool = tools_dict[tool_call["name"].lower()]
|
||||
tool_msg = selected_tool.invoke(tool_call)
|
||||
prompt.append(tool_msg)
|
||||
ai_msg = llm.invoke(prompt)
|
||||
except IndexError:
|
||||
pass
|
||||
return ai_msg.content
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
from typing import Callable, Any
|
||||
|
||||
@dataclass
|
||||
class Test:
|
||||
name: str
|
||||
system_msg: field(default="You are a helful AI assistant.")
|
||||
human_msg: str
|
||||
validation_info: field(default="""- it is factually correct
|
||||
- it fits/answers the system message and human query
|
||||
- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""")
|
||||
runnable: Callable
|
||||
runnable_input: dict
|
||||
validator: Callable
|
||||
|
||||
validation_input: dict
|
||||
|
||||
11
libs/tools.py
Normal file
11
libs/tools.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from langchain.tools import tool
|
||||
|
||||
@tool
|
||||
def add(a: float, b: float) -> float:
|
||||
"""Adds a+b and retuns the sum"""
|
||||
return a+b
|
||||
|
||||
@tool
|
||||
def multiply(a: float, b: float) -> float:
|
||||
"""Multiplies a*b and retuns the product"""
|
||||
return a*b
|
||||
@@ -10,11 +10,11 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
|
||||
{validation_info}
|
||||
{validation_input}
|
||||
|
||||
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
|
||||
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
|
||||
# {validation_info}
|
||||
# {validation_input}
|
||||
|
||||
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
|
||||
|
||||
@@ -31,9 +31,9 @@ Answer:
|
||||
{answer}
|
||||
""")
|
||||
]).invoke({
|
||||
"validation_info": test.validation_info,
|
||||
"system_msg": test.system_msg,
|
||||
"human_msg": test.human_msg,
|
||||
"validation_input": test.validation_input,
|
||||
"system_msg": test.runnable_input['system_msg'],
|
||||
"human_msg": test.runnable_input['human_msg'],
|
||||
"answer": answer
|
||||
})
|
||||
|
||||
@@ -51,3 +51,11 @@ Answer:
|
||||
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
||||
return system_human_answer_match(test=test, answer=answer)
|
||||
|
||||
from re import search
|
||||
|
||||
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
||||
match = False
|
||||
for pattern in test.validation_input['patterns']:
|
||||
if search(pattern, answer):
|
||||
match = True
|
||||
return match
|
||||
|
||||
@@ -2,12 +2,13 @@ from libs.test_class import Test
|
||||
from libs.run_tests import run_tests
|
||||
from libs.runnables import *
|
||||
from libs.validators import *
|
||||
from libs.tools import *
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
def main():
|
||||
models = [
|
||||
# "llama3.1", # 8b
|
||||
"llama3.1", # 8b
|
||||
# "llama3.1:70b",
|
||||
# "llama3-groq-tool-use", # latest
|
||||
# "llama3-groq-tool-use:70b",
|
||||
@@ -15,29 +16,45 @@ def main():
|
||||
# "mixtral:8x22b",
|
||||
# "gemma2:2b",
|
||||
# "phi3", # 3.8b
|
||||
"tinyllama:1.1b",
|
||||
# "tinyllama:1.1b",
|
||||
]
|
||||
seeds = [
|
||||
# 2,
|
||||
222,
|
||||
# 22222,
|
||||
2222222
|
||||
22222,
|
||||
# 2222222
|
||||
]
|
||||
tests = [
|
||||
Test(
|
||||
name="Chinese Fruit",
|
||||
system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.",
|
||||
human_msg="什么蔬菜最健康?",
|
||||
validation_info="""- in Mandarin Chinese
|
||||
- factually correct
|
||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
|
||||
runnable=basic,
|
||||
validator=system_human_answer_match
|
||||
runnable_input={
|
||||
"system_msg": "You are a helpful assistant. You serve people across the globe.",
|
||||
"human_msg": "什么蔬菜最健康?",
|
||||
},
|
||||
validator=system_human_answer_match,
|
||||
validation_input={
|
||||
"criteria": """- in Mandarin Chinese
|
||||
- factually correct
|
||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
|
||||
}
|
||||
),
|
||||
Test(
|
||||
name="Simple Multiplication",
|
||||
runnable=one_tool_call_answer,
|
||||
runnable_input={
|
||||
"system_msg": "You are a helpful assistant.",
|
||||
"human_msg": "What is 234215 times 143243?",
|
||||
"tools": {
|
||||
"add": add,
|
||||
"multiply": multiply
|
||||
}
|
||||
},
|
||||
validator=regex_match_any,
|
||||
validation_input={
|
||||
"patterns": ["33549659245", "33,549,659,245", "33.549.659.245"]
|
||||
}
|
||||
),
|
||||
# Test(
|
||||
# name="Simple Multiplication",
|
||||
# system_msg=
|
||||
# )
|
||||
]
|
||||
|
||||
results = run_tests(
|
||||
|
||||
Reference in New Issue
Block a user