mc
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,5 +3,5 @@
|
|||||||
.direnv
|
.direnv
|
||||||
.vscode
|
.vscode
|
||||||
saved_results.json
|
saved_results.json
|
||||||
*.png
|
*.eps
|
||||||
|
|
||||||
|
|||||||
BIN
assets/NewCM10-Regular.otf
Normal file
BIN
assets/NewCM10-Regular.otf
Normal file
Binary file not shown.
@@ -71,7 +71,7 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
|
|||||||
'seed': seed,
|
'seed': seed,
|
||||||
'technique_id': technique_id
|
'technique_id': technique_id
|
||||||
}
|
}
|
||||||
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
hash_key = nxhash(json.dumps(combination, sort_keys=True))
|
||||||
|
|
||||||
combination.update({
|
combination.update({
|
||||||
'test_name': test.name,
|
'test_name': test.name,
|
||||||
|
|||||||
@@ -157,9 +157,9 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def write_note(command: str) -> str:
|
def write_note(note: str) -> str:
|
||||||
"""Write a not with the current time to the database."""
|
"""Write a not with the current time to the database."""
|
||||||
return command
|
return "Written."
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def save_python_repl(command: str):
|
def save_python_repl(command: str):
|
||||||
@@ -171,6 +171,7 @@ def save_python_repl(command: str):
|
|||||||
"^ *subprocess\\.",
|
"^ *subprocess\\.",
|
||||||
"^ *(with)? ?open\\(",
|
"^ *(with)? ?open\\(",
|
||||||
"^ *shutil\\.",
|
"^ *shutil\\.",
|
||||||
|
"^ *requests\\.",
|
||||||
]
|
]
|
||||||
valid = True
|
valid = True
|
||||||
for pattern in blocked_patterns:
|
for pattern in blocked_patterns:
|
||||||
|
|||||||
0
suite_settings/__init__.py
Normal file
0
suite_settings/__init__.py
Normal file
@@ -33,10 +33,7 @@ tests = {
|
|||||||
"tools": {"add": add, "multiply": multiply},
|
"tools": {"add": add, "multiply": multiply},
|
||||||
},
|
},
|
||||||
validator=regex_match_any,
|
validator=regex_match_any,
|
||||||
validation_input={
|
validation_input={"patterns": ["33549659245", "33,549,659,245", "33.549.659.245"]},
|
||||||
"patterns": ["33549659245", "33,549,659,245", "33.549.659.245"]
|
|
||||||
# "patterns": ["3[,\. ]?3[,\. ]?5[,\. ]?4[,\. ]?9[,\. ]?6[,\. ]?5[,\. ]?9[,\. ]?2[,\. ]?4[,\. ]?5"] # Would accept 3.354.965.9245
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
120: Test(
|
120: Test(
|
||||||
name="Complex Multiplication",
|
name="Complex Multiplication",
|
||||||
@@ -47,10 +44,10 @@ tests = {
|
|||||||
"tools": {"add": add, "multiply": multiply},
|
"tools": {"add": add, "multiply": multiply},
|
||||||
},
|
},
|
||||||
validator=regex_match_any,
|
validator=regex_match_any,
|
||||||
validation_input={"patterns": ["6134205", "6.134.205", "6,134,205"]},
|
validation_input={"patterns": ["6134205", "6,134,205"]},
|
||||||
),
|
),
|
||||||
363: Test(
|
363: Test(
|
||||||
name="Complex Multiplication Python",
|
name="Python Remainder",
|
||||||
runnable=one_tool_call_answer,
|
runnable=one_tool_call_answer,
|
||||||
runnable_input={
|
runnable_input={
|
||||||
"system_msg": 'You are a helpful assistant.',
|
"system_msg": 'You are a helpful assistant.',
|
||||||
@@ -58,9 +55,7 @@ tests = {
|
|||||||
"tools": { "save_python_repl": save_python_repl },
|
"tools": { "save_python_repl": save_python_repl },
|
||||||
},
|
},
|
||||||
validator=regex_match_any,
|
validator=regex_match_any,
|
||||||
validation_input={
|
validation_input={"patterns": [ "236", "two ?hundred and thirty ?six", "two ?hundred thirty ?six" ]}
|
||||||
"patterns": [ "236", "two ?hundred and thirty ?six", "two ?hundred thirty ?six" ]
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
283: Test(
|
283: Test(
|
||||||
name="Notes from last Saturday",
|
name="Notes from last Saturday",
|
||||||
@@ -113,11 +108,11 @@ tests = {
|
|||||||
AIMessage("I'm afraid I cannot be of great help, since I obviously know charlotte way less than you, but last year you two went out to Cavalinons and you got her a rose necklace as a present. And she liked it. So maybe a pair of earrings would be something she'd like?", name="example_assistant"),
|
AIMessage("I'm afraid I cannot be of great help, since I obviously know charlotte way less than you, but last year you two went out to Cavalinons and you got her a rose necklace as a present. And she liked it. So maybe a pair of earrings would be something she'd like?", name="example_assistant"),
|
||||||
|
|
||||||
HumanMessage("Did I write down anything yesterday or the day before that?"),
|
HumanMessage("Did I write down anything yesterday or the day before that?"),
|
||||||
AIMessage( "", tool_calls=[{"name": "get_current_date_and_time", "args": {}, "id": "21"}]),
|
AIMessage("", tool_calls=[{"name": "get_current_date_and_time", "args": {}, "id": "21"}]),
|
||||||
ToolMessage("Wednesday the 7th of August 2024 16:23", tool_call_id="21"),
|
ToolMessage("Wednesday the 7th of August 2024 16:23", tool_call_id="21"),
|
||||||
AIMessage( "", tool_calls=[ { "name": "get_notes_in_timespan", "args": {"begin": "2024/08/05", "to": "2024/08/06"}, "id": "22"}]),
|
AIMessage("", tool_calls=[ { "name": "get_notes_in_timespan", "args": {"begin": "2024/08/05", "to": "2024/08/06"}, "id": "22"}]),
|
||||||
ToolMessage( "2024/08/05 11:45 Ask Dr. Mills about the side effects of the new medication he got me.\n\n2024/08/06 18:30 Pick up the dry cleaning on Thursday; they close early on Fridays.", tool_call_id="22"),
|
ToolMessage( "2024/08/05 11:45 Ask Dr. Mills about the side effects of the new medication he got me.\n\n2024/08/06 18:30 Pick up the dry cleaning on Thursday; they close early on Fridays.", tool_call_id="22"),
|
||||||
AIMessage( "Yes. I found two entries.\n- From yesterday stating that you wanted to pickup the dry cleaning on Thursday, because they close early on Fridays\n- From Monday a note saying that you want to ask Dr. Mills about the side effects of the new medication he got you.", name="example_assistant"),
|
AIMessage("Yes. I found two entries.\n- From yesterday stating that you wanted to pickup the dry cleaning on Thursday, because they close early on Fridays\n- From Monday a note saying that you want to ask Dr. Mills about the side effects of the new medication he got you.", name="example_assistant"),
|
||||||
],
|
],
|
||||||
"human_msg": "Last Saturday, who did grandma want me to call?",
|
"human_msg": "Last Saturday, who did grandma want me to call?",
|
||||||
"tools": {
|
"tools": {
|
||||||
|
|||||||
362
visualize.py
362
visualize.py
@@ -1,20 +1,24 @@
|
|||||||
|
from typing import Callable, Optional
|
||||||
import json
|
import json
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from suite_settings.models import models
|
from suite_settings.models import models
|
||||||
from suite_settings.techniques import techniques
|
from suite_settings.techniques import techniques
|
||||||
from suite_settings.tests import tests
|
from suite_settings.tests import tests
|
||||||
|
|
||||||
# Load the JSON data
|
|
||||||
with open('saved_results.json', 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# Convert JSON data into a DataFrame
|
FONT_FAMILY = "NewComputerModern08"
|
||||||
results = []
|
|
||||||
for test_hash, test_data in data.items():
|
|
||||||
results.append({
|
def get_df() -> pd.DataFrame:
|
||||||
|
with open('saved_results.json', 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
raw_data = []
|
||||||
|
for test_hash, test_data in data.items():
|
||||||
|
raw_data.append({
|
||||||
"hash": test_hash,
|
"hash": test_hash,
|
||||||
"model_name": models[test_data['model_id']].display_name,
|
"model_name": models[test_data['model_id']].display_name,
|
||||||
"model_size": models[test_data['model_id']].parameter_count_in_b,
|
"model_size": models[test_data['model_id']].parameter_count_in_b,
|
||||||
@@ -25,118 +29,282 @@ for test_hash, test_data in data.items():
|
|||||||
"validation": test_data['validation']
|
"validation": test_data['validation']
|
||||||
})
|
})
|
||||||
|
|
||||||
df = pd.DataFrame(results)
|
df = pd.DataFrame(raw_data)
|
||||||
|
|
||||||
|
# Categorical ordering for 'technique_name'
|
||||||
|
df['technique_name'] = pd.Categorical(
|
||||||
|
df['technique_name'],
|
||||||
|
categories=[
|
||||||
|
techniques[1].name,
|
||||||
|
techniques[572].name,
|
||||||
|
techniques[903].name
|
||||||
|
],
|
||||||
|
ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Categorical ordering for 'test_name'
|
||||||
|
df['test_name'] = pd.Categorical(
|
||||||
|
df['test_name'],
|
||||||
|
categories=[
|
||||||
|
tests[693].name,
|
||||||
|
tests[363].name,
|
||||||
|
tests[120].name,
|
||||||
|
tests[283].name,
|
||||||
|
tests[260].name,
|
||||||
|
tests[856].name
|
||||||
|
],
|
||||||
|
ordered=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by model_size first, then alphabetically by model_name
|
||||||
|
sorted_df = df.sort_values(['model_size', 'model_name'], ascending=[True, True])
|
||||||
|
|
||||||
|
return sorted_df
|
||||||
|
|
||||||
|
|
||||||
|
def insert_average_models(pt: pd.DataFrame, df: pd.DataFrame, pivot: int) -> pd.DataFrame:
|
||||||
|
# Use the df's model_size for calculations
|
||||||
|
model_sizes = df.groupby('model_technique')['model_size'].first()
|
||||||
|
|
||||||
df['technique_name'] = pd.Categorical(df['technique_name'], categories=[techniques[1].name, techniques[572].name, techniques[903].name],ordered=True)
|
# Split the pivot table into two groups based on model size
|
||||||
df['test_name'] = pd.Categorical(df['test_name'], categories=[tests[607].name, tests[693].name, tests[363].name, tests[120].name, tests[283].name, tests[260].name, tests[856].name],ordered=True)
|
up_to_pivot = pt.index[model_sizes.loc[pt.index] <= pivot]
|
||||||
sorted_df = df.sort_values('model_size')
|
above_pivot = pt.index[model_sizes.loc[pt.index] > pivot]
|
||||||
|
|
||||||
# Perform the groupby and unstack operation
|
# Calculate average pass rate for models up to and including 10B
|
||||||
result_df = (
|
avg_up_to_pivot = pt.loc[up_to_pivot].mean()
|
||||||
sorted_df.groupby(['model_name', 'validation']).size()
|
|
||||||
.unstack(fill_value=0) # Unstack and fill NaN with 0
|
|
||||||
)
|
|
||||||
|
|
||||||
## 1st Chart
|
# Find the last model with exactly 10B parameters
|
||||||
# Count the number of validation results for each technique_name
|
last_pivot_model_index = up_to_pivot[up_to_pivot.shape[0]-1]
|
||||||
validation_counts = result_df.loc[sorted_df['model_name'].drop_duplicates()]
|
|
||||||
validation_counts.columns = ['Failed', 'Passed']
|
|
||||||
|
|
||||||
# Plot the validation results by technique_name
|
# Reindex the pivot table to insert the new row after the last 10B model
|
||||||
plt.figure(figsize=(10, 6))
|
new_index = list(pt.index)
|
||||||
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
|
last_10b_position = new_index.index(last_pivot_model_index)
|
||||||
plt.title('Validation Results by Model and Technique')
|
|
||||||
plt.xlabel('Model and Technique')
|
# Insert the row "Average up to 10B" right after the last 10B model
|
||||||
plt.ylabel('Number of Tests')
|
new_index.insert(last_10b_position+1, f"Average up to {pivot}b")
|
||||||
plt.xticks(rotation=45, ha='right')
|
pt = pt.reindex(new_index)
|
||||||
plt.legend(title='Validation')
|
|
||||||
plt.tight_layout()
|
# Set the values for the "Average up to 10B" row
|
||||||
plt.savefig('model-bar-chart.png')
|
pt.loc[f"Average up to {pivot}b"] = avg_up_to_pivot
|
||||||
|
|
||||||
|
# Calculate the average pass rate for models larger than 10B
|
||||||
|
avg_above_pivot = pt.loc[above_pivot].mean()
|
||||||
|
# Add a new row for the average of models larger than 10B at the end
|
||||||
|
pt.loc[f"Average above {pivot}b"] = avg_above_pivot
|
||||||
|
|
||||||
|
# Calculate the average pass rate for models larger than 10B
|
||||||
|
avg_total = pt.loc[pt.index].mean()
|
||||||
|
# Add a new row for the average of models larger than 10B at the end
|
||||||
|
pt.loc["Average Total"] = avg_total
|
||||||
|
return pt
|
||||||
|
|
||||||
|
def insert_average_test_y(pt: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
avg = pt.loc[pt.index].mean()
|
||||||
|
pt.loc["Average"] = avg
|
||||||
|
return pt
|
||||||
|
|
||||||
|
def insert_average_test_x(pt: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
avg = pt.mean(axis=1)
|
||||||
|
pt["Average"] = avg
|
||||||
|
return pt
|
||||||
|
|
||||||
|
def insert_average_technique(pt: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
avg = pt.mean(axis=1) # Calculate the average across columns (axis=1)
|
||||||
|
pt['Average'] = avg # Insert the average as a new column
|
||||||
|
return pt
|
||||||
|
|
||||||
|
UNFITTING = [
|
||||||
|
903, # tinyllama
|
||||||
|
404, # llama3 groq TU
|
||||||
|
120, # llama3 groq TU 70b
|
||||||
|
890 # Command R+
|
||||||
|
]
|
||||||
|
|
||||||
|
def remove_unfitting(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
if len(UNFITTING) > 0:
|
||||||
|
dff = df.loc[df['model_name'] != models[UNFITTING[0]].display_name]
|
||||||
|
if len(UNFITTING) > 1:
|
||||||
|
for id in UNFITTING[1:]:
|
||||||
|
dff = dff.loc[dff['model_name'] != models[id].display_name]
|
||||||
|
return dff
|
||||||
|
return df
|
||||||
|
|
||||||
|
def trendline(df: pd.DataFrame) -> None:
|
||||||
|
# Step 1: Calculate pass rate for each model size
|
||||||
|
# Group by 'model_size' and calculate the percentage of runs validated as correct
|
||||||
|
df['validation'] = df['validation'].astype(int) # Convert validation boolean to 1/0
|
||||||
|
pass_rate_df = df.groupby('model_size').agg(
|
||||||
|
pass_rate=('validation', 'mean') # Mean gives us the percentage of correct validations
|
||||||
|
).reset_index()
|
||||||
|
|
||||||
|
# Step 2: Plotting
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.scatter(pass_rate_df['model_size'], pass_rate_df['pass_rate'] * 100, label='Pass Rate (%)', color='blue')
|
||||||
|
|
||||||
|
# Fit a trendline
|
||||||
|
z = np.polyfit(pass_rate_df['model_size'], pass_rate_df['pass_rate'] * 100, 1) # Linear trendline
|
||||||
|
p = np.poly1d(z)
|
||||||
|
plt.plot(pass_rate_df['model_size'], p(pass_rate_df['model_size']), color='red', label='Trendline')
|
||||||
|
|
||||||
|
# Set font for axis tick labels
|
||||||
|
font_properties = {'fontname': FONT_FAMILY, 'fontsize': 12}
|
||||||
|
plt.xticks(**font_properties)
|
||||||
|
plt.yticks(**font_properties)
|
||||||
|
|
||||||
|
# Step 3: Customize plot
|
||||||
|
plt.title('Model Size vs Pass Rate', font=FONT_FAMILY)
|
||||||
|
plt.xlabel('Model Size (in billions of parameters)', font=FONT_FAMILY)
|
||||||
|
plt.ylabel('Pass Rate (%)', font=FONT_FAMILY)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend(prop={'family': FONT_FAMILY})
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
plt.savefig('size-trendline.eps', format='eps', dpi=1200)
|
||||||
|
|
||||||
|
|
||||||
|
def heatmap_models_plus_techniues(df: pd.DataFrame, color: str, title: Optional[str]= None, get_weight: Optional[Callable]= None) -> None:
|
||||||
|
ordered_techniques = df['model_technique'].unique()
|
||||||
|
pt = pd.pivot_table(
|
||||||
|
df,
|
||||||
|
|
||||||
## 2nd Chart
|
|
||||||
# Plot the validation results by test name
|
|
||||||
test_name_counts = df.groupby(['test_name', 'validation']).size().unstack().fillna(0)
|
|
||||||
test_name_counts.columns = ['Failed', 'Passed']
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
test_name_counts.plot(kind='barh', stacked=True, color=['red', 'green'], ax=plt.gca())
|
|
||||||
plt.title('Validation Results by Test Name')
|
|
||||||
plt.xlabel('Number of Tests')
|
|
||||||
plt.ylabel('Test Name')
|
|
||||||
plt.legend(title='Validation')
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig('test-bar-chart.png')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sorted_df = df.sort_values('model_size' )
|
|
||||||
|
|
||||||
# Get the unique order of 'model_technique' based on sorted_df
|
|
||||||
ordered_techniques = sorted_df['model_technique'].unique()
|
|
||||||
|
|
||||||
# Create the pivot table with the correct order of model_technique
|
|
||||||
pass_rate = pd.pivot_table(
|
|
||||||
sorted_df,
|
|
||||||
values='validation',
|
values='validation',
|
||||||
index='model_technique',
|
index='model_technique',
|
||||||
columns='test_name',
|
columns='test_name',
|
||||||
|
observed=False,
|
||||||
aggfunc="mean",
|
aggfunc="mean",
|
||||||
fill_value=0
|
fill_value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reorder the rows in the pivot table based on the ordered techniques
|
pt = pt.loc[ordered_techniques]
|
||||||
pass_rate = pass_rate.loc[ordered_techniques]
|
|
||||||
|
|
||||||
# Plot the heatmap
|
if get_weight:
|
||||||
plt.figure(figsize=(8, 12))
|
def get_model_size_by_name(name: str) -> float:
|
||||||
sns.heatmap(
|
for id in models:
|
||||||
pass_rate * 100,
|
if name == models[id].display_name:
|
||||||
|
return models[id].parameter_count_in_b
|
||||||
|
raise IndexError(f"Model {name} not found in models.")
|
||||||
|
|
||||||
|
for (index, row) in pt.iterrows():
|
||||||
|
pt.loc[index] = row * get_weight([ get_model_size_by_name(index.split(":")[0]) ])
|
||||||
|
|
||||||
|
pt = insert_average_models(pt=pt, df=df, pivot=10)
|
||||||
|
pt = insert_average_test_x(pt=pt, df=df)
|
||||||
|
|
||||||
|
|
||||||
|
plt.figure(figsize=(8, 12))
|
||||||
|
|
||||||
|
sns.heatmap(
|
||||||
|
pt * 100,
|
||||||
annot=True,
|
annot=True,
|
||||||
fmt=".0f",
|
fmt=".0f" if pt.tail(1)['Average'].item() > 0.1 else ".1f",
|
||||||
cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True),
|
cmap=sns.color_palette(color, as_cmap=True),
|
||||||
cbar=True,
|
cbar=True,
|
||||||
annot_kws={"size": 10}
|
annot_kws={"size": 10, "fontname": FONT_FAMILY}
|
||||||
)
|
)
|
||||||
# Add percentage sign to annotations
|
|
||||||
for text in plt.gca().texts:
|
|
||||||
text.set_text(f"{text.get_text()}%")
|
|
||||||
|
|
||||||
# Customize the plot with labels and a title
|
# Update the annotations to display percentages
|
||||||
plt.title('Model Technique Performance on Each Test', fontsize=16)
|
for text in plt.gca().texts:
|
||||||
plt.xlabel('Test Name', fontsize=14)
|
o = text.get_text()
|
||||||
plt.ylabel('Model and Technique', fontsize=14)
|
text.set_text(f"{o if o != '0.0' else '0'}{'%' if not title else ''}")
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
# Rotate x-axis labels by 45 degrees
|
for text in plt.gca().yaxis.get_ticklabels():
|
||||||
plt.xticks(rotation=45, ha='right')
|
if 'average' in text.get_text().lower():
|
||||||
|
text.set_color('red')
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
# Adjust layout to ensure labels don't get cut off
|
for text in plt.gca().xaxis.get_ticklabels():
|
||||||
plt.tight_layout()
|
if 'average' in text.get_text().lower():
|
||||||
|
text.set_color('red')
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
# Save the heatmap
|
# Set fonts for titles, labels, and tick labels
|
||||||
plt.savefig('modelTechnique_heatmap.png')
|
plt.title(f'Model+Technique Performance{"" if not title else ": " + title + "adjsuted"}', fontsize=16, fontname=FONT_FAMILY)
|
||||||
|
plt.xlabel('Test Name', fontsize=14, fontname=FONT_FAMILY)
|
||||||
|
plt.ylabel('Model and Technique', fontsize=14, fontname=FONT_FAMILY)
|
||||||
|
|
||||||
|
plt.xticks(rotation=45, ha='right', fontname=FONT_FAMILY)
|
||||||
|
plt.yticks(fontname=FONT_FAMILY)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"modeles-plus-techniques-heatmap{'' if not title else '-' + title.lower()}.eps", format='eps', dpi=1200)
|
||||||
|
|
||||||
|
|
||||||
## 4th Chart: Technique Performance on Each Test (Aggregated Heatmap)
|
def heatmap_techniques(df: pd.DataFrame, color: str, title: Optional[str]= None, get_weight: Optional[Callable]= None) -> None:
|
||||||
technique_pass_rate = pd.pivot_table(sorted_df, values='validation', index='test_name', columns='technique_name', aggfunc="mean", fill_value=0)
|
pt = pd.pivot_table(
|
||||||
plt.figure(figsize=(8, 4))
|
df,
|
||||||
sns.heatmap(technique_pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
values='validation',
|
||||||
|
index='test_name',
|
||||||
|
observed=False,
|
||||||
|
columns='technique_name',
|
||||||
|
aggfunc="mean",
|
||||||
|
fill_value=0
|
||||||
|
)
|
||||||
|
if get_weight:
|
||||||
|
native = [ models[m].parameter_count_in_b for m in models if models[m].supports_tools ]
|
||||||
|
artificial = [ models[m].parameter_count_in_b for m in models if not models[m].supports_tools ]
|
||||||
|
weight_native = get_weight(native)
|
||||||
|
weight_artificial = get_weight(artificial)
|
||||||
|
pt['Native'] = pt['Native'] * weight_native
|
||||||
|
pt['LSM'] = pt['LSM'] * weight_artificial
|
||||||
|
pt['T2S'] = pt['T2S'] * weight_artificial
|
||||||
|
|
||||||
# Add percentage sign to annotations
|
pt = insert_average_test_y(pt=pt, df=df)
|
||||||
for text in plt.gca().texts:
|
pt = insert_average_technique(pt=pt, df=df)
|
||||||
text.set_text(f"{text.get_text()}%")
|
|
||||||
|
|
||||||
# Customize the plot with labels and a title
|
plt.figure(figsize=(8, 4))
|
||||||
plt.title('Technique Performance on Each Test', fontsize=16)
|
|
||||||
plt.ylabel('Test Name', fontsize=14)
|
sns.heatmap(
|
||||||
plt.xlabel('Technique', fontsize=14)
|
pt * 100,
|
||||||
plt.xticks(rotation=0)
|
annot=True,
|
||||||
plt.tight_layout()
|
fmt=".0f" if pt.tail(1)['Average'].item() > 0.2 else ".1f",
|
||||||
plt.savefig('technique_heatmap.png')
|
cmap=sns.color_palette(color, as_cmap=True),
|
||||||
|
cbar=True,
|
||||||
|
annot_kws={"size": 10, "fontname": FONT_FAMILY}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add percentage sign to annotations
|
||||||
|
for text in plt.gca().texts:
|
||||||
|
o = text.get_text()
|
||||||
|
text.set_text(f"{o if o != '0.0' else '0'}{'%' if not title else ''}")
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
|
for text in plt.gca().yaxis.get_ticklabels():
|
||||||
|
if 'average' in text.get_text().lower():
|
||||||
|
text.set_color('red')
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
|
for text in plt.gca().xaxis.get_ticklabels():
|
||||||
|
if 'average' in text.get_text().lower():
|
||||||
|
text.set_color('red')
|
||||||
|
text.set_fontname(FONT_FAMILY)
|
||||||
|
|
||||||
|
# Customize the plot with labels and a title
|
||||||
|
plt.title(f"Technique Performance{'' if not title else ': ' + title + ' adjusted'}", fontsize=16, fontname=FONT_FAMILY)
|
||||||
|
plt.ylabel('Test Name', fontsize=14, fontname=FONT_FAMILY)
|
||||||
|
plt.xlabel('Technique', fontsize=14, fontname=FONT_FAMILY)
|
||||||
|
|
||||||
|
plt.xticks(rotation=0, fontname=FONT_FAMILY)
|
||||||
|
plt.yticks(fontname=FONT_FAMILY)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"techniques-heatmap{'' if not title else '-' + title.lower()}.eps", format='eps', dpi=1200)
|
||||||
|
|
||||||
|
|
||||||
|
def size(sizes: list[float]) -> float:
|
||||||
|
return 100/(sum(sizes) / len(sizes))
|
||||||
|
|
||||||
|
def performance(sizes: list[float]) -> float:
|
||||||
|
weights_list = [ 1/np.log(x+1) for x in sizes ]
|
||||||
|
weight = sum(weights_list) / len(weights_list)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
df = get_df()
|
||||||
|
dff = remove_unfitting(df)
|
||||||
|
trendline(dff.copy())
|
||||||
|
heatmap_models_plus_techniues(df.copy(), color="blend:#100,#255,#4a3")
|
||||||
|
heatmap_models_plus_techniues(df.copy(), color="blend:#100,#236,#44a,#a4d,#fff,#ffc,#ffa,#ff7,#ff4,#ff0,#af0,#7f0,#3f0,#0f0", title="Size", get_weight=size)
|
||||||
|
heatmap_models_plus_techniues(df.copy(), color="blend:#100,#d44,#dc2,#dcc,#cff", title="Performance", get_weight=performance)
|
||||||
|
heatmap_techniques(df=dff.copy(), color="blend:#100,#255,#4a3")
|
||||||
|
heatmap_techniques(df=dff.copy(), color="blend:#100,#236,#44a,#a4d", title="Size", get_weight=size)
|
||||||
|
heatmap_techniques(df=dff.copy(), color="blend:#100,#d44,#dc2", title="Performance", get_weight=performance)
|
||||||
|
|||||||
Reference in New Issue
Block a user