visualize

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-14 21:01:38 +02:00
parent 11f37009d3
commit 15973d723f

132
visualize.py Normal file
View File

@@ -0,0 +1,132 @@
import json
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from math import pi
# Load the JSON data
with open('saved_results.json', 'r') as f:
data = json.load(f)
# Convert JSON data into a DataFrame
results = []
for test_hash, test_data in data.items():
results.append({
"hash": test_hash,
"model": test_data['model'],
"seed": test_data['seed'],
"test_name": test_data['test_name'],
"validation": test_data['validation']
})
df = pd.DataFrame(results)
## 1st Chart
# Count the number of validation results for each model
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0)
validation_counts.columns = ['Failed', 'Passed']
# Plot the validation results by model
plt.figure(figsize=(10, 6))
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
plt.title('Validation Results by Model')
plt.xlabel('Model')
plt.ylabel('Number of Tests')
plt.xticks(rotation=45, ha='right')
plt.legend(title='Validation')
plt.tight_layout()
plt.savefig('validation_results_by_model.png')
## 2nd Chart
# Plot the validation results by test name
test_name_counts = df.groupby(['test_name', 'validation']).size().unstack().fillna(0)
test_name_counts.columns = ['Failed', 'Passed']
plt.figure(figsize=(10, 6))
test_name_counts.plot(kind='barh', stacked=True, color=['red', 'green'], ax=plt.gca())
plt.title('Validation Results by Test Name')
plt.xlabel('Number of Tests')
plt.ylabel('Test Name')
plt.legend(title='Validation')
plt.tight_layout()
plt.savefig('validation_results_by_test_name.png')
## 3rd Chart
# Prepare data for the spider chart
models = df['model'].unique()
# Calculate the pass rate for each model on each test
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
tests = df['test_name'].unique().tolist()
# Initialize the spider plot
num_vars = len(pass_rate)-1
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += [ angles[0] ]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
# Plot each model's performance
for model in models:
values = pass_rate.loc[model].tolist()
values += [ values[0] ]
ax.fill(angles, values, alpha=0.25)
ax.plot(angles, values, label=model)
#
# Configure the spider chart
ax.set_theta_offset(pi / 2)
ax.set_theta_direction(-1)
tests.append(tests[0])
tests.pop(0)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(tests)
ax.set_yticks(np.linspace(0, 1, 5))
ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8)
ax.set_ylim(0, 1)
plt.title('Model Performance on Each Test')
plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
plt.tight_layout()
plt.savefig('model_performance_spider_chart.png')
# 4th chart
# Create a heatmap
plt.figure(figsize=(8, 8))
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
# Add percentage sign to annotations
for text in plt.gca().texts:
text.set_text(f"{text.get_text()}%")
# Customize the plot with labels and a title
plt.title('Model Performance on Each Test', fontsize=16)
plt.xlabel('Test Name', fontsize=14)
plt.ylabel('Model', fontsize=14)
# Rotate x-axis labels by 45 degrees
plt.xticks(rotation=45, ha='right')
# Adjust layout to ensure labels don't get cut off
plt.tight_layout()
# Save the heatmap
plt.savefig('model_performance_heatmap.png')