vis
This commit is contained in:
92
visualize.py
92
visualize.py
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
|
||||
from suite_settings.models import models
|
||||
from suite_settings.techniques import techniques
|
||||
from suite_settings.tests import tests
|
||||
|
||||
# Load the JSON data
|
||||
with open('saved_results.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
@@ -13,9 +16,12 @@ results = []
|
||||
for test_hash, test_data in data.items():
|
||||
results.append({
|
||||
"hash": test_hash,
|
||||
"model": test_data['model_name'],
|
||||
"model_name": models[test_data['model_id']].display_name,
|
||||
"model_size": models[test_data['model_id']].parameter_count_in_b,
|
||||
"technique_name": techniques[test_data['technique_id']].name,
|
||||
"model_technique": f"{models[test_data['model_id']].display_name}:{ techniques[test_data['technique_id']].name}",
|
||||
"seed": test_data['seed'],
|
||||
"test_name": test_data['test_name'],
|
||||
"test_name": tests[test_data['test_id']].name,
|
||||
"validation": test_data['validation']
|
||||
})
|
||||
|
||||
@@ -23,21 +29,34 @@ df = pd.DataFrame(results)
|
||||
|
||||
|
||||
|
||||
df['technique_name'] = pd.Categorical(df['technique_name'], categories=[techniques[1].name, techniques[572].name, techniques[903].name],ordered=True)
|
||||
df['test_name'] = pd.Categorical(df['test_name'], categories=[tests[607].name, tests[693].name, tests[120].name, tests[283].name, tests[260].name, tests[856].name],ordered=True)
|
||||
sorted_df = df.sort_values('model_size')
|
||||
|
||||
# Perform the groupby and unstack operation
|
||||
result_df = (
|
||||
sorted_df.groupby(['model_name', 'validation']).size()
|
||||
.unstack(fill_value=0) # Unstack and fill NaN with 0
|
||||
)
|
||||
|
||||
## 1st Chart
|
||||
# Count the number of validation results for each model
|
||||
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0)
|
||||
# Count the number of validation results for each technique_name
|
||||
validation_counts = result_df.loc[sorted_df['model_name'].drop_duplicates()]
|
||||
validation_counts.columns = ['Failed', 'Passed']
|
||||
|
||||
# Plot the validation results by model
|
||||
# Plot the validation results by technique_name
|
||||
plt.figure(figsize=(10, 6))
|
||||
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
|
||||
plt.title('Validation Results by Model')
|
||||
plt.xlabel('Model')
|
||||
plt.title('Validation Results by Model and Technique')
|
||||
plt.xlabel('Model and Technique')
|
||||
plt.ylabel('Number of Tests')
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.legend(title='Validation')
|
||||
plt.tight_layout()
|
||||
plt.savefig('validation_results_by_model.png')
|
||||
plt.savefig('model-bar-chart.png')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -54,25 +73,46 @@ plt.xlabel('Number of Tests')
|
||||
plt.ylabel('Test Name')
|
||||
plt.legend(title='Validation')
|
||||
plt.tight_layout()
|
||||
plt.savefig('validation_results_by_test_name.png')
|
||||
plt.savefig('test-bar-chart.png')
|
||||
|
||||
|
||||
|
||||
sorted_df = df.sort_values('model_size' )
|
||||
|
||||
## 3rd Chart
|
||||
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
||||
# Create a heatmap
|
||||
plt.figure(figsize=(8, 8))
|
||||
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||
# Get the unique order of 'model_technique' based on sorted_df
|
||||
ordered_techniques = sorted_df['model_technique'].unique()
|
||||
|
||||
# Create the pivot table with the correct order of model_technique
|
||||
pass_rate = pd.pivot_table(
|
||||
sorted_df,
|
||||
values='validation',
|
||||
index='model_technique',
|
||||
columns='test_name',
|
||||
aggfunc="mean",
|
||||
fill_value=0
|
||||
)
|
||||
|
||||
# Reorder the rows in the pivot table based on the ordered techniques
|
||||
pass_rate = pass_rate.loc[ordered_techniques]
|
||||
|
||||
# Plot the heatmap
|
||||
plt.figure(figsize=(8, 10))
|
||||
sns.heatmap(
|
||||
pass_rate * 100,
|
||||
annot=True,
|
||||
fmt=".0f",
|
||||
cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True),
|
||||
cbar=True,
|
||||
annot_kws={"size": 10}
|
||||
)
|
||||
# Add percentage sign to annotations
|
||||
for text in plt.gca().texts:
|
||||
text.set_text(f"{text.get_text()}%")
|
||||
|
||||
# Customize the plot with labels and a title
|
||||
plt.title('Model Performance on Each Test', fontsize=16)
|
||||
plt.title('Model Technique Performance on Each Test', fontsize=16)
|
||||
plt.xlabel('Test Name', fontsize=14)
|
||||
plt.ylabel('Model', fontsize=14)
|
||||
plt.ylabel('Model and Technique', fontsize=14)
|
||||
|
||||
# Rotate x-axis labels by 45 degrees
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
@@ -81,6 +121,22 @@ plt.xticks(rotation=45, ha='right')
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the heatmap
|
||||
plt.savefig('model_performance_heatmap.png')
|
||||
plt.savefig('modelTechnique_heatmap.png')
|
||||
|
||||
|
||||
## 4th Chart: Technique Performance on Each Test (Aggregated Heatmap)
|
||||
technique_pass_rate = pd.pivot_table(sorted_df, values='validation', index='test_name', columns='technique_name', aggfunc="mean", fill_value=0)
|
||||
plt.figure(figsize=(8, 4))
|
||||
sns.heatmap(technique_pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||
|
||||
# Add percentage sign to annotations
|
||||
for text in plt.gca().texts:
|
||||
text.set_text(f"{text.get_text()}%")
|
||||
|
||||
# Customize the plot with labels and a title
|
||||
plt.title('Technique Performance on Each Test', fontsize=16)
|
||||
plt.ylabel('Test Name', fontsize=14)
|
||||
plt.xlabel('Technique', fontsize=14)
|
||||
plt.xticks(rotation=0)
|
||||
plt.tight_layout()
|
||||
plt.savefig('technique_heatmap.png')
|
||||
|
||||
Reference in New Issue
Block a user