mega commit
This commit is contained in:
48
visualize.py
48
visualize.py
@@ -3,7 +3,6 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from math import pi
|
||||
|
||||
# Load the JSON data
|
||||
with open('saved_results.json', 'r') as f:
|
||||
@@ -14,7 +13,7 @@ results = []
|
||||
for test_hash, test_data in data.items():
|
||||
results.append({
|
||||
"hash": test_hash,
|
||||
"model": test_data['model'],
|
||||
"model": test_data['model_name'],
|
||||
"seed": test_data['seed'],
|
||||
"test_name": test_data['test_name'],
|
||||
"validation": test_data['validation']
|
||||
@@ -61,52 +60,7 @@ plt.savefig('validation_results_by_test_name.png')
|
||||
|
||||
|
||||
## 3rd Chart
|
||||
# Prepare data for the spider chart
|
||||
models = df['model'].unique()
|
||||
|
||||
# Calculate the pass rate for each model on each test
|
||||
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
||||
tests = df['test_name'].unique().tolist()
|
||||
|
||||
# Initialize the spider plot
|
||||
num_vars = len(pass_rate)-1
|
||||
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
||||
angles += [ angles[0] ]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
||||
|
||||
# Plot each model's performance
|
||||
for model in models:
|
||||
values = pass_rate.loc[model].tolist()
|
||||
values += [ values[0] ]
|
||||
ax.fill(angles, values, alpha=0.25)
|
||||
ax.plot(angles, values, label=model)
|
||||
#
|
||||
|
||||
# Configure the spider chart
|
||||
ax.set_theta_offset(pi / 2)
|
||||
ax.set_theta_direction(-1)
|
||||
|
||||
tests.append(tests[0])
|
||||
tests.pop(0)
|
||||
|
||||
ax.set_xticks(angles[:-1])
|
||||
ax.set_xticklabels(tests)
|
||||
|
||||
ax.set_yticks(np.linspace(0, 1, 5))
|
||||
ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8)
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
plt.title('Model Performance on Each Test')
|
||||
plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
|
||||
plt.tight_layout()
|
||||
plt.savefig('model_performance_spider_chart.png')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 4th chart
|
||||
# Create a heatmap
|
||||
plt.figure(figsize=(8, 8))
|
||||
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||
|
||||
Reference in New Issue
Block a user