Mul test, valdidation works, but printing it doesnt
This commit is contained in:
@@ -19,10 +19,10 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
results.append({"test": test,"model": model, "seed": seed, "result": result})
|
results.append({"test": test,"model": model, "seed": seed, "result": result})
|
||||||
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
|
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\033[0;31mError:\033[0m" + e)
|
print("\033[0;31mError:\033[0m " + str(e))
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url)
|
result['validation'] = result['test'].validator(test=result['test'], answer=result['result'], base_url=base_url)
|
||||||
|
|
||||||
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))
|
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
from libs.test_class import Test
|
from libs.test_class import Test
|
||||||
|
from langchain.tools import Tool
|
||||||
|
|
||||||
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
if test.system_msg == None: prompt = [ test.human_msg ]
|
system_msg = test.runnable_input['system_msg']
|
||||||
else: prompt = [ test.system_msg, test.human_msg ]
|
human_msg = test.runnable_input['human_msg']
|
||||||
|
|
||||||
|
if system_msg == None: prompt = [ human_msg ]
|
||||||
|
else: prompt = [ system_msg, human_msg ]
|
||||||
|
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -14,3 +18,35 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
)
|
)
|
||||||
ai_msg = llm.invoke(prompt)
|
ai_msg = llm.invoke(prompt)
|
||||||
return ai_msg.content
|
return ai_msg.content
|
||||||
|
|
||||||
|
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
|
system_msg = test.runnable_input['system_msg']
|
||||||
|
human_msg = test.runnable_input['human_msg']
|
||||||
|
tools_dict = test.runnable_input['tools']
|
||||||
|
tools = []
|
||||||
|
for key in tools_dict:
|
||||||
|
tools.append(tools_dict[key])
|
||||||
|
|
||||||
|
if system_msg == None: prompt = [ human_msg ]
|
||||||
|
else: prompt = [ system_msg, human_msg ]
|
||||||
|
|
||||||
|
llm = ChatOllama(
|
||||||
|
model=model,
|
||||||
|
seed=seed,
|
||||||
|
base_url=base_url
|
||||||
|
).bind_tools(tools)
|
||||||
|
|
||||||
|
ai_msg = llm.invoke(prompt)
|
||||||
|
|
||||||
|
prompt.append(ai_msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call = ai_msg.tool_calls[0]
|
||||||
|
selected_tool = tools_dict[tool_call["name"].lower()]
|
||||||
|
tool_msg = selected_tool.invoke(tool_call)
|
||||||
|
prompt.append(tool_msg)
|
||||||
|
ai_msg = llm.invoke(prompt)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
return ai_msg.content
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable
|
from typing import Callable, Any
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Test:
|
class Test:
|
||||||
name: str
|
name: str
|
||||||
system_msg: field(default="You are a helful AI assistant.")
|
|
||||||
human_msg: str
|
|
||||||
validation_info: field(default="""- it is factually correct
|
|
||||||
- it fits/answers the system message and human query
|
|
||||||
- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""")
|
|
||||||
runnable: Callable
|
runnable: Callable
|
||||||
|
runnable_input: dict
|
||||||
validator: Callable
|
validator: Callable
|
||||||
|
validation_input: dict
|
||||||
|
|||||||
11
libs/tools.py
Normal file
11
libs/tools.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def add(a: float, b: float) -> float:
|
||||||
|
"""Adds a+b and retuns the sum"""
|
||||||
|
return a+b
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def multiply(a: float, b: float) -> float:
|
||||||
|
"""Multiplies a*b and retuns the product"""
|
||||||
|
return a*b
|
||||||
@@ -10,11 +10,11 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
|||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages([
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
|
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
|
||||||
{validation_info}
|
{validation_input}
|
||||||
|
|
||||||
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
|
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
|
||||||
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
|
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
|
||||||
# {validation_info}
|
# {validation_input}
|
||||||
|
|
||||||
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
|
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
|
||||||
|
|
||||||
@@ -31,9 +31,9 @@ Answer:
|
|||||||
{answer}
|
{answer}
|
||||||
""")
|
""")
|
||||||
]).invoke({
|
]).invoke({
|
||||||
"validation_info": test.validation_info,
|
"validation_input": test.validation_input,
|
||||||
"system_msg": test.system_msg,
|
"system_msg": test.runnable_input['system_msg'],
|
||||||
"human_msg": test.human_msg,
|
"human_msg": test.runnable_input['human_msg'],
|
||||||
"answer": answer
|
"answer": answer
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -51,3 +51,11 @@ Answer:
|
|||||||
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
||||||
return system_human_answer_match(test=test, answer=answer)
|
return system_human_answer_match(test=test, answer=answer)
|
||||||
|
|
||||||
|
from re import search
|
||||||
|
|
||||||
|
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
||||||
|
match = False
|
||||||
|
for pattern in test.validation_input['patterns']:
|
||||||
|
if search(pattern, answer):
|
||||||
|
match = True
|
||||||
|
return match
|
||||||
|
|||||||
@@ -2,12 +2,13 @@ from libs.test_class import Test
|
|||||||
from libs.run_tests import run_tests
|
from libs.run_tests import run_tests
|
||||||
from libs.runnables import *
|
from libs.runnables import *
|
||||||
from libs.validators import *
|
from libs.validators import *
|
||||||
|
from libs.tools import *
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
models = [
|
models = [
|
||||||
# "llama3.1", # 8b
|
"llama3.1", # 8b
|
||||||
# "llama3.1:70b",
|
# "llama3.1:70b",
|
||||||
# "llama3-groq-tool-use", # latest
|
# "llama3-groq-tool-use", # latest
|
||||||
# "llama3-groq-tool-use:70b",
|
# "llama3-groq-tool-use:70b",
|
||||||
@@ -15,29 +16,45 @@ def main():
|
|||||||
# "mixtral:8x22b",
|
# "mixtral:8x22b",
|
||||||
# "gemma2:2b",
|
# "gemma2:2b",
|
||||||
# "phi3", # 3.8b
|
# "phi3", # 3.8b
|
||||||
"tinyllama:1.1b",
|
# "tinyllama:1.1b",
|
||||||
]
|
]
|
||||||
seeds = [
|
seeds = [
|
||||||
# 2,
|
# 2,
|
||||||
222,
|
222,
|
||||||
# 22222,
|
22222,
|
||||||
2222222
|
# 2222222
|
||||||
]
|
]
|
||||||
tests = [
|
tests = [
|
||||||
Test(
|
Test(
|
||||||
name="Chinese Fruit",
|
name="Chinese Fruit",
|
||||||
system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.",
|
|
||||||
human_msg="什么蔬菜最健康?",
|
|
||||||
validation_info="""- in Mandarin Chinese
|
|
||||||
- factually correct
|
|
||||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
|
|
||||||
runnable=basic,
|
runnable=basic,
|
||||||
validator=system_human_answer_match
|
runnable_input={
|
||||||
|
"system_msg": "You are a helpful assistant. You serve people across the globe.",
|
||||||
|
"human_msg": "什么蔬菜最健康?",
|
||||||
|
},
|
||||||
|
validator=system_human_answer_match,
|
||||||
|
validation_input={
|
||||||
|
"criteria": """- in Mandarin Chinese
|
||||||
|
- factually correct
|
||||||
|
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Test(
|
||||||
|
name="Simple Multiplication",
|
||||||
|
runnable=one_tool_call_answer,
|
||||||
|
runnable_input={
|
||||||
|
"system_msg": "You are a helpful assistant.",
|
||||||
|
"human_msg": "What is 234215 times 143243?",
|
||||||
|
"tools": {
|
||||||
|
"add": add,
|
||||||
|
"multiply": multiply
|
||||||
|
}
|
||||||
|
},
|
||||||
|
validator=regex_match_any,
|
||||||
|
validation_input={
|
||||||
|
"patterns": ["33549659245", "33,549,659,245", "33.549.659.245"]
|
||||||
|
}
|
||||||
),
|
),
|
||||||
# Test(
|
|
||||||
# name="Simple Multiplication",
|
|
||||||
# system_msg=
|
|
||||||
# )
|
|
||||||
]
|
]
|
||||||
|
|
||||||
results = run_tests(
|
results = run_tests(
|
||||||
|
|||||||
Reference in New Issue
Block a user