Files
test-small-llms/visualize.py
Lennart J. Kurzweg (Nx2) a578dd26a0 mega commit
2024-08-20 20:47:17 +02:00

87 lines
2.4 KiB
Python

import json
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
# Load the JSON data
with open('saved_results.json', 'r') as f:
data = json.load(f)
# Convert JSON data into a DataFrame
results = []
for test_hash, test_data in data.items():
results.append({
"hash": test_hash,
"model": test_data['model_name'],
"seed": test_data['seed'],
"test_name": test_data['test_name'],
"validation": test_data['validation']
})
df = pd.DataFrame(results)
## 1st Chart
# Count the number of validation results for each model
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0)
validation_counts.columns = ['Failed', 'Passed']
# Plot the validation results by model
plt.figure(figsize=(10, 6))
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
plt.title('Validation Results by Model')
plt.xlabel('Model')
plt.ylabel('Number of Tests')
plt.xticks(rotation=45, ha='right')
plt.legend(title='Validation')
plt.tight_layout()
plt.savefig('validation_results_by_model.png')
## 2nd Chart
# Plot the validation results by test name
test_name_counts = df.groupby(['test_name', 'validation']).size().unstack().fillna(0)
test_name_counts.columns = ['Failed', 'Passed']
plt.figure(figsize=(10, 6))
test_name_counts.plot(kind='barh', stacked=True, color=['red', 'green'], ax=plt.gca())
plt.title('Validation Results by Test Name')
plt.xlabel('Number of Tests')
plt.ylabel('Test Name')
plt.legend(title='Validation')
plt.tight_layout()
plt.savefig('validation_results_by_test_name.png')
## 3rd Chart
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
# Create a heatmap
plt.figure(figsize=(8, 8))
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
# Add percentage sign to annotations
for text in plt.gca().texts:
text.set_text(f"{text.get_text()}%")
# Customize the plot with labels and a title
plt.title('Model Performance on Each Test', fontsize=16)
plt.xlabel('Test Name', fontsize=14)
plt.ylabel('Model', fontsize=14)
# Rotate x-axis labels by 45 degrees
plt.xticks(rotation=45, ha='right')
# Adjust layout to ensure labels don't get cut off
plt.tight_layout()
# Save the heatmap
plt.savefig('model_performance_heatmap.png')