Merge branch 'master' of ssh://git.nx2.site:20022/nx2/test-small-llms

This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-28 20:46:41 +02:00
3 changed files with 130 additions and 31 deletions

View File

@@ -4,19 +4,22 @@ from re import search
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
from langchain_core.tools import Tool
from langchain_experimental.utilities import PythonREPL
@tool @tool
def add(a: float, b: float) -> str: def add(a: float, b: float) -> str:
"""Adds a+b and returns the sum""" """Adds a+b and returns the sum"""
af = float(a) af = float(a)
bf = float(b) bf = float(b)
return f"{a} + {b} = {a+b}" return f"{af} + {af} = {af+bf}"
@tool @tool
def multiply(a: float, b: float) -> str: def multiply(a: float, b: float) -> str:
"""Multiplies a*b and returns the product""" """Multiplies a*b and returns the product"""
af = float(a) af = float(a)
bf = float(b) bf = float(b)
return f"{a} * {b} = {a*b}" return f"{af} * {bf} = {af*bf}"
@tool @tool
def get_current_date_and_time() -> str: def get_current_date_and_time() -> str:
@@ -99,10 +102,13 @@ def get_notes_in_timespan(begin: str, to: str) -> str:
try: try:
begin_d = datetime.strptime(begin, "%Y/%m/%d") begin_d = datetime.strptime(begin, "%Y/%m/%d")
to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M") to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M")
except: return "Error: Invalid input. Date format is %Y/%m/%d" except ValueError:
return "Error: Invalid input. Date format is %Y/%m/%d"
try: assert begin_d < to_d try:
except: return "Error: from time has to be before to time." assert begin_d < to_d
except AssertionError:
return "Error: from time has to be before to time."
filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d] filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d]
@@ -128,9 +134,12 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
exaples: exaples:
{"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine""" {"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine"""
if isinstance(patterns, list): big_pattern = '|'.join(f"({s})" for s in patterns) if isinstance(patterns, list):
elif isinstance(patterns, str): big_pattern = patterns big_pattern = '|'.join(f"({s})" for s in patterns)
else: return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}." elif isinstance(patterns, str):
big_pattern = patterns
else:
return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}."
filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())] filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())]
@@ -147,7 +156,29 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
return ret return ret
@tool @tool
def write_note(content: str) -> str: def write_note(command: str) -> str:
"""Write a not with the current time to the database.""" """Write a not with the current time to the database."""
return content return command
@tool
def save_python_repl(command: str):
"""Simulates the normal python repl but with certain patterns blocked for savety reasons"""
python_repl = PythonREPL()
blocked_patterns = [
"^os\\.",
"^subprocess\\.",
"^with open\\(",
]
valid = True
for pattern in blocked_patterns:
if search(pattern, command):
valid = False
break
if valid:
return python_repl.run(command)
else:
return f"Command not executed, becaise the blocked pattern `{pattern}` was found in the command."

View File

@@ -1,7 +1,7 @@
from libs.classes import Test from libs.classes import Test
from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools
from libs.validators import regex_match_any, system_human_answer_match from libs.validators import regex_match_any, system_human_answer_match
from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note, save_python_repl
from textwrap import dedent from textwrap import dedent
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
@@ -49,6 +49,19 @@ tests = {
validator=regex_match_any, validator=regex_match_any,
validation_input={"patterns": ["6134205", "6.134.205", "6,134,205"]}, validation_input={"patterns": ["6134205", "6.134.205", "6,134,205"]},
), ),
363: Test(
name="Complex Multiplication Python",
runnable=one_tool_call_answer,
runnable_input={
"system_msg": 'You are a helpful assistant.',
"human_msg": 'Is 31515261 divisible by 425? If not, whats the remainder?',
"tools": { "python_repl": save_python_repl },
},
validator=regex_match_any,
validation_input={
"patterns": [ "236", "two ?hundred and thirty ?six", "two ?hundred thirty ?six" ]
}
),
283: Test( 283: Test(
name="Notes from last Saturday", name="Notes from last Saturday",
runnable=agent_with_tools, runnable=agent_with_tools,
@@ -119,7 +132,6 @@ tests = {
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
}, },
), ),
# 363: Test(),
# 600: Test(), # 600: Test(),
# 221: Test(), # 221: Test(),
# 985: Test(), # 985: Test(),

View File

@@ -1,9 +1,12 @@
import json import json
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import numpy as np
import seaborn as sns import seaborn as sns
from suite_settings.models import models
from suite_settings.techniques import techniques
from suite_settings.tests import tests
# Load the JSON data # Load the JSON data
with open('saved_results.json', 'r') as f: with open('saved_results.json', 'r') as f:
data = json.load(f) data = json.load(f)
@@ -13,9 +16,12 @@ results = []
for test_hash, test_data in data.items(): for test_hash, test_data in data.items():
results.append({ results.append({
"hash": test_hash, "hash": test_hash,
"model": test_data['model_name'], "model_name": models[test_data['model_id']].display_name,
"model_size": models[test_data['model_id']].parameter_count_in_b,
"technique_name": techniques[test_data['technique_id']].name,
"model_technique": f"{models[test_data['model_id']].display_name}:{ techniques[test_data['technique_id']].name}",
"seed": test_data['seed'], "seed": test_data['seed'],
"test_name": test_data['test_name'], "test_name": tests[test_data['test_id']].name,
"validation": test_data['validation'] "validation": test_data['validation']
}) })
@@ -23,21 +29,34 @@ df = pd.DataFrame(results)
df['technique_name'] = pd.Categorical(df['technique_name'], categories=[techniques[1].name, techniques[572].name, techniques[903].name],ordered=True)
df['test_name'] = pd.Categorical(df['test_name'], categories=[tests[607].name, tests[693].name, tests[120].name, tests[283].name, tests[260].name, tests[856].name],ordered=True)
sorted_df = df.sort_values('model_size')
# Perform the groupby and unstack operation
result_df = (
sorted_df.groupby(['model_name', 'validation']).size()
.unstack(fill_value=0) # Unstack and fill NaN with 0
)
## 1st Chart ## 1st Chart
# Count the number of validation results for each model # Count the number of validation results for each technique_name
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0) validation_counts = result_df.loc[sorted_df['model_name'].drop_duplicates()]
validation_counts.columns = ['Failed', 'Passed'] validation_counts.columns = ['Failed', 'Passed']
# Plot the validation results by model # Plot the validation results by technique_name
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca()) validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
plt.title('Validation Results by Model') plt.title('Validation Results by Model and Technique')
plt.xlabel('Model') plt.xlabel('Model and Technique')
plt.ylabel('Number of Tests') plt.ylabel('Number of Tests')
plt.xticks(rotation=45, ha='right') plt.xticks(rotation=45, ha='right')
plt.legend(title='Validation') plt.legend(title='Validation')
plt.tight_layout() plt.tight_layout()
plt.savefig('validation_results_by_model.png') plt.savefig('model-bar-chart.png')
@@ -54,25 +73,46 @@ plt.xlabel('Number of Tests')
plt.ylabel('Test Name') plt.ylabel('Test Name')
plt.legend(title='Validation') plt.legend(title='Validation')
plt.tight_layout() plt.tight_layout()
plt.savefig('validation_results_by_test_name.png') plt.savefig('test-bar-chart.png')
sorted_df = df.sort_values('model_size' )
## 3rd Chart # Get the unique order of 'model_technique' based on sorted_df
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0) ordered_techniques = sorted_df['model_technique'].unique()
# Create a heatmap
plt.figure(figsize=(8, 8))
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
# Create the pivot table with the correct order of model_technique
pass_rate = pd.pivot_table(
sorted_df,
values='validation',
index='model_technique',
columns='test_name',
aggfunc="mean",
fill_value=0
)
# Reorder the rows in the pivot table based on the ordered techniques
pass_rate = pass_rate.loc[ordered_techniques]
# Plot the heatmap
plt.figure(figsize=(8, 10))
sns.heatmap(
pass_rate * 100,
annot=True,
fmt=".0f",
cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True),
cbar=True,
annot_kws={"size": 10}
)
# Add percentage sign to annotations # Add percentage sign to annotations
for text in plt.gca().texts: for text in plt.gca().texts:
text.set_text(f"{text.get_text()}%") text.set_text(f"{text.get_text()}%")
# Customize the plot with labels and a title # Customize the plot with labels and a title
plt.title('Model Performance on Each Test', fontsize=16) plt.title('Model Technique Performance on Each Test', fontsize=16)
plt.xlabel('Test Name', fontsize=14) plt.xlabel('Test Name', fontsize=14)
plt.ylabel('Model', fontsize=14) plt.ylabel('Model and Technique', fontsize=14)
# Rotate x-axis labels by 45 degrees # Rotate x-axis labels by 45 degrees
plt.xticks(rotation=45, ha='right') plt.xticks(rotation=45, ha='right')
@@ -81,6 +121,22 @@ plt.xticks(rotation=45, ha='right')
plt.tight_layout() plt.tight_layout()
# Save the heatmap # Save the heatmap
plt.savefig('model_performance_heatmap.png') plt.savefig('modelTechnique_heatmap.png')
## 4th Chart: Technique Performance on Each Test (Aggregated Heatmap)
technique_pass_rate = pd.pivot_table(sorted_df, values='validation', index='test_name', columns='technique_name', aggfunc="mean", fill_value=0)
plt.figure(figsize=(8, 4))
sns.heatmap(technique_pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
# Add percentage sign to annotations
for text in plt.gca().texts:
text.set_text(f"{text.get_text()}%")
# Customize the plot with labels and a title
plt.title('Technique Performance on Each Test', fontsize=16)
plt.ylabel('Test Name', fontsize=14)
plt.xlabel('Technique', fontsize=14)
plt.xticks(rotation=0)
plt.tight_layout()
plt.savefig('technique_heatmap.png')