This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-25 20:10:53 +02:00
parent a578dd26a0
commit 2723ced901
8 changed files with 307 additions and 229 deletions

View File

@@ -1,4 +1,5 @@
from libs.classes import Test, Model
from os import name
from libs.classes import Technique, Test, Model
from libs.functions import nxhash
from typing import Union
@@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int:
collection_type = "models"
elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "tests"
elif isinstance(collection[list(collection.keys())[0]], Technique):
collection_type = "techniques"
else:
raise TypeError("get_len: unsupported collection_type")
else:
@@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int:
case "tests":
for test_id in collection:
maximum_length = max(maximum_length, len(collection[test_id].name))
case "techniques":
for technique_id in collection:
maximum_length = max(maximum_length, len(collection[technique_id].name))
case _:
for model_name in collection:
raise TypeError("get_len: unsupported collection_type")
@@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str):
try:
print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f:
@@ -53,88 +59,109 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
model = models[model_id]
for test_id in tests:
test = tests[test_id]
for seed in seeds:
# Init dict
combination = {
'test_id': test_id,
'model_id': model_id,
'seed': seed,
}
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
combination['test_name'] = test.name
combination['model_name'] = model.display_name
# if hash_key == "DE3D137E":
# pass
for technique_id in techniques:
technique = techniques[technique_id]
if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)):
continue
for seed in seeds:
# Init dict
combination = {
'test_id': test_id,
'model_id': model_id,
'seed': seed,
'technique_id': technique_id
}
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
if hash_key not in saved_results.keys():
try:
print("\033[0;35mModel '\033[0m" +
model.display_name +
"\033[0;35m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;35m now runs test '\033[0m" +
test.name +
"\033[0;35m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;35m)\033[0m",
end=""
)
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
if isinstance(answer, str):
combination['answer'] = answer
# combination['tool_calls'] = [] # no entry
del answer
elif isinstance(answer, dict): # calls
combination['answer'] = answer['answer']
combination['tool_calls'] = answer['tool_calls']
del answer
else:
raise Exception(f"runnable returned unkown type {type(answer)}.")
combination.update({
'test_name': test.name,
'model_name': model.display_name,
'technique_name': technique.name,
})
# if hash_key == "DE3D137E":
# pass
if hash_key not in saved_results.keys():
try:
print("\033[0;35mModel '\033[0m" +
model.display_name +
"\033[0;35m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;35m using technique '\033[0m" +
technique.name +
"\033[0;35m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;35m now runs test '\033[0m" +
test.name +
"\033[0;35m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;35m)\033[0m",
end=""
)
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
if isinstance(answer, str):
combination['answer'] = answer
# combination['tool_calls'] = [] # no entry
del answer
elif isinstance(answer, dict): # calls
combination['answer'] = answer['answer']
combination['tool_calls'] = answer['tool_calls']
del answer
else:
raise Exception(f"runnable returned unkown type {type(answer)}.")
combination['test'] = test
run_results[hash_key] = combination
print("\r\033[0;32mModel '\033[0m" +
combination['test'] = test
run_results[hash_key] = combination
print("\r\033[0;32mModel '\033[0m" +
model.display_name +
"\033[0;32m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;32m using technique '\033[0m" +
technique.name +
"\033[0;32m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;32m finished test '\033[0m" +
test.name +
"\033[0;32m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;32m)\033[0m"
)
except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else:
print("\r\033[0;34mModel '\033[0m" +
model.display_name +
"\033[0;32m'" +
"\033[0;34m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;32m finished test '\033[0m" +
"\033[0;34m using technique '\033[0m" +
technique.name +
"\033[0;34m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;34m skipped test '\033[0m" +
test.name +
"\033[0;32m'" +
"\033[0;34m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;32m)\033[0m"
"\033[0;34m) becasue its results exists in saved_results.json\033[0m"
)
except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else:
print("\r\033[0;34mModel '\033[0m" +
model.display_name +
"\033[0;34m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;34m skipped test '\033[0m" +
test.name +
"\033[0;34m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;34m) becasue its results exists in saved_results.json\033[0m"
)
# Validate Results