97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
|
|
def print_help():
|
|
print("""Example usages:
|
|
|
|
python print_saved_results.py
|
|
python print_saved_results.py -m llama3.1
|
|
python print_saved_results.py -m llama3.1,mixtral-nemo:12b
|
|
python print_saved_results.py -m llama3.1 -s 2222,2 -t "Healthy Vegetables in Chinese"
|
|
|
|
Note: If one of the "fileters" does not exist, no error is thrown.""")
|
|
|
|
def main(argv: list[str]) -> None:
|
|
try:
|
|
with open("./saved_results.json", "r") as f:
|
|
saved_results = json.load(fp=f)
|
|
except:
|
|
print("saved_results.json not found. Try running test_suite.py first.")
|
|
exit(1)
|
|
|
|
if "-h" in argv:
|
|
print_help()
|
|
exit(0)
|
|
|
|
try:
|
|
if "-m" in argv:
|
|
test_str = argv[argv.index("-m")+1]
|
|
assert test_str[0] != "-"
|
|
models = test_str.split(",")
|
|
argv.pop(argv.index("-m")+1)
|
|
argv.pop(argv.index("-m"))
|
|
else:
|
|
models = None
|
|
|
|
if "-s" in argv:
|
|
test_str = argv[argv.index("-s")+1]
|
|
assert test_str[0] != "-"
|
|
seeds = test_str.split(",")
|
|
argv.pop(argv.index("-s")+1)
|
|
argv.pop(argv.index("-s"))
|
|
else:
|
|
seeds = None
|
|
|
|
if "-t" in argv:
|
|
test_str = argv[argv.index("-t")+1]
|
|
assert test_str[0] != "-"
|
|
tests = test_str.split(",")
|
|
argv.pop(argv.index("-t")+1)
|
|
argv.pop(argv.index("-t"))
|
|
else:
|
|
tests = None
|
|
except:
|
|
print("Syntax error. Run `python print_saved_results.py -h` for help.")
|
|
print_help()
|
|
exit(1)
|
|
|
|
argv.pop(0) # remove filename entry
|
|
if argv != []:
|
|
print("Syntax error. Run `python print_saved_results.py -h` for help.")
|
|
print(f"Got unkown argument{'s' if len(argv) != 1 else ''}: {argv}")
|
|
print_help()
|
|
exit(1)
|
|
|
|
|
|
first_print = True
|
|
term_size = os.get_terminal_size()
|
|
|
|
for hash_key in saved_results:
|
|
result = saved_results[hash_key]
|
|
if models == None or result['model'] in models:
|
|
if seeds == None or str(result['seed']) in seeds:
|
|
if tests == None or result['test_name'] in tests:
|
|
if not first_print: print('-' * term_size.columns)
|
|
|
|
print(
|
|
"\n" +
|
|
"\033[0;36mTest name:\033[0m " +
|
|
result['test_name'] +
|
|
"\n\033[0;36mModel:\033[0m " +
|
|
result['model'] +
|
|
"\n\033[0;36mSeed:\033[0m " +
|
|
str(result['seed']) +
|
|
"\n\033[0;36mValidation result:\033[0m " +
|
|
str(result['validation']) +
|
|
"\n\033[0;36mAnswer: »\033[0m" +
|
|
result['answer'] +
|
|
"\033[0;36m«\033[0m" +
|
|
"\n"
|
|
)
|
|
|
|
first_print = False
|
|
|
|
if __name__ == "__main__":
|
|
main(argv=sys.argv)
|