Merge branch 'master' of ssh://git.nx2.site:20022/nx2/test-small-llms
This commit is contained in:
@@ -4,19 +4,22 @@ from re import search
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from langchain_core.tools import Tool
|
||||||
|
from langchain_experimental.utilities import PythonREPL
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def add(a: float, b: float) -> str:
|
def add(a: float, b: float) -> str:
|
||||||
"""Adds a+b and returns the sum"""
|
"""Adds a+b and returns the sum"""
|
||||||
af = float(a)
|
af = float(a)
|
||||||
bf = float(b)
|
bf = float(b)
|
||||||
return f"{a} + {b} = {a+b}"
|
return f"{af} + {af} = {af+bf}"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def multiply(a: float, b: float) -> str:
|
def multiply(a: float, b: float) -> str:
|
||||||
"""Multiplies a*b and returns the product"""
|
"""Multiplies a*b and returns the product"""
|
||||||
af = float(a)
|
af = float(a)
|
||||||
bf = float(b)
|
bf = float(b)
|
||||||
return f"{a} * {b} = {a*b}"
|
return f"{af} * {bf} = {af*bf}"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_current_date_and_time() -> str:
|
def get_current_date_and_time() -> str:
|
||||||
@@ -99,10 +102,13 @@ def get_notes_in_timespan(begin: str, to: str) -> str:
|
|||||||
try:
|
try:
|
||||||
begin_d = datetime.strptime(begin, "%Y/%m/%d")
|
begin_d = datetime.strptime(begin, "%Y/%m/%d")
|
||||||
to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M")
|
to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M")
|
||||||
except: return "Error: Invalid input. Date format is %Y/%m/%d"
|
except ValueError:
|
||||||
|
return "Error: Invalid input. Date format is %Y/%m/%d"
|
||||||
|
|
||||||
try: assert begin_d < to_d
|
try:
|
||||||
except: return "Error: from time has to be before to time."
|
assert begin_d < to_d
|
||||||
|
except AssertionError:
|
||||||
|
return "Error: from time has to be before to time."
|
||||||
|
|
||||||
filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d]
|
filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d]
|
||||||
|
|
||||||
@@ -128,9 +134,12 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
|
|||||||
exaples:
|
exaples:
|
||||||
{"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine"""
|
{"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine"""
|
||||||
|
|
||||||
if isinstance(patterns, list): big_pattern = '|'.join(f"({s})" for s in patterns)
|
if isinstance(patterns, list):
|
||||||
elif isinstance(patterns, str): big_pattern = patterns
|
big_pattern = '|'.join(f"({s})" for s in patterns)
|
||||||
else: return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}."
|
elif isinstance(patterns, str):
|
||||||
|
big_pattern = patterns
|
||||||
|
else:
|
||||||
|
return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}."
|
||||||
|
|
||||||
filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())]
|
filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())]
|
||||||
|
|
||||||
@@ -147,7 +156,29 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def write_note(content: str) -> str:
|
def write_note(command: str) -> str:
|
||||||
"""Write a not with the current time to the database."""
|
"""Write a not with the current time to the database."""
|
||||||
return content
|
return command
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def save_python_repl(command: str):
|
||||||
|
"""Simulates the normal python repl but with certain patterns blocked for savety reasons"""
|
||||||
|
python_repl = PythonREPL()
|
||||||
|
blocked_patterns = [
|
||||||
|
"^os\\.",
|
||||||
|
"^subprocess\\.",
|
||||||
|
"^with open\\(",
|
||||||
|
]
|
||||||
|
valid = True
|
||||||
|
for pattern in blocked_patterns:
|
||||||
|
if search(pattern, command):
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
return python_repl.run(command)
|
||||||
|
else:
|
||||||
|
return f"Command not executed, becaise the blocked pattern `{pattern}` was found in the command."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from libs.classes import Test
|
from libs.classes import Test
|
||||||
from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools
|
from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools
|
||||||
from libs.validators import regex_match_any, system_human_answer_match
|
from libs.validators import regex_match_any, system_human_answer_match
|
||||||
from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note
|
from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note, save_python_repl
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
||||||
|
|
||||||
@@ -49,6 +49,19 @@ tests = {
|
|||||||
validator=regex_match_any,
|
validator=regex_match_any,
|
||||||
validation_input={"patterns": ["6134205", "6.134.205", "6,134,205"]},
|
validation_input={"patterns": ["6134205", "6.134.205", "6,134,205"]},
|
||||||
),
|
),
|
||||||
|
363: Test(
|
||||||
|
name="Complex Multiplication Python",
|
||||||
|
runnable=one_tool_call_answer,
|
||||||
|
runnable_input={
|
||||||
|
"system_msg": 'You are a helpful assistant.',
|
||||||
|
"human_msg": 'Is 31515261 divisible by 425? If not, whats the remainder?',
|
||||||
|
"tools": { "python_repl": save_python_repl },
|
||||||
|
},
|
||||||
|
validator=regex_match_any,
|
||||||
|
validation_input={
|
||||||
|
"patterns": [ "236", "two ?hundred and thirty ?six", "two ?hundred thirty ?six" ]
|
||||||
|
}
|
||||||
|
),
|
||||||
283: Test(
|
283: Test(
|
||||||
name="Notes from last Saturday",
|
name="Notes from last Saturday",
|
||||||
runnable=agent_with_tools,
|
runnable=agent_with_tools,
|
||||||
@@ -65,7 +78,7 @@ tests = {
|
|||||||
validator=system_human_answer_match,
|
validator=system_human_answer_match,
|
||||||
validation_input={
|
validation_input={
|
||||||
"criteria": dedent("""- containing the information that the Human should call Wolfgang
|
"criteria": dedent("""- containing the information that the Human should call Wolfgang
|
||||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
260: Test(
|
260: Test(
|
||||||
@@ -119,7 +132,6 @@ tests = {
|
|||||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# 363: Test(),
|
|
||||||
# 600: Test(),
|
# 600: Test(),
|
||||||
# 221: Test(),
|
# 221: Test(),
|
||||||
# 985: Test(),
|
# 985: Test(),
|
||||||
|
|||||||
92
visualize.py
92
visualize.py
@@ -1,9 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
|
||||||
|
from suite_settings.models import models
|
||||||
|
from suite_settings.techniques import techniques
|
||||||
|
from suite_settings.tests import tests
|
||||||
|
|
||||||
# Load the JSON data
|
# Load the JSON data
|
||||||
with open('saved_results.json', 'r') as f:
|
with open('saved_results.json', 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@@ -13,9 +16,12 @@ results = []
|
|||||||
for test_hash, test_data in data.items():
|
for test_hash, test_data in data.items():
|
||||||
results.append({
|
results.append({
|
||||||
"hash": test_hash,
|
"hash": test_hash,
|
||||||
"model": test_data['model_name'],
|
"model_name": models[test_data['model_id']].display_name,
|
||||||
|
"model_size": models[test_data['model_id']].parameter_count_in_b,
|
||||||
|
"technique_name": techniques[test_data['technique_id']].name,
|
||||||
|
"model_technique": f"{models[test_data['model_id']].display_name}:{ techniques[test_data['technique_id']].name}",
|
||||||
"seed": test_data['seed'],
|
"seed": test_data['seed'],
|
||||||
"test_name": test_data['test_name'],
|
"test_name": tests[test_data['test_id']].name,
|
||||||
"validation": test_data['validation']
|
"validation": test_data['validation']
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -23,21 +29,34 @@ df = pd.DataFrame(results)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
df['technique_name'] = pd.Categorical(df['technique_name'], categories=[techniques[1].name, techniques[572].name, techniques[903].name],ordered=True)
|
||||||
|
df['test_name'] = pd.Categorical(df['test_name'], categories=[tests[607].name, tests[693].name, tests[120].name, tests[283].name, tests[260].name, tests[856].name],ordered=True)
|
||||||
|
sorted_df = df.sort_values('model_size')
|
||||||
|
|
||||||
|
# Perform the groupby and unstack operation
|
||||||
|
result_df = (
|
||||||
|
sorted_df.groupby(['model_name', 'validation']).size()
|
||||||
|
.unstack(fill_value=0) # Unstack and fill NaN with 0
|
||||||
|
)
|
||||||
|
|
||||||
## 1st Chart
|
## 1st Chart
|
||||||
# Count the number of validation results for each model
|
# Count the number of validation results for each technique_name
|
||||||
validation_counts = df.groupby(['model', 'validation']).size().unstack().fillna(0)
|
validation_counts = result_df.loc[sorted_df['model_name'].drop_duplicates()]
|
||||||
validation_counts.columns = ['Failed', 'Passed']
|
validation_counts.columns = ['Failed', 'Passed']
|
||||||
|
|
||||||
# Plot the validation results by model
|
# Plot the validation results by technique_name
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
|
validation_counts.plot(kind='bar', stacked=True, color=['red', 'green'], ax=plt.gca())
|
||||||
plt.title('Validation Results by Model')
|
plt.title('Validation Results by Model and Technique')
|
||||||
plt.xlabel('Model')
|
plt.xlabel('Model and Technique')
|
||||||
plt.ylabel('Number of Tests')
|
plt.ylabel('Number of Tests')
|
||||||
plt.xticks(rotation=45, ha='right')
|
plt.xticks(rotation=45, ha='right')
|
||||||
plt.legend(title='Validation')
|
plt.legend(title='Validation')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig('validation_results_by_model.png')
|
plt.savefig('model-bar-chart.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -54,25 +73,46 @@ plt.xlabel('Number of Tests')
|
|||||||
plt.ylabel('Test Name')
|
plt.ylabel('Test Name')
|
||||||
plt.legend(title='Validation')
|
plt.legend(title='Validation')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig('validation_results_by_test_name.png')
|
plt.savefig('test-bar-chart.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
sorted_df = df.sort_values('model_size' )
|
||||||
|
|
||||||
## 3rd Chart
|
# Get the unique order of 'model_technique' based on sorted_df
|
||||||
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
ordered_techniques = sorted_df['model_technique'].unique()
|
||||||
# Create a heatmap
|
|
||||||
plt.figure(figsize=(8, 8))
|
|
||||||
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
|
||||||
|
|
||||||
|
# Create the pivot table with the correct order of model_technique
|
||||||
|
pass_rate = pd.pivot_table(
|
||||||
|
sorted_df,
|
||||||
|
values='validation',
|
||||||
|
index='model_technique',
|
||||||
|
columns='test_name',
|
||||||
|
aggfunc="mean",
|
||||||
|
fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reorder the rows in the pivot table based on the ordered techniques
|
||||||
|
pass_rate = pass_rate.loc[ordered_techniques]
|
||||||
|
|
||||||
|
# Plot the heatmap
|
||||||
|
plt.figure(figsize=(8, 10))
|
||||||
|
sns.heatmap(
|
||||||
|
pass_rate * 100,
|
||||||
|
annot=True,
|
||||||
|
fmt=".0f",
|
||||||
|
cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True),
|
||||||
|
cbar=True,
|
||||||
|
annot_kws={"size": 10}
|
||||||
|
)
|
||||||
# Add percentage sign to annotations
|
# Add percentage sign to annotations
|
||||||
for text in plt.gca().texts:
|
for text in plt.gca().texts:
|
||||||
text.set_text(f"{text.get_text()}%")
|
text.set_text(f"{text.get_text()}%")
|
||||||
|
|
||||||
# Customize the plot with labels and a title
|
# Customize the plot with labels and a title
|
||||||
plt.title('Model Performance on Each Test', fontsize=16)
|
plt.title('Model Technique Performance on Each Test', fontsize=16)
|
||||||
plt.xlabel('Test Name', fontsize=14)
|
plt.xlabel('Test Name', fontsize=14)
|
||||||
plt.ylabel('Model', fontsize=14)
|
plt.ylabel('Model and Technique', fontsize=14)
|
||||||
|
|
||||||
# Rotate x-axis labels by 45 degrees
|
# Rotate x-axis labels by 45 degrees
|
||||||
plt.xticks(rotation=45, ha='right')
|
plt.xticks(rotation=45, ha='right')
|
||||||
@@ -81,6 +121,22 @@ plt.xticks(rotation=45, ha='right')
|
|||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# Save the heatmap
|
# Save the heatmap
|
||||||
plt.savefig('model_performance_heatmap.png')
|
plt.savefig('modelTechnique_heatmap.png')
|
||||||
|
|
||||||
|
|
||||||
|
## 4th Chart: Technique Performance on Each Test (Aggregated Heatmap)
|
||||||
|
technique_pass_rate = pd.pivot_table(sorted_df, values='validation', index='test_name', columns='technique_name', aggfunc="mean", fill_value=0)
|
||||||
|
plt.figure(figsize=(8, 4))
|
||||||
|
sns.heatmap(technique_pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||||
|
|
||||||
|
# Add percentage sign to annotations
|
||||||
|
for text in plt.gca().texts:
|
||||||
|
text.set_text(f"{text.get_text()}%")
|
||||||
|
|
||||||
|
# Customize the plot with labels and a title
|
||||||
|
plt.title('Technique Performance on Each Test', fontsize=16)
|
||||||
|
plt.ylabel('Test Name', fontsize=14)
|
||||||
|
plt.xlabel('Technique', fontsize=14)
|
||||||
|
plt.xticks(rotation=0)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('technique_heatmap.png')
|
||||||
|
|||||||
Reference in New Issue
Block a user