diff --git a/visualize.py b/visualize.py index 2dd3b87..1e6db47 100644 --- a/visualize.py +++ b/visualize.py @@ -18,16 +18,19 @@ def get_df() -> pd.DataFrame: data = json.load(f) raw_data = [] for test_hash, test_data in data.items(): - raw_data.append({ - "hash": test_hash, - "model_name": models[test_data['model_id']].display_name, - "model_size": models[test_data['model_id']].parameter_count_in_b, - "technique_name": techniques[test_data['technique_id']].name, - "model_technique": f"{models[test_data['model_id']].display_name}:{ techniques[test_data['technique_id']].name}", - "seed": test_data['seed'], - "test_name": tests[test_data['test_id']].name, - "validation": test_data['validation'] - }) + try: + raw_data.append({ + "hash": test_hash, + "model_name": models[test_data['model_id']].display_name, + "model_size": models[test_data['model_id']].parameter_count_in_b, + "technique_name": techniques[test_data['technique_id']].name, + "model_technique": f"{models[test_data['model_id']].display_name}:{ techniques[test_data['technique_id']].name}", + "seed": test_data['seed'], + "test_name": tests[test_data['test_id']].name, + "validation": test_data['validation'] + }) + except KeyError: + pass df = pd.DataFrame(raw_data)