From 11f37009d3d1cbd71d8a9c78f72fafa28fa013c4 Mon Sep 17 00:00:00 2001 From: "Lennart J. Kurzweg (Nx2)" Date: Thu, 8 Aug 2024 16:52:44 +0200 Subject: [PATCH] extra file for printing --- print_saved_results.py | 96 ++++++++++++++++++++++++++++++++++++++++++ test_small_llms.py | 26 ++++-------- 2 files changed, 105 insertions(+), 17 deletions(-) create mode 100644 print_saved_results.py diff --git a/print_saved_results.py b/print_saved_results.py new file mode 100644 index 0000000..22cd804 --- /dev/null +++ b/print_saved_results.py @@ -0,0 +1,96 @@ +import json +import os +import sys + +def print_help(): + print("""Example usages: + +python print_saved_results.py +python print_saved_results.py -m llama3.1 +python print_saved_results.py -m llama3.1,mixtral-nemo:12b +python print_saved_results.py -m llama3.1 -s 2222,2 -t "Healthy Vegetables in Chinese" + +Note: If one of the "fileters" does not exist, no error is thrown.""") + +def main(argv: list[str]) -> None: + try: + with open("./saved_results.json", "r") as f: + saved_results = json.load(fp=f) + except: + print("saved_results.json not found. Try running test_suite.py first.") + exit(1) + + if "-h" in argv: + print_help() + exit(0) + + try: + if "-m" in argv: + test_str = argv[argv.index("-m")+1] + assert test_str[0] != "-" + models = test_str.split(",") + argv.pop(argv.index("-m")+1) + argv.pop(argv.index("-m")) + else: + models = None + + if "-s" in argv: + test_str = argv[argv.index("-s")+1] + assert test_str[0] != "-" + seeds = test_str.split(",") + argv.pop(argv.index("-s")+1) + argv.pop(argv.index("-s")) + else: + seeds = None + + if "-t" in argv: + test_str = argv[argv.index("-t")+1] + assert test_str[0] != "-" + tests = test_str.split(",") + argv.pop(argv.index("-t")+1) + argv.pop(argv.index("-t")) + else: + tests = None + except: + print("Syntax error. Run `python print_saved_results.py -h` for help.") + print_help() + exit(1) + + argv.pop(0) # remove filename entry + if argv != []: + print("Syntax error. Run `python print_saved_results.py -h` for help.") + print(f"Got unkown argument{'s' if len(argv) != 1 else ''}: {argv}") + print_help() + exit(1) + + + first_print = True + term_size = os.get_terminal_size() + + for hash_key in saved_results: + result = saved_results[hash_key] + if models == None or result['model'] in models: + if seeds == None or str(result['seed']) in seeds: + if tests == None or result['test_name'] in tests: + if not first_print: print('-' * term_size.columns) + + print( + "\n" + + "\033[0;36mTest name:\033[0m " + + result['test_name'] + + "\n\033[0;36mModel:\033[0m " + + result['model'] + + "\n\033[0;36mSeed:\033[0m " + + str(result['seed']) + + "\n\033[0;36mValidation result:\033[0m " + + str(result['validation']) + + "\n\033[0;36mAnswer: »\033[0m" + + result['answer'] + + "\033[0;36m«\033[0m" + + "\n" + ) + + first_print = False + +if __name__ == "__main__": + main(argv=sys.argv) diff --git a/test_small_llms.py b/test_small_llms.py index 060f501..1d1af44 100644 --- a/test_small_llms.py +++ b/test_small_llms.py @@ -9,20 +9,22 @@ from pprint import pprint def main(): models = [ "llama3.1", # 8b - # "llama3.1:70b", - # "llama3-groq-tool-use", # latest - # "llama3-groq-tool-use:70b", + "llama3.1:70b", + "llama3-groq-tool-use", # latest + "llama3-groq-tool-use:70b", # "mixtral:8x7b", - # "mixtral:8x22b", + "mixtral:8x22b", # "gemma2:2b", # "phi3", # 3.8b # "tinyllama:1.1b", + "mistral-nemo:12b", + # "command-r-plus:104b", ] seeds = [ 2, - # 222, - # 22222, - # 2222222 + 222, + 22222, + 2222222 ] tests = [ Test( @@ -69,15 +71,5 @@ That means If the answer is not in Chinese the answer is NOT correct! Only if th base_url="http://bolt.hs-mittweida.de:11434" ) - print() - for hash_key in results: - result = results[hash_key] - print(f""" -\033[0;36mtest_name:\033[0m {result['test_name']} -\033[0;36mmodel:\033[0m {result['model']} -\033[0;36mseed:\033[0m {result['seed']} -\033[0;36mvalidation_result:\033[0m {result['validation']} -\033[0;36manswer: »\033[0m{result['answer']}\033[0;36m«\033[0m""") - if __name__ == "__main__": main()