building of pipeline (validation flaky)

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-04 20:50:11 +02:00
parent e56fa9225c
commit 52a180b936
8 changed files with 168 additions and 53 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.venv .venv
__pycache */__pycache__/*
.direnv
.vscode .vscode

0
libs/__init__.py Normal file
View File

View File

@@ -1,52 +0,0 @@
from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.tools import Tool
def query_fits_to_answer(query: str, answer: str) -> bool:
def rate(rating: bool) -> None:
"""Rate answer as correct (True) or as incorrect (False)."""
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""You are a rating machine. You rate answers as correct if they are
1. factually correct (every statement made)
2. fitting response to the query answering all questions prompted
if the answer does not mach these criteria, rate the answer as incorrect. **Only use the rate tool. Do not answer conversationally**.
Do not answer with <I'm sorry but I do not have the capability to perform this task for you, I am happy to help you with any other queries you may have.> or anything like it. Just use the `rate` tool."""),
HumanMessagePromptTemplate.from_template(template="""Query:
{query}
Answer:
{answer}
""")
]).invoke({"query": query, "answer": answer})
llm = ChatOllama(model="llama3-groq-tool-use:70b").bind_tools([rate])
ai_msg = llm.invoke(prompt)
try:
return ai_msg.tool_calls[0]['args']['rating']
except IndexError as e:
print(f"\rValidation Error of <{ai_msg.content}> Retrying...")
return query_fits_to_answer(query=query, answer=answer)
if __name__ == "__main__":
# print(query_fits_to_answer(
# query="Who is Obama?",
# answer="Barack Obama was the 44th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first African American to hold the office."
# ))
# print(query_fits_to_answer(
# query="Who is Obama?",
# answer="Quantum computing is a revolutionary technology that uses the principles of quantum mechanics to perform calculations and operations on data. It's a fundamentally different approach from classical computing, which is based on bits (0s and 1s) and transistors."
# ))
# print(query_fits_to_answer(
# query="Who is Obama?",
# answer="Barack Obama was the 72th President of the United States, serving two terms from January 2005 to January 2013. He was a significant figure in American politics and made history as the first Chinese American to hold the office."
# ))
print(query_fits_to_answer(
query="Who is Obama?",
answer="Barack Obama was the 45th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first Chinese American to hold the office."
))

29
libs/run_tests.py Normal file
View File

@@ -0,0 +1,29 @@
from libs.test_class import Test
from libs.validators import system_human_answer_match
from libs.runnables import basic
def padd(list, element):
longest = 0
for s in list:
longest = max(longest, len(str(s)))
return str(element).ljust(longest)
def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str):
results = []
esc = "\033"
for model in models:
for seed in seeds:
for test in tests:
try:
result = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
results.append({"test": test,"model": model, "seed": seed, "result": result})
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
except Exception as e:
print("\033[0;31mError:\033[0m" + e)
for result in results:
result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url)
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))
return results

16
libs/runnables.py Normal file
View File

@@ -0,0 +1,16 @@
from langchain_ollama.chat_models import ChatOllama
from langchain_core.messages import SystemMessage, HumanMessage
from libs.test_class import Test
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
if test.system_msg == None: prompt = [ test.human_msg ]
else: prompt = [ test.system_msg, test.human_msg ]
llm = ChatOllama(
model=model,
seed=seed,
base_url=base_url
)
ai_msg = llm.invoke(prompt)
return ai_msg.content

14
libs/test_class.py Normal file
View File

@@ -0,0 +1,14 @@
from dataclasses import dataclass, field
from typing import Callable
@dataclass
class Test:
name: str
system_msg: field(default="You are a helful AI assistant.")
human_msg: str
validation_info: field(default="""- it is factually correct
- it fits/answers the system message and human query
- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""")
runnable: Callable
validator: Callable

53
libs/validators.py Normal file
View File

@@ -0,0 +1,53 @@
from langchain_ollama.chat_models import ChatOllama
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.tools import Tool
from libs.test_class import Test
def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
def rate(rating: bool) -> None:
"""Rate answer as correct (True) or as incorrect (False)."""
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
{validation_info}
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
# {validation_info}
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
# **Only use the rate tool. Do not under any circumstances answer conversationally**.
# DO NOT ANSWER WITH <I'm sorry but I do not have the capability to perform this task for you...> or anything like it.
# Use the rate tool!"""),
HumanMessagePromptTemplate.from_template(template="""System Message:
{system_msg}
Query:
{human_msg}
Answer:
{answer}
""")
]).invoke({
"validation_info": test.validation_info,
"system_msg": test.system_msg,
"human_msg": test.human_msg,
"answer": answer
})
llm = ChatOllama(
model="llama3.1:70b",
# model="llama3-groq-tool-use:70b",
base_url=base_url
).bind_tools([rate])
ai_msg = llm.invoke(prompt)
try:
return ai_msg.tool_calls[0]['args']['rating']
except IndexError as e:
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
return system_human_answer_match(test=test, answer=answer)

54
test_small_llms.py Normal file
View File

@@ -0,0 +1,54 @@
from libs.test_class import Test
from libs.run_tests import run_tests
from libs.runnables import *
from libs.validators import *
from pprint import pprint
def main():
models = [
# "llama3.1", # 8b
# "llama3.1:70b",
# "llama3-groq-tool-use", # latest
# "llama3-groq-tool-use:70b",
# "mixtral:8x7b",
# "mixtral:8x22b",
# "gemma2:2b",
# "phi3", # 3.8b
"tinyllama:1.1b",
]
seeds = [
# 2,
222,
# 22222,
2222222
]
tests = [
Test(
name="Chinese Fruit",
system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.",
human_msg="什么蔬菜最健康?",
validation_info="""- in Mandarin Chinese
- factually correct
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
runnable=basic,
validator=system_human_answer_match
),
# Test(
# name="Simple Multiplication",
# system_msg=
# )
]
results = run_tests(
models=models,
seeds=seeds,
tests=tests,
base_url="http://bolt.hs-mittweida.de:11434"
)
print()
for result in results: print(f"\n\033[0;36mtest_name:\033[0m {result['test'].name}\n\033[0;36mmodel:\033[0m {result['model']}\n\033[0;36mseed:\033[0m {result['seed']}\n\033[0;36mvalidation_result:\033[0m {result['validation']}\n\033[0;36manswer: ⏎\033[0m\n{result['result']}")
if __name__ == "__main__":
main()