visualize
This commit is contained in:
132
visualize.py
Normal file
132
visualize.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
from math import pi
|
||||||
|
|
||||||
|
# Load the JSON data
|
||||||
|
with open('saved_results.json', 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Convert JSON data into a DataFrame
|
||||||
|
results = []
|
||||||
|
for test_hash, test_data in data.items():
|
||||||
|
results.append({
|
||||||
|
"hash": test_hash,
|
||||||
|
"model": test_data['model'],
|
||||||
|
"seed": test_data['seed'],
|
||||||
|
"test_name": test_data['test_name'],
|
||||||
|
"validation": test_data['validation']
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 1st Chart
|
||||||
|
# Count the number of validation results for each model
|
||||||
|
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0)
|
||||||
|
validation_counts.columns = ['Failed', 'Passed']
|
||||||
|
|
||||||
|
# Plot the validation results by model
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
|
||||||
|
plt.title('Validation Results by Model')
|
||||||
|
plt.xlabel('Model')
|
||||||
|
plt.ylabel('Number of Tests')
|
||||||
|
plt.xticks(rotation=45, ha='right')
|
||||||
|
plt.legend(title='Validation')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('validation_results_by_model.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 2nd Chart
|
||||||
|
# Plot the validation results by test name
|
||||||
|
test_name_counts = df.groupby(['test_name', 'validation']).size().unstack().fillna(0)
|
||||||
|
test_name_counts.columns = ['Failed', 'Passed']
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
test_name_counts.plot(kind='barh', stacked=True, color=['red', 'green'], ax=plt.gca())
|
||||||
|
plt.title('Validation Results by Test Name')
|
||||||
|
plt.xlabel('Number of Tests')
|
||||||
|
plt.ylabel('Test Name')
|
||||||
|
plt.legend(title='Validation')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('validation_results_by_test_name.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 3rd Chart
|
||||||
|
# Prepare data for the spider chart
|
||||||
|
models = df['model'].unique()
|
||||||
|
|
||||||
|
# Calculate the pass rate for each model on each test
|
||||||
|
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
||||||
|
tests = df['test_name'].unique().tolist()
|
||||||
|
|
||||||
|
# Initialize the spider plot
|
||||||
|
num_vars = len(pass_rate)-1
|
||||||
|
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
||||||
|
angles += [ angles[0] ]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
||||||
|
|
||||||
|
# Plot each model's performance
|
||||||
|
for model in models:
|
||||||
|
values = pass_rate.loc[model].tolist()
|
||||||
|
values += [ values[0] ]
|
||||||
|
ax.fill(angles, values, alpha=0.25)
|
||||||
|
ax.plot(angles, values, label=model)
|
||||||
|
#
|
||||||
|
|
||||||
|
# Configure the spider chart
|
||||||
|
ax.set_theta_offset(pi / 2)
|
||||||
|
ax.set_theta_direction(-1)
|
||||||
|
|
||||||
|
tests.append(tests[0])
|
||||||
|
tests.pop(0)
|
||||||
|
|
||||||
|
ax.set_xticks(angles[:-1])
|
||||||
|
ax.set_xticklabels(tests)
|
||||||
|
|
||||||
|
ax.set_yticks(np.linspace(0, 1, 5))
|
||||||
|
ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8)
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
|
||||||
|
plt.title('Model Performance on Each Test')
|
||||||
|
plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('model_performance_spider_chart.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 4th chart
|
||||||
|
# Create a heatmap
|
||||||
|
plt.figure(figsize=(8, 8))
|
||||||
|
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||||
|
|
||||||
|
# Add percentage sign to annotations
|
||||||
|
for text in plt.gca().texts:
|
||||||
|
text.set_text(f"{text.get_text()}%")
|
||||||
|
|
||||||
|
# Customize the plot with labels and a title
|
||||||
|
plt.title('Model Performance on Each Test', fontsize=16)
|
||||||
|
plt.xlabel('Test Name', fontsize=14)
|
||||||
|
plt.ylabel('Model', fontsize=14)
|
||||||
|
|
||||||
|
# Rotate x-axis labels by 45 degrees
|
||||||
|
plt.xticks(rotation=45, ha='right')
|
||||||
|
|
||||||
|
# Adjust layout to ensure labels don't get cut off
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save the heatmap
|
||||||
|
plt.savefig('model_performance_heatmap.png')
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user