building of pipeline (validation flaky)
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
.venv
|
.venv
|
||||||
__pycache
|
*/__pycache__/*
|
||||||
|
.direnv
|
||||||
.vscode
|
.vscode
|
||||||
|
|||||||
0
libs/__init__.py
Normal file
0
libs/__init__.py
Normal file
@@ -1,52 +0,0 @@
|
|||||||
from langchain_ollama.chat_models import ChatOllama
|
|
||||||
from langchain_core.messages import SystemMessage
|
|
||||||
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate
|
|
||||||
from langchain.tools import Tool
|
|
||||||
|
|
||||||
def query_fits_to_answer(query: str, answer: str) -> bool:
|
|
||||||
|
|
||||||
def rate(rating: bool) -> None:
|
|
||||||
"""Rate answer as correct (True) or as incorrect (False)."""
|
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages([
|
|
||||||
SystemMessage(content="""You are a rating machine. You rate answers as correct if they are
|
|
||||||
1. factually correct (every statement made)
|
|
||||||
2. fitting response to the query answering all questions prompted
|
|
||||||
|
|
||||||
if the answer does not mach these criteria, rate the answer as incorrect. **Only use the rate tool. Do not answer conversationally**.
|
|
||||||
Do not answer with <I'm sorry but I do not have the capability to perform this task for you, I am happy to help you with any other queries you may have.> or anything like it. Just use the `rate` tool."""),
|
|
||||||
HumanMessagePromptTemplate.from_template(template="""Query:
|
|
||||||
{query}
|
|
||||||
|
|
||||||
Answer:
|
|
||||||
{answer}
|
|
||||||
""")
|
|
||||||
]).invoke({"query": query, "answer": answer})
|
|
||||||
|
|
||||||
llm = ChatOllama(model="llama3-groq-tool-use:70b").bind_tools([rate])
|
|
||||||
|
|
||||||
ai_msg = llm.invoke(prompt)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return ai_msg.tool_calls[0]['args']['rating']
|
|
||||||
except IndexError as e:
|
|
||||||
print(f"\rValidation Error of <{ai_msg.content}> Retrying...")
|
|
||||||
return query_fits_to_answer(query=query, answer=answer)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# print(query_fits_to_answer(
|
|
||||||
# query="Who is Obama?",
|
|
||||||
# answer="Barack Obama was the 44th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first African American to hold the office."
|
|
||||||
# ))
|
|
||||||
# print(query_fits_to_answer(
|
|
||||||
# query="Who is Obama?",
|
|
||||||
# answer="Quantum computing is a revolutionary technology that uses the principles of quantum mechanics to perform calculations and operations on data. It's a fundamentally different approach from classical computing, which is based on bits (0s and 1s) and transistors."
|
|
||||||
# ))
|
|
||||||
# print(query_fits_to_answer(
|
|
||||||
# query="Who is Obama?",
|
|
||||||
# answer="Barack Obama was the 72th President of the United States, serving two terms from January 2005 to January 2013. He was a significant figure in American politics and made history as the first Chinese American to hold the office."
|
|
||||||
# ))
|
|
||||||
print(query_fits_to_answer(
|
|
||||||
query="Who is Obama?",
|
|
||||||
answer="Barack Obama was the 45th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first Chinese American to hold the office."
|
|
||||||
))
|
|
||||||
29
libs/run_tests.py
Normal file
29
libs/run_tests.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from libs.test_class import Test
|
||||||
|
from libs.validators import system_human_answer_match
|
||||||
|
from libs.runnables import basic
|
||||||
|
|
||||||
|
def padd(list, element):
|
||||||
|
longest = 0
|
||||||
|
for s in list:
|
||||||
|
longest = max(longest, len(str(s)))
|
||||||
|
return str(element).ljust(longest)
|
||||||
|
|
||||||
|
def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str):
|
||||||
|
results = []
|
||||||
|
esc = "\033"
|
||||||
|
for model in models:
|
||||||
|
for seed in seeds:
|
||||||
|
for test in tests:
|
||||||
|
try:
|
||||||
|
result = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
||||||
|
results.append({"test": test,"model": model, "seed": seed, "result": result})
|
||||||
|
print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.")
|
||||||
|
except Exception as e:
|
||||||
|
print("\033[0;31mError:\033[0m" + e)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url)
|
||||||
|
|
||||||
|
print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m'))
|
||||||
|
|
||||||
|
return results
|
||||||
16
libs/runnables.py
Normal file
16
libs/runnables.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage
|
||||||
|
from libs.test_class import Test
|
||||||
|
|
||||||
|
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
|
if test.system_msg == None: prompt = [ test.human_msg ]
|
||||||
|
else: prompt = [ test.system_msg, test.human_msg ]
|
||||||
|
|
||||||
|
llm = ChatOllama(
|
||||||
|
model=model,
|
||||||
|
seed=seed,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
ai_msg = llm.invoke(prompt)
|
||||||
|
return ai_msg.content
|
||||||
14
libs/test_class.py
Normal file
14
libs/test_class.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Test:
|
||||||
|
name: str
|
||||||
|
system_msg: field(default="You are a helful AI assistant.")
|
||||||
|
human_msg: str
|
||||||
|
validation_info: field(default="""- it is factually correct
|
||||||
|
- it fits/answers the system message and human query
|
||||||
|
- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""")
|
||||||
|
runnable: Callable
|
||||||
|
validator: Callable
|
||||||
|
|
||||||
53
libs/validators.py
Normal file
53
libs/validators.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
|
from langchain.tools import Tool
|
||||||
|
from libs.test_class import Test
|
||||||
|
|
||||||
|
def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
||||||
|
|
||||||
|
def rate(rating: bool) -> None:
|
||||||
|
"""Rate answer as correct (True) or as incorrect (False)."""
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is
|
||||||
|
{validation_info}
|
||||||
|
|
||||||
|
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
|
||||||
|
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
|
||||||
|
# {validation_info}
|
||||||
|
|
||||||
|
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
|
||||||
|
|
||||||
|
# **Only use the rate tool. Do not under any circumstances answer conversationally**.
|
||||||
|
# DO NOT ANSWER WITH <I'm sorry but I do not have the capability to perform this task for you...> or anything like it.
|
||||||
|
# Use the rate tool!"""),
|
||||||
|
HumanMessagePromptTemplate.from_template(template="""System Message:
|
||||||
|
{system_msg}
|
||||||
|
|
||||||
|
Query:
|
||||||
|
{human_msg}
|
||||||
|
|
||||||
|
Answer:
|
||||||
|
{answer}
|
||||||
|
""")
|
||||||
|
]).invoke({
|
||||||
|
"validation_info": test.validation_info,
|
||||||
|
"system_msg": test.system_msg,
|
||||||
|
"human_msg": test.human_msg,
|
||||||
|
"answer": answer
|
||||||
|
})
|
||||||
|
|
||||||
|
llm = ChatOllama(
|
||||||
|
model="llama3.1:70b",
|
||||||
|
# model="llama3-groq-tool-use:70b",
|
||||||
|
base_url=base_url
|
||||||
|
).bind_tools([rate])
|
||||||
|
|
||||||
|
ai_msg = llm.invoke(prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ai_msg.tool_calls[0]['args']['rating']
|
||||||
|
except IndexError as e:
|
||||||
|
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
||||||
|
return system_human_answer_match(test=test, answer=answer)
|
||||||
|
|
||||||
54
test_small_llms.py
Normal file
54
test_small_llms.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from libs.test_class import Test
|
||||||
|
from libs.run_tests import run_tests
|
||||||
|
from libs.runnables import *
|
||||||
|
from libs.validators import *
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
def main():
|
||||||
|
models = [
|
||||||
|
# "llama3.1", # 8b
|
||||||
|
# "llama3.1:70b",
|
||||||
|
# "llama3-groq-tool-use", # latest
|
||||||
|
# "llama3-groq-tool-use:70b",
|
||||||
|
# "mixtral:8x7b",
|
||||||
|
# "mixtral:8x22b",
|
||||||
|
# "gemma2:2b",
|
||||||
|
# "phi3", # 3.8b
|
||||||
|
"tinyllama:1.1b",
|
||||||
|
]
|
||||||
|
seeds = [
|
||||||
|
# 2,
|
||||||
|
222,
|
||||||
|
# 22222,
|
||||||
|
2222222
|
||||||
|
]
|
||||||
|
tests = [
|
||||||
|
Test(
|
||||||
|
name="Chinese Fruit",
|
||||||
|
system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.",
|
||||||
|
human_msg="什么蔬菜最健康?",
|
||||||
|
validation_info="""- in Mandarin Chinese
|
||||||
|
- factually correct
|
||||||
|
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""",
|
||||||
|
runnable=basic,
|
||||||
|
validator=system_human_answer_match
|
||||||
|
),
|
||||||
|
# Test(
|
||||||
|
# name="Simple Multiplication",
|
||||||
|
# system_msg=
|
||||||
|
# )
|
||||||
|
]
|
||||||
|
|
||||||
|
results = run_tests(
|
||||||
|
models=models,
|
||||||
|
seeds=seeds,
|
||||||
|
tests=tests,
|
||||||
|
base_url="http://bolt.hs-mittweida.de:11434"
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
for result in results: print(f"\n\033[0;36mtest_name:\033[0m {result['test'].name}\n\033[0;36mmodel:\033[0m {result['model']}\n\033[0;36mseed:\033[0m {result['seed']}\n\033[0;36mvalidation_result:\033[0m {result['validation']}\n\033[0;36manswer: ⏎\033[0m\n{result['result']}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user