From 52a180b936b8dca88b893b97287b0c5e0d21123b Mon Sep 17 00:00:00 2001 From: "Lennart J. Kurzweg (Nx2)" Date: Sun, 4 Aug 2024 20:50:11 +0200 Subject: [PATCH] building of pipeline (validation flaky) --- .gitignore | 3 +- libs/__init__.py | 0 libs/query_fits_to_answer.py | 52 ---------------------------------- libs/run_tests.py | 29 +++++++++++++++++++ libs/runnables.py | 16 +++++++++++ libs/test_class.py | 14 ++++++++++ libs/validators.py | 53 +++++++++++++++++++++++++++++++++++ test_small_llms.py | 54 ++++++++++++++++++++++++++++++++++++ 8 files changed, 168 insertions(+), 53 deletions(-) create mode 100644 libs/__init__.py delete mode 100644 libs/query_fits_to_answer.py create mode 100644 libs/run_tests.py create mode 100644 libs/runnables.py create mode 100644 libs/test_class.py create mode 100644 libs/validators.py create mode 100644 test_small_llms.py diff --git a/.gitignore b/.gitignore index 2fc6d6a..33563d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .venv -__pycache +*/__pycache__/* +.direnv .vscode diff --git a/libs/__init__.py b/libs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/query_fits_to_answer.py b/libs/query_fits_to_answer.py deleted file mode 100644 index f42fa3a..0000000 --- a/libs/query_fits_to_answer.py +++ /dev/null @@ -1,52 +0,0 @@ -from langchain_ollama.chat_models import ChatOllama -from langchain_core.messages import SystemMessage -from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate -from langchain.tools import Tool - -def query_fits_to_answer(query: str, answer: str) -> bool: - - def rate(rating: bool) -> None: - """Rate answer as correct (True) or as incorrect (False).""" - - prompt = ChatPromptTemplate.from_messages([ - SystemMessage(content="""You are a rating machine. You rate answers as correct if they are - 1. factually correct (every statement made) - 2. fitting response to the query answering all questions prompted - - if the answer does not mach these criteria, rate the answer as incorrect. **Only use the rate tool. Do not answer conversationally**. - Do not answer with or anything like it. Just use the `rate` tool."""), - HumanMessagePromptTemplate.from_template(template="""Query: - {query} - - Answer: - {answer} - """) - ]).invoke({"query": query, "answer": answer}) - - llm = ChatOllama(model="llama3-groq-tool-use:70b").bind_tools([rate]) - - ai_msg = llm.invoke(prompt) - - try: - return ai_msg.tool_calls[0]['args']['rating'] - except IndexError as e: - print(f"\rValidation Error of <{ai_msg.content}> Retrying...") - return query_fits_to_answer(query=query, answer=answer) - -if __name__ == "__main__": - # print(query_fits_to_answer( - # query="Who is Obama?", - # answer="Barack Obama was the 44th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first African American to hold the office." - # )) - # print(query_fits_to_answer( - # query="Who is Obama?", - # answer="Quantum computing is a revolutionary technology that uses the principles of quantum mechanics to perform calculations and operations on data. It's a fundamentally different approach from classical computing, which is based on bits (0s and 1s) and transistors." - # )) - # print(query_fits_to_answer( - # query="Who is Obama?", - # answer="Barack Obama was the 72th President of the United States, serving two terms from January 2005 to January 2013. He was a significant figure in American politics and made history as the first Chinese American to hold the office." - # )) - print(query_fits_to_answer( - query="Who is Obama?", - answer="Barack Obama was the 45th President of the United States, serving two terms from January 2009 to January 2017. He was a significant figure in American politics and made history as the first Chinese American to hold the office." - )) diff --git a/libs/run_tests.py b/libs/run_tests.py new file mode 100644 index 0000000..8e2595f --- /dev/null +++ b/libs/run_tests.py @@ -0,0 +1,29 @@ +from libs.test_class import Test +from libs.validators import system_human_answer_match +from libs.runnables import basic + +def padd(list, element): + longest = 0 + for s in list: + longest = max(longest, len(str(s))) + return str(element).ljust(longest) + +def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str): + results = [] + esc = "\033" + for model in models: + for seed in seeds: + for test in tests: + try: + result = test.runnable(model=model, seed=seed, test=test, base_url=base_url) + results.append({"test": test,"model": model, "seed": seed, "result": result}) + print(f"Model {padd(models, model)} starting with seed {padd(seeds, seed)} is done with test '{test.name}'.") + except Exception as e: + print("\033[0;31mError:\033[0m" + e) + + for result in results: + result['validation'] = test.validator(test=result['test'], answer=result['result'], base_url=base_url) + + print(f"Validation of answer from test {result['test'].name} by {result['model']} with seed {result['seed']} evaluated to " + ('\033[0;32mcorrect\033[0m' if result['validation'] == True else '\033[0;31mincorrect\033[0m')) + + return results diff --git a/libs/runnables.py b/libs/runnables.py new file mode 100644 index 0000000..31e4d0d --- /dev/null +++ b/libs/runnables.py @@ -0,0 +1,16 @@ +from langchain_ollama.chat_models import ChatOllama +from langchain_core.messages import SystemMessage, HumanMessage +from libs.test_class import Test + +def basic(model: str, seed: int, test: Test, base_url: str) -> str: + + if test.system_msg == None: prompt = [ test.human_msg ] + else: prompt = [ test.system_msg, test.human_msg ] + + llm = ChatOllama( + model=model, + seed=seed, + base_url=base_url + ) + ai_msg = llm.invoke(prompt) + return ai_msg.content diff --git a/libs/test_class.py b/libs/test_class.py new file mode 100644 index 0000000..5020fd7 --- /dev/null +++ b/libs/test_class.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field +from typing import Callable + +@dataclass +class Test: + name: str + system_msg: field(default="You are a helful AI assistant.") + human_msg: str + validation_info: field(default="""- it is factually correct +- it fits/answers the system message and human query +- it is just the answer, and doesn't have any AI fragments (A/B versions, "end of message" parts, unfiting discalimers or notes)""") + runnable: Callable + validator: Callable + diff --git a/libs/validators.py b/libs/validators.py new file mode 100644 index 0000000..6eda1ae --- /dev/null +++ b/libs/validators.py @@ -0,0 +1,53 @@ +from langchain_ollama.chat_models import ChatOllama +from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate +from langchain.tools import Tool +from libs.test_class import Test + +def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool: + + def rate(rating: bool) -> None: + """Rate answer as correct (True) or as incorrect (False).""" + + prompt = ChatPromptTemplate.from_messages([ + SystemMessagePromptTemplate.from_template(template="""Rate the answer as correct, if the answer is + {validation_info} + + else as incorrect. Only use the rate tool. Do not answer conversationally."""), +# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if +# {validation_info} + +# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect. + +# **Only use the rate tool. Do not under any circumstances answer conversationally**. +# DO NOT ANSWER WITH or anything like it. +# Use the rate tool!"""), + HumanMessagePromptTemplate.from_template(template="""System Message: +{system_msg} + +Query: +{human_msg} + +Answer: +{answer} +""") + ]).invoke({ + "validation_info": test.validation_info, + "system_msg": test.system_msg, + "human_msg": test.human_msg, + "answer": answer + }) + + llm = ChatOllama( + model="llama3.1:70b", + # model="llama3-groq-tool-use:70b", + base_url=base_url + ).bind_tools([rate]) + + ai_msg = llm.invoke(prompt) + + try: + return ai_msg.tool_calls[0]['args']['rating'] + except IndexError as e: + print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...") + return system_human_answer_match(test=test, answer=answer) + diff --git a/test_small_llms.py b/test_small_llms.py new file mode 100644 index 0000000..9140501 --- /dev/null +++ b/test_small_llms.py @@ -0,0 +1,54 @@ +from libs.test_class import Test +from libs.run_tests import run_tests +from libs.runnables import * +from libs.validators import * + +from pprint import pprint + +def main(): + models = [ + # "llama3.1", # 8b + # "llama3.1:70b", + # "llama3-groq-tool-use", # latest + # "llama3-groq-tool-use:70b", + # "mixtral:8x7b", + # "mixtral:8x22b", + # "gemma2:2b", + # "phi3", # 3.8b + "tinyllama:1.1b", + ] + seeds = [ + # 2, + 222, + # 22222, + 2222222 + ] + tests = [ + Test( + name="Chinese Fruit", + system_msg="You are a helpful assistant. You serve people across the globe. You can be a freind, but stay professional.", + human_msg="什么蔬菜最健康?", + validation_info="""- in Mandarin Chinese +- factually correct +- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes)""", + runnable=basic, + validator=system_human_answer_match + ), + # Test( + # name="Simple Multiplication", + # system_msg= + # ) + ] + + results = run_tests( + models=models, + seeds=seeds, + tests=tests, + base_url="http://bolt.hs-mittweida.de:11434" + ) + + print() + for result in results: print(f"\n\033[0;36mtest_name:\033[0m {result['test'].name}\n\033[0;36mmodel:\033[0m {result['model']}\n\033[0;36mseed:\033[0m {result['seed']}\n\033[0;36mvalidation_result:\033[0m {result['validation']}\n\033[0;36manswer: ⏎\033[0m\n{result['result']}") + +if __name__ == "__main__": + main()