cacheing, tests as dict, new tests
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
from libs.test_class import Test
|
from libs.test_class import Test
|
||||||
from libs.validators import system_human_answer_match
|
from typing import Union
|
||||||
from libs.runnables import basic
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -16,18 +15,36 @@ def nxhash(text:str): # @BenVida StackOverflow
|
|||||||
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
|
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
|
||||||
return hex(hash)[2:].upper().zfill(8)
|
return hex(hash)[2:].upper().zfill(8)
|
||||||
|
|
||||||
def get_len(l: list) -> int:
|
def get_len(collection: Union[list, dict]) -> int:
|
||||||
m = 0
|
maximum_length = 0
|
||||||
for e in l:
|
|
||||||
if isinstance(e, Test):
|
if isinstance(collection, dict):
|
||||||
m = max(m, len(e.name))
|
collection_type = "tests"
|
||||||
elif isinstance(e, str):
|
elif isinstance(collection, list):
|
||||||
m = max(m, len(e))
|
if isinstance(collection[0], str):
|
||||||
elif isinstance(e, int):
|
collection_type = "models"
|
||||||
m = max(m, len(str(e)))
|
elif isinstance(collection[0], int):
|
||||||
|
collection_type = "seeds"
|
||||||
else:
|
else:
|
||||||
raise Exception(f"get_len() only supports lits of Test, str or int but got {type(e)}")
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
return m
|
else:
|
||||||
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
|
|
||||||
|
match collection_type:
|
||||||
|
case "models":
|
||||||
|
for model_name in collection:
|
||||||
|
maximum_length = max(maximum_length, len(model_name))
|
||||||
|
case "seeds":
|
||||||
|
for seed in collection:
|
||||||
|
maximum_length = max(maximum_length, len(str(seed)))
|
||||||
|
case "tests":
|
||||||
|
for test_id in collection:
|
||||||
|
maximum_length = max(maximum_length, len(collection[test_id].name))
|
||||||
|
case _:
|
||||||
|
for model_name in collection:
|
||||||
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
|
|
||||||
|
return maximum_length
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -44,29 +61,63 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
run_results = {}
|
run_results = {}
|
||||||
print("Starting to run Tests ... ")
|
print("Starting to run Tests ... ")
|
||||||
for model in models:
|
for model in models:
|
||||||
for seed in seeds:
|
for test_id in tests:
|
||||||
for test in tests:
|
test = tests[test_id]
|
||||||
|
for seed in seeds:
|
||||||
# Init dict
|
# Init dict
|
||||||
combination = {
|
combination = {
|
||||||
'test_name': test.name,
|
'test_id': test_id,
|
||||||
'model': model,
|
'model': model,
|
||||||
'seed': seed,
|
'seed': seed,
|
||||||
}
|
}
|
||||||
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
||||||
|
combination['test_name'] = test.name
|
||||||
|
|
||||||
|
# if hash_key == "DE3D137E":
|
||||||
|
# pass
|
||||||
|
|
||||||
if hash_key not in saved_results.keys():
|
if hash_key not in saved_results.keys():
|
||||||
try:
|
try:
|
||||||
combination['answer'] = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
print("\033[0;35mModel '\033[0m" +
|
||||||
|
model +
|
||||||
|
"\033[0;35m'" +
|
||||||
|
(" " * (get_len(models) - len(model))) +
|
||||||
|
" with seed \033[0m\033[0;30m" +
|
||||||
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
|
"\033[0m" +
|
||||||
|
str(seed) +
|
||||||
|
"\033[0;35m now runs test '\033[0m" +
|
||||||
|
test.name +
|
||||||
|
"\033[0;35m'" +
|
||||||
|
(" " * (get_len(tests) - len(test.name))) +
|
||||||
|
" (\033[0m" +
|
||||||
|
hash_key +
|
||||||
|
"\033[0;35m)\033[0m",
|
||||||
|
end=""
|
||||||
|
)
|
||||||
|
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
||||||
|
if isinstance(answer, str): # tool capabile return tools called as a list[dict]
|
||||||
|
combination['answer'] = answer
|
||||||
|
# combination['tool_calls'] = [] # no entry
|
||||||
|
del answer
|
||||||
|
elif isinstance(answer, dict): # calls
|
||||||
|
combination['answer'] = answer['answer']
|
||||||
|
combination['tool_calls'] = answer['tool_calls']
|
||||||
|
del answer
|
||||||
|
else:
|
||||||
|
raise Exception(f"runnable returd unkown type {type(answer)}.")
|
||||||
|
|
||||||
|
|
||||||
combination['test'] = test
|
combination['test'] = test
|
||||||
run_results[hash_key] = combination
|
run_results[hash_key] = combination
|
||||||
print("\033[0;32mModel '\033[0m" +
|
print("\r\033[0;32mModel '\033[0m" +
|
||||||
model +
|
model +
|
||||||
"\033[0;32m'" +
|
"\033[0;32m'" +
|
||||||
(" " * (get_len(models) - len(model))) +
|
(" " * (get_len(models) - len(model))) +
|
||||||
" with seed \033[0m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
|
"\033[0m" +
|
||||||
str(seed) +
|
str(seed) +
|
||||||
(" " * (get_len(seeds) - len(str(seed)))) +
|
|
||||||
"\033[0;32m finished test '\033[0m" +
|
"\033[0;32m finished test '\033[0m" +
|
||||||
test.name +
|
test.name +
|
||||||
"\033[0;32m'" +
|
"\033[0;32m'" +
|
||||||
@@ -76,15 +127,16 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
"\033[0;32m)\033[0m"
|
"\033[0;32m)\033[0m"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\033[0;31mError: <\033[0m " + str(e) + "\033[0;31m>\033[0m trying to continue...")
|
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...")
|
||||||
else:
|
else:
|
||||||
print("\033[0;34mModel '\033[0m" +
|
print("\r\033[0;34mModel '\033[0m" +
|
||||||
model +
|
model +
|
||||||
"\033[0;34m'" +
|
"\033[0;34m'" +
|
||||||
(" " * (get_len(models) - len(model))) +
|
(" " * (get_len(models) - len(model))) +
|
||||||
" with seed \033[0m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
|
"\033[0m" +
|
||||||
str(seed) +
|
str(seed) +
|
||||||
(" " * (get_len(seeds) - len(str(seed)))) +
|
|
||||||
"\033[0;34m skipped test '\033[0m" +
|
"\033[0;34m skipped test '\033[0m" +
|
||||||
test.name +
|
test.name +
|
||||||
"\033[0;34m'" +
|
"\033[0;34m'" +
|
||||||
@@ -100,25 +152,37 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
for hash_key in run_results:
|
for hash_key in run_results:
|
||||||
result = run_results[hash_key]
|
result = run_results[hash_key]
|
||||||
|
|
||||||
entry = {
|
try:
|
||||||
'test_name': result['test_name'],
|
entry = {
|
||||||
'model': result['model'],
|
'test_name': result['test_name'],
|
||||||
'seed': result['seed'],
|
'test_id': result['test_id'],
|
||||||
'answer': result['answer'],
|
'model': result['model'],
|
||||||
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url)
|
'seed': result['seed'],
|
||||||
}
|
'answer': result['answer'],
|
||||||
|
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
entry['tool_calls'] = result['tool_calls']
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
saved_results[hash_key] = entry # add result with validation to saved results
|
saved_results[hash_key] = entry # add result with validation to saved results
|
||||||
|
|
||||||
print("\033[0;36mTest results of model '\033[0m" +
|
print("\033[0;36mTest results of model '\033[0m" +
|
||||||
model +
|
entry['model'] +
|
||||||
"\033[0;36m'" +
|
"\033[0;36m'" +
|
||||||
(" " * (get_len(models) - len(entry['model']))) +
|
(" " * (get_len(models) - len(entry['model']))) +
|
||||||
" with seed \033[0m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
str(seed) +
|
("0" * (get_len(seeds) - len(str(entry['seed'])))) +
|
||||||
(" " * (get_len(seeds) - len(str(entry['seed'])))) +
|
"\033[0m" +
|
||||||
|
str(entry['seed']) +
|
||||||
"\033[0;36m on test '\033[0m" +
|
"\033[0;36m on test '\033[0m" +
|
||||||
test.name +
|
entry['test_name'] +
|
||||||
"\033[0;36m'" +
|
"\033[0;36m'" +
|
||||||
(" " * (get_len(tests) - len(entry['test_name']))) +
|
(" " * (get_len(tests) - len(entry['test_name']))) +
|
||||||
" (\033[0m" +
|
" (\033[0m" +
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
||||||
from libs.test_class import Test
|
from libs.test_class import Test
|
||||||
from langchain.tools import Tool
|
from langchain.tools import Tool
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from langgraph.graph import StateGraph, MessagesState
|
||||||
|
# from langgraph.prebuilt import ToolNode
|
||||||
|
import json
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|
||||||
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
system_msg = test.runnable_input['system_msg']
|
system_msg = test.runnable_input['system_msg']
|
||||||
human_msg = test.runnable_input['human_msg']
|
human_msg = test.runnable_input['human_msg']
|
||||||
|
|
||||||
@@ -19,8 +25,9 @@ def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
ai_msg = llm.invoke(prompt)
|
ai_msg = llm.invoke(prompt)
|
||||||
return ai_msg.content
|
return ai_msg.content
|
||||||
|
|
||||||
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
||||||
|
|
||||||
|
|
||||||
|
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
system_msg = test.runnable_input['system_msg']
|
system_msg = test.runnable_input['system_msg']
|
||||||
human_msg = test.runnable_input['human_msg']
|
human_msg = test.runnable_input['human_msg']
|
||||||
tools_dict = test.runnable_input['tools']
|
tools_dict = test.runnable_input['tools']
|
||||||
@@ -42,11 +49,233 @@ def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> st
|
|||||||
prompt.append(ai_msg)
|
prompt.append(ai_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_call = ai_msg.tool_calls[0]
|
tool_calls = []
|
||||||
selected_tool = tools_dict[tool_call["name"].lower()]
|
for i in range(len(ai_msg.tool_calls)):
|
||||||
tool_msg = selected_tool.invoke(tool_call)
|
tool_call = ai_msg.tool_calls[i]
|
||||||
prompt.append(tool_msg)
|
selected_tool = tools_dict[tool_call["name"].lower()]
|
||||||
ai_msg = llm.invoke(prompt)
|
tool_msg = selected_tool.invoke(tool_call)
|
||||||
except IndexError:
|
prompt.append(tool_msg)
|
||||||
pass
|
ai_msg = llm.invoke(prompt)
|
||||||
return ai_msg.content
|
tool_calls.append({
|
||||||
|
"tool": tool_call["name"],
|
||||||
|
"args": tool_call["args"],
|
||||||
|
"index": 0
|
||||||
|
})
|
||||||
|
except IndexError: # LLM didnt use a tool -> jsut return the content
|
||||||
|
tool_calls = []
|
||||||
|
return {
|
||||||
|
"answer": ai_msg.content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
index = -1
|
||||||
|
|
||||||
|
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1]
|
||||||
|
nonlocal index
|
||||||
|
if last_message.tool_calls:
|
||||||
|
index += 1
|
||||||
|
return "tools"
|
||||||
|
return "__end__"
|
||||||
|
|
||||||
|
def call_llm(state: MessagesState):
|
||||||
|
messages = state["messages"]
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
class NxToolNode:
|
||||||
|
"""A node that runs the tools requested in the last AIMessage."""
|
||||||
|
|
||||||
|
def __init__(self, tools: list) -> None:
|
||||||
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict):
|
||||||
|
if messages := inputs.get("messages", []):
|
||||||
|
message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("No message found in input")
|
||||||
|
outputs = []
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
|
||||||
|
nonlocal tool_calls
|
||||||
|
nonlocal index
|
||||||
|
tool_calls.append({
|
||||||
|
"tool": tool_call["name"],
|
||||||
|
"args": tool_call["args"],
|
||||||
|
"index": index
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
||||||
|
except KeyError as e:
|
||||||
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
||||||
|
except ValidationError as e:
|
||||||
|
tool_result = 'Tool got invalid input:\n' + e
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = 'Error: ' + str(e)
|
||||||
|
|
||||||
|
outputs.append(
|
||||||
|
ToolMessage(
|
||||||
|
content=json.dumps(tool_result),
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"messages": outputs}
|
||||||
|
|
||||||
|
|
||||||
|
tools_dict = test.runnable_input['tools']
|
||||||
|
tools = []
|
||||||
|
for key in tools_dict:
|
||||||
|
tools.append(tools_dict[key])
|
||||||
|
tool_node = NxToolNode(tools)
|
||||||
|
llm = ChatOllama(
|
||||||
|
model=model,
|
||||||
|
seed=seed,
|
||||||
|
base_url=base_url
|
||||||
|
).bind_tools(tools)
|
||||||
|
|
||||||
|
workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
|
# Define the two nodes we will cycle between
|
||||||
|
workflow.add_node("agent", call_llm)
|
||||||
|
workflow.add_node("tools", tool_node)
|
||||||
|
|
||||||
|
workflow.add_edge("__start__", "agent")
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"agent",
|
||||||
|
should_continue,
|
||||||
|
)
|
||||||
|
workflow.add_edge("tools", "agent")
|
||||||
|
|
||||||
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
# example with a single tool call
|
||||||
|
start_messages = [
|
||||||
|
SystemMessage(content=test.runnable_input['system_msg']),
|
||||||
|
HumanMessage(content=test.runnable_input['human_msg'])
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chunk in graph.stream(
|
||||||
|
{"messages": start_messages},
|
||||||
|
stream_mode="values",
|
||||||
|
): chunks.append(chunk["messages"][-1])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": chunks[-1].content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
index = -1
|
||||||
|
|
||||||
|
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1]
|
||||||
|
nonlocal index
|
||||||
|
if last_message.tool_calls:
|
||||||
|
index += 1
|
||||||
|
return "tools"
|
||||||
|
return "__end__"
|
||||||
|
|
||||||
|
def call_llm(state: MessagesState):
|
||||||
|
messages = state["messages"]
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
class NxToolNode:
|
||||||
|
"""A node that runs the tools requested in the last AIMessage."""
|
||||||
|
|
||||||
|
def __init__(self, tools: list) -> None:
|
||||||
|
self.tools_by_name = {tool.name: tool for tool in tools}
|
||||||
|
|
||||||
|
def __call__(self, inputs: dict):
|
||||||
|
if messages := inputs.get("messages", []):
|
||||||
|
message = messages[-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("No message found in input")
|
||||||
|
outputs = []
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
|
||||||
|
nonlocal tool_calls
|
||||||
|
nonlocal index
|
||||||
|
tool_calls.append({
|
||||||
|
"tool": tool_call["name"],
|
||||||
|
"args": tool_call["args"],
|
||||||
|
"index": index
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
||||||
|
except KeyError as e:
|
||||||
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
||||||
|
except ValidationError as e:
|
||||||
|
tool_result = 'Tool got invalid input:\n' + e
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = 'Error: ' + str(e)
|
||||||
|
|
||||||
|
outputs.append(
|
||||||
|
ToolMessage(
|
||||||
|
content=json.dumps(tool_result),
|
||||||
|
name=tool_call["name"],
|
||||||
|
tool_call_id=tool_call["id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {"messages": outputs}
|
||||||
|
|
||||||
|
|
||||||
|
tools_dict = test.runnable_input['tools']
|
||||||
|
tools = []
|
||||||
|
for key in tools_dict:
|
||||||
|
tools.append(tools_dict[key])
|
||||||
|
tool_node = NxToolNode(tools)
|
||||||
|
llm = ChatOllama(
|
||||||
|
model=model,
|
||||||
|
seed=seed,
|
||||||
|
base_url=base_url
|
||||||
|
).bind_tools(tools)
|
||||||
|
|
||||||
|
workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
|
# Define the two nodes we will cycle between
|
||||||
|
workflow.add_node("agent", call_llm)
|
||||||
|
workflow.add_node("tools", tool_node)
|
||||||
|
|
||||||
|
workflow.add_edge("__start__", "agent")
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"agent",
|
||||||
|
should_continue,
|
||||||
|
)
|
||||||
|
workflow.add_edge("tools", "agent")
|
||||||
|
|
||||||
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
# example with a single tool call
|
||||||
|
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chunk in graph.stream(
|
||||||
|
{"messages": start_messages},
|
||||||
|
stream_mode="values",
|
||||||
|
): chunks.append(chunk["messages"][-1])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": chunks[-1].content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
}
|
||||||
|
|||||||
147
libs/tools.py
147
libs/tools.py
@@ -1,11 +1,150 @@
|
|||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from re import search
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def add(a: float, b: float) -> float:
|
def add(a: float, b: float) -> str:
|
||||||
"""Adds a+b and retuns the sum"""
|
"""Adds a+b and retuns the sum"""
|
||||||
return a+b
|
af = float(a)
|
||||||
|
bf = float(b)
|
||||||
|
return f"{a} + {b} = {a+b}"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def multiply(a: float, b: float) -> float:
|
def multiply(a: float, b: float) -> str:
|
||||||
"""Multiplies a*b and retuns the product"""
|
"""Multiplies a*b and retuns the product"""
|
||||||
return a*b
|
af = float(a)
|
||||||
|
bf = float(b)
|
||||||
|
return f"{a} * {b} = {a*b}"
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_current_date_and_time() -> str:
|
||||||
|
"""Return current Date and time"""
|
||||||
|
return "Thursday the 8th of August 2024 18:03"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Entry:
|
||||||
|
time: datetime
|
||||||
|
content: str
|
||||||
|
|
||||||
|
note_entries = [
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/03 14:58", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Granny Petra says I should call Wolfgang to ask him when Susanne comes back when he comes back from his holidays."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/07 09:15", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Mom says to buy some fresh flowers for the living room before Aunt Linda visits."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/06 18:30", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Pick up the dry cleaning on Thursday; they close early on Fridays."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/05 11:45", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Ask Dr. Mills about the side effects of the new medication he got me."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/04 16:00", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Call the plumber to fix the leak in the upstairs bathroom."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/03 08:00", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Schedule a car service appointment before the road trip to the mountains."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/02 20:10", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Check if the library has a copy of the new mystery novel everyone is talking about."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/08/01 14:30", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Send a thank-you note to Mrs. Jenkins for the lovely dinner last weekend."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/07/31 12:05", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Email the project update to the team by the end of the week."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/07/30 07:50", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Pick up a birthday card for Uncle George before the family gathering on Sunday."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2024/07/29 15:20", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Research local yoga classes; consider signing up for the weekend session."
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2023/08/01 07:21", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Talk to Joffrey for the insurance!"
|
||||||
|
),
|
||||||
|
Entry(
|
||||||
|
time=datetime.strptime("2023/08/01 23:10", "%Y/%m/%d %H:%M"),
|
||||||
|
content="Went out with Charlotte for our anniversary. Pizza at Cavalinos. She loved the Necklace!"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_notes_in_timespan(begin: str, to: str) -> str:
|
||||||
|
"""Recieves the Notes saved in a time span.
|
||||||
|
|
||||||
|
aguments:
|
||||||
|
begin: str # start of the time span (incluive) %Y/%m/%d
|
||||||
|
to: str # end of the timespan (incluive) %Y/%m/%d
|
||||||
|
|
||||||
|
exaples:
|
||||||
|
{"begin": "2012/08/31", "to": "2012/09/06"} # 7 days from the 31st 00:00 till the 6th 23:59
|
||||||
|
{"begin": "2019/04/14", "to": "2019/04/14"} # All notes from the 19th of April 2019"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
begin_d = datetime.strptime(begin, "%Y/%m/%d")
|
||||||
|
to_d = datetime.strptime(to+" 23:59", "%Y/%m/%d %H:%M")
|
||||||
|
except: return "Error: Invalid input. Date format is %Y/%m/%d"
|
||||||
|
|
||||||
|
try: assert begin_d < to_d
|
||||||
|
except: return "Error: from time has to be before to time."
|
||||||
|
|
||||||
|
filtered_entries = [entry for entry in note_entries if begin_d <= entry.time <= to_d]
|
||||||
|
|
||||||
|
if filtered_entries == []:
|
||||||
|
return "No entries were found for that time period."
|
||||||
|
|
||||||
|
ret = ""
|
||||||
|
is_first = True
|
||||||
|
for entry in filtered_entries:
|
||||||
|
ret += '' if is_first else '\n\n'
|
||||||
|
ret += f"{datetime.strftime(entry.time, '%Y/%m/%d %H:%M')} {entry.content}"
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def get_notes_containing(patterns: Union[list[str], str]) -> str:
|
||||||
|
"""Recieves the Notes matching any of the RegEx patterns.
|
||||||
|
|
||||||
|
aguments:
|
||||||
|
patterns: Union[list[str], str] # case insensitive pattern(s) notes are to be mached against
|
||||||
|
|
||||||
|
exaples:
|
||||||
|
{"patterns": [ "Aunt(ie)?", "Sabine" ]} # Looks for Notes related to Aunt Sabine"""
|
||||||
|
|
||||||
|
if isinstance(patterns, list): big_pattern = '|'.join(f"({s})" for s in patterns)
|
||||||
|
elif isinstance(patterns, str): big_pattern = patterns
|
||||||
|
else: return f"Error: Invalid Input type. `patterns` can either be a list of strings or a single string. But got {type(patterns)}."
|
||||||
|
|
||||||
|
filtered_entries = [entry for entry in note_entries if search(big_pattern.lower(), entry.content.lower())]
|
||||||
|
|
||||||
|
ret = ""
|
||||||
|
is_first = True
|
||||||
|
for entry in filtered_entries:
|
||||||
|
ret += '' if is_first else '\n\n'
|
||||||
|
ret += f"{datetime.strftime(entry.time, '%Y/%m/%d %H:%M')} {entry.content}"
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def write_note(content: str) -> str:
|
||||||
|
"""Write a not with the current time to the database."""
|
||||||
|
return content
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from langchain_ollama.chat_models import ChatOllama
|
|||||||
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from libs.test_class import Test
|
from libs.test_class import Test
|
||||||
|
from re import search
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
||||||
|
|
||||||
@@ -11,27 +13,19 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
|||||||
return rating
|
return rating
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages([
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
SystemMessagePromptTemplate.from_template(template="""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
|
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
|
||||||
{validation_input}
|
{validation_input}
|
||||||
|
|
||||||
else as incorrect. Only use the rate tool. Do not answer conversationally."""),
|
else as incorrect. Only use the rate tool. Do not answer conversationally.""")),
|
||||||
# SystemMessagePromptTemplate.from_template(template="""You are a rating machine. You are given 3 things: The system message, the Human query, and the AI response. You evaluate the response as correct if
|
HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
|
||||||
# {validation_input}
|
{system_msg}
|
||||||
|
|
||||||
# If the answer does not match these criteria, rate the answer as incorrect. If the answer is a "refusal" or a "declaration of incapability", the answer is automatically incorrect.
|
Human query:
|
||||||
|
{human_msg}
|
||||||
|
|
||||||
# **Only use the rate tool. Do not under any circumstances answer conversationally**.
|
LLM answer:
|
||||||
# DO NOT ANSWER WITH <I'm sorry but I do not have the capability to perform this task for you...> or anything like it.
|
{answer}
|
||||||
# Use the rate tool!"""),
|
"""))
|
||||||
HumanMessagePromptTemplate.from_template(template="""System Message:
|
|
||||||
{system_msg}
|
|
||||||
|
|
||||||
Human query:
|
|
||||||
{human_msg}
|
|
||||||
|
|
||||||
LLM answer:
|
|
||||||
{answer}
|
|
||||||
""")
|
|
||||||
]).invoke({
|
]).invoke({
|
||||||
"validation_input": test.validation_input['criteria'],
|
"validation_input": test.validation_input['criteria'],
|
||||||
"system_msg": test.runnable_input['system_msg'],
|
"system_msg": test.runnable_input['system_msg'],
|
||||||
@@ -48,7 +42,10 @@ LLM answer:
|
|||||||
ai_msg = llm.invoke(prompt)
|
ai_msg = llm.invoke(prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret_str = rate.invoke(ai_msg.tool_calls[0]).content
|
tool_call = ai_msg.tool_calls[0]
|
||||||
|
if tool_call['name'] != "rate":
|
||||||
|
raise Exception(f"Verificaiton model tried to tool `{tool_call['name']}` not `rate`")
|
||||||
|
ret_str = rate.invoke(tool_call).content
|
||||||
if ret_str.lower() == 'true': return True
|
if ret_str.lower() == 'true': return True
|
||||||
elif ret_str.lower() == 'false': return False
|
elif ret_str.lower() == 'false': return False
|
||||||
else: raise Exception(f"rate tool retured {ret_str}")
|
else: raise Exception(f"rate tool retured {ret_str}")
|
||||||
@@ -56,8 +53,6 @@ LLM answer:
|
|||||||
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
||||||
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
||||||
|
|
||||||
from re import search
|
|
||||||
|
|
||||||
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
||||||
match = False
|
match = False
|
||||||
for pattern in test.validation_input['patterns']:
|
for pattern in test.validation_input['patterns']:
|
||||||
|
|||||||
Reference in New Issue
Block a user