mega commit
This commit is contained in:
18
libs/classes.py
Normal file
18
libs/classes.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Test:
|
||||||
|
name: str
|
||||||
|
runnable: Callable
|
||||||
|
runnable_input: dict
|
||||||
|
validator: Callable
|
||||||
|
validation_input: dict
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Model:
|
||||||
|
display_name: str
|
||||||
|
identifier: str
|
||||||
|
supports_tools: bool
|
||||||
|
parameter_count_in_b: float
|
||||||
|
|
||||||
5
libs/functions.py
Normal file
5
libs/functions.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
def nxhash(text:str) -> str: # @BenVida StackOverflow
|
||||||
|
hash=0
|
||||||
|
for ch in text:
|
||||||
|
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
|
||||||
|
return str(hex(hash)[2:].upper().zfill(8))
|
||||||
322
libs/ollama_functions.py
Normal file
322
libs/ollama_functions.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
from types import NoneType
|
||||||
|
|
||||||
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool, Tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||||
|
|
||||||
|
from libs.functions import nxhash
|
||||||
|
|
||||||
|
DEFAULT_SYTEM_PROMPT = """You have access to the following tools:
|
||||||
|
|
||||||
|
{tools}
|
||||||
|
|
||||||
|
You must always select one of the above tools and respond with only a JSON object matching the following schema:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"tool": <name of the selected tool>,
|
||||||
|
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SYTEM_PROMPT_WITH_HISTORY = """{system_msg}
|
||||||
|
|
||||||
|
You continue a chat history either conversationally or with a tool call.
|
||||||
|
|
||||||
|
You have access to the following tools:
|
||||||
|
|
||||||
|
{tools}
|
||||||
|
|
||||||
|
You must either select one of the above tools and respond with only a JSON object matching the following schema:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"tool": <name of the selected tool>,
|
||||||
|
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
|
||||||
|
}}
|
||||||
|
|
||||||
|
or answer conversationally normally.
|
||||||
|
|
||||||
|
The conversation before consisted of the following messages:
|
||||||
|
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Now you must answer accordingly either conversationally or with another tool call.
|
||||||
|
|
||||||
|
For conversational answers: Answer as if it was a continuous conversation. The Human only sees the conversational responses, and not anything about the tools. Do not mention the tools or the process of using them.
|
||||||
|
"""
|
||||||
|
CONVERSATIONAL_RESPONSE_TOOL = {
|
||||||
|
"name": "__conversational_response",
|
||||||
|
"description": (
|
||||||
|
"Respond conversationally if no other tools should be called for a given query."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Conversational response to the user.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["response"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_BM = TypeVar("_BM", bound=BaseModel)
|
||||||
|
_DictOrPydantic = Union[Dict, _BM]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
|
return isinstance(obj, type) and (
|
||||||
|
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||||
|
"""Convert a tool to an Ollama tool."""
|
||||||
|
description = None
|
||||||
|
if _is_pydantic_class(tool):
|
||||||
|
schema = tool.construct().schema()
|
||||||
|
name = schema["title"]
|
||||||
|
elif isinstance(tool, BaseTool):
|
||||||
|
schema = tool.tool_call_schema.schema()
|
||||||
|
name = tool.get_name()
|
||||||
|
description = tool.description
|
||||||
|
elif is_basemodel_instance(tool):
|
||||||
|
schema = tool.get_input_schema().schema()
|
||||||
|
name = tool.get_name()
|
||||||
|
description = tool.description
|
||||||
|
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
||||||
|
return tool.copy()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"""Cannot convert {tool} to an Ollama tool.
|
||||||
|
{tool} needs to be a Pydantic class, model, or a dict."""
|
||||||
|
)
|
||||||
|
definition = {"name": name, "parameters": schema}
|
||||||
|
if description:
|
||||||
|
definition["description"] = description
|
||||||
|
|
||||||
|
return definition
|
||||||
|
|
||||||
|
def parse_response(message: BaseMessage) -> str:
|
||||||
|
"""Extract `function_call` from `AIMessage`."""
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
kwargs = message.additional_kwargs
|
||||||
|
tool_calls = message.tool_calls
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
tool_call = tool_calls[-1]
|
||||||
|
args = tool_call.get("args")
|
||||||
|
return json.dumps(args)
|
||||||
|
elif "function_call" in kwargs:
|
||||||
|
if "arguments" in kwargs["function_call"]:
|
||||||
|
return kwargs["function_call"]["arguments"]
|
||||||
|
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
|
||||||
|
else:
|
||||||
|
raise ValueError("`tool_calls` missing from AIMessage: {message}")
|
||||||
|
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaFunctions(ChatOllama):
|
||||||
|
"""Function chat model that uses Ollama API."""
|
||||||
|
|
||||||
|
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
|
||||||
|
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
return self.bind(functions=tools, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
|
||||||
|
|
||||||
|
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||||
|
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||||
|
call_list = []
|
||||||
|
for c in tool_calls:
|
||||||
|
call_list.append({
|
||||||
|
"id": nxhash(c['id'])[-4:],
|
||||||
|
"tool": c['name'],
|
||||||
|
"args": c['args']
|
||||||
|
})
|
||||||
|
if len(call_list) == 1:
|
||||||
|
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||||
|
formated_history = ""
|
||||||
|
system_msg = ""
|
||||||
|
for m in messages:
|
||||||
|
|
||||||
|
if formated_history != "":
|
||||||
|
formated_history += "\n\n"
|
||||||
|
|
||||||
|
if isinstance(m, SystemMessage):
|
||||||
|
system_msg += str(m.content)
|
||||||
|
elif isinstance(m, HumanMessage):
|
||||||
|
formated_history += "The Human said:\n" + str(m.content)
|
||||||
|
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||||
|
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||||
|
elif isinstance(m, ToolMessage):
|
||||||
|
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||||
|
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||||
|
formated_history += "You said:\n" + str(m.content)
|
||||||
|
else:
|
||||||
|
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||||
|
|
||||||
|
return system_msg, formated_history
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
|
||||||
|
try:
|
||||||
|
parsed_chat_result = json.loads(chat_result_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed_chat_result = chat_result_str
|
||||||
|
return parsed_chat_result
|
||||||
|
|
||||||
|
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
|
||||||
|
if not parsed_chat_result:
|
||||||
|
called_tool_name = None
|
||||||
|
elif "tool" in parsed_chat_result:
|
||||||
|
called_tool_name = d["tool"] # per spec
|
||||||
|
elif "name" in d:
|
||||||
|
called_tool_name = d["name"] # Phi3 often does this
|
||||||
|
elif "tool_name" in d:
|
||||||
|
called_tool_name = d["tool_name"] # Phi3 often does this
|
||||||
|
elif "action" in d:
|
||||||
|
called_tool_name = d["action"] # Phi3 does this
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
called_tool = [tool for tool in functions_list if tool['name'] == called_tool_name][0]
|
||||||
|
except IndexError:
|
||||||
|
return None # when a tool is called, but the tool doesnt exist
|
||||||
|
|
||||||
|
return called_tool
|
||||||
|
|
||||||
|
def _extract_conversaional_response(self, d: dict) -> str:
|
||||||
|
if ("tool_input" in d and "response" in d["tool_input"]):
|
||||||
|
response = d["tool_input"]["response"]
|
||||||
|
elif ("input" in d and "response" in d["input"]):
|
||||||
|
response = d["input"]["response"]
|
||||||
|
elif ("args" in d and "response" in d["args"]):
|
||||||
|
response = d["args"]["response"]
|
||||||
|
elif "response" in d:
|
||||||
|
response = d["response"]
|
||||||
|
elif "input" in d:
|
||||||
|
response = d["input"]
|
||||||
|
elif "args" in d:
|
||||||
|
response = d["args"]
|
||||||
|
elif "tool_input" in d:
|
||||||
|
response = d["tool_input"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert isinstance(response, str)
|
||||||
|
except AssertionError:
|
||||||
|
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _extract_tool_args(self, d: dict) -> dict:
|
||||||
|
if "tool_input" in parsed_chat_result:
|
||||||
|
called_tool_args = d["tool_input"] # per spec
|
||||||
|
elif "input" in d:
|
||||||
|
called_tool_args = d["input"] # Phi3 often does this
|
||||||
|
elif "args" in d:
|
||||||
|
called_tool_args = d["args"]
|
||||||
|
else:
|
||||||
|
called_tool_args = {}
|
||||||
|
return called_tool_args
|
||||||
|
|
||||||
|
# prepare generation
|
||||||
|
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
|
||||||
|
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
|
||||||
|
functions_str = json.dumps(functions_list, indent=2)
|
||||||
|
|
||||||
|
# prepare generation with history
|
||||||
|
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||||
|
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||||
|
|
||||||
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||||
|
system_message = system_message_prompt_template.format(
|
||||||
|
tools=functions_str,
|
||||||
|
history=formated_history,
|
||||||
|
system_msg=system_msg
|
||||||
|
)
|
||||||
|
final_messages = [ system_message ]
|
||||||
|
|
||||||
|
# prepare generation without history
|
||||||
|
else:
|
||||||
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||||
|
system_message = system_message_prompt_template.format(
|
||||||
|
tools=functions_str
|
||||||
|
)
|
||||||
|
final_messages = [ system_message ] + messages
|
||||||
|
|
||||||
|
# genrerate chat result
|
||||||
|
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
chat_result = response_message.generations[0].text
|
||||||
|
|
||||||
|
# chekc for validity
|
||||||
|
if not isinstance(chat_result, str):
|
||||||
|
raise ValueError("OllamaFunctions does not support non-string output.")
|
||||||
|
|
||||||
|
# make str to dict
|
||||||
|
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
|
||||||
|
# if model failed to return vailid json, just retrun the whole thing
|
||||||
|
if isinstance(parsed_chat_result, str):
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
|
||||||
|
|
||||||
|
|
||||||
|
# get the called tool from the dict
|
||||||
|
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
|
||||||
|
|
||||||
|
if not called_tool:
|
||||||
|
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||||
|
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
|
||||||
|
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||||
|
else:
|
||||||
|
response_msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[ToolCall(
|
||||||
|
name=called_tool['name'],
|
||||||
|
args=_extract_tool_args(self, d=parsed_chat_result),
|
||||||
|
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=response_msg)])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "ollama_functions"
|
||||||
@@ -1,30 +1,19 @@
|
|||||||
from libs.test_class import Test
|
from libs.classes import Test, Model
|
||||||
|
from libs.functions import nxhash
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
def padd(list, element):
|
|
||||||
longest = 0
|
|
||||||
for s in list:
|
|
||||||
longest = max(longest, len(str(s)))
|
|
||||||
return str(element).ljust(longest)
|
|
||||||
|
|
||||||
def nxhash(text:str): # @BenVida StackOverflow
|
|
||||||
hash=0
|
|
||||||
for ch in text:
|
|
||||||
hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF
|
|
||||||
return hex(hash)[2:].upper().zfill(8)
|
|
||||||
|
|
||||||
def get_len(collection: Union[list, dict]) -> int:
|
def get_len(collection: Union[list, dict]) -> int:
|
||||||
maximum_length = 0
|
maximum_length = 0
|
||||||
|
|
||||||
if isinstance(collection, dict):
|
if isinstance(collection, list):
|
||||||
collection_type = "tests"
|
|
||||||
elif isinstance(collection, list):
|
|
||||||
if isinstance(collection[0], str):
|
|
||||||
collection_type = "models"
|
|
||||||
elif isinstance(collection[0], int):
|
|
||||||
collection_type = "seeds"
|
collection_type = "seeds"
|
||||||
|
elif isinstance(collection, dict):
|
||||||
|
if isinstance(collection[list(collection.keys())[0]], Model):
|
||||||
|
collection_type = "models"
|
||||||
|
elif isinstance(collection[list(collection.keys())[0]], Test):
|
||||||
|
collection_type = "tests"
|
||||||
else:
|
else:
|
||||||
raise TypeError("get_len: unsupported collection_type")
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
else:
|
else:
|
||||||
@@ -32,8 +21,8 @@ def get_len(collection: Union[list, dict]) -> int:
|
|||||||
|
|
||||||
match collection_type:
|
match collection_type:
|
||||||
case "models":
|
case "models":
|
||||||
for model_name in collection:
|
for model_id in collection:
|
||||||
maximum_length = max(maximum_length, len(model_name))
|
maximum_length = max(maximum_length, len(collection[model_id].display_name))
|
||||||
case "seeds":
|
case "seeds":
|
||||||
for seed in collection:
|
for seed in collection:
|
||||||
maximum_length = max(maximum_length, len(str(seed)))
|
maximum_length = max(maximum_length, len(str(seed)))
|
||||||
@@ -48,30 +37,32 @@ def get_len(collection: Union[list, dict]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url: str):
|
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
|
||||||
try:
|
try:
|
||||||
print("Trying to load saved_results.json")
|
print("Trying to load saved_results.json")
|
||||||
with open("./saved_results.json", "r") as f:
|
with open("./saved_results.json", "r") as f:
|
||||||
saved_results = json.load(fp=f)
|
saved_results = json.load(fp=f)
|
||||||
print("Loaded.")
|
print("Loaded.")
|
||||||
except:
|
except FileNotFoundError:
|
||||||
print("saved_results.json not found. Initializing empty.")
|
print("saved_results.json not found. Initializing empty.")
|
||||||
saved_results = {}
|
saved_results = {}
|
||||||
# Get Results
|
# Get Results
|
||||||
run_results = {}
|
run_results = {}
|
||||||
print("Starting to run Tests ... ")
|
print("Starting to run Tests ... ")
|
||||||
for model in models:
|
for model_id in models:
|
||||||
|
model = models[model_id]
|
||||||
for test_id in tests:
|
for test_id in tests:
|
||||||
test = tests[test_id]
|
test = tests[test_id]
|
||||||
for seed in seeds:
|
for seed in seeds:
|
||||||
# Init dict
|
# Init dict
|
||||||
combination = {
|
combination = {
|
||||||
'test_id': test_id,
|
'test_id': test_id,
|
||||||
'model': model,
|
'model_id': model_id,
|
||||||
'seed': seed,
|
'seed': seed,
|
||||||
}
|
}
|
||||||
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
||||||
combination['test_name'] = test.name
|
combination['test_name'] = test.name
|
||||||
|
combination['model_name'] = model.display_name
|
||||||
|
|
||||||
# if hash_key == "DE3D137E":
|
# if hash_key == "DE3D137E":
|
||||||
# pass
|
# pass
|
||||||
@@ -79,9 +70,9 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
if hash_key not in saved_results.keys():
|
if hash_key not in saved_results.keys():
|
||||||
try:
|
try:
|
||||||
print("\033[0;35mModel '\033[0m" +
|
print("\033[0;35mModel '\033[0m" +
|
||||||
model +
|
model.display_name +
|
||||||
"\033[0;35m'" +
|
"\033[0;35m'" +
|
||||||
(" " * (get_len(models) - len(model))) +
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
" with seed \033[0m\033[0;30m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
"\033[0m" +
|
"\033[0m" +
|
||||||
@@ -96,7 +87,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
end=""
|
end=""
|
||||||
)
|
)
|
||||||
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
||||||
if isinstance(answer, str): # tool capabile return tools called as a list[dict]
|
if isinstance(answer, str):
|
||||||
combination['answer'] = answer
|
combination['answer'] = answer
|
||||||
# combination['tool_calls'] = [] # no entry
|
# combination['tool_calls'] = [] # no entry
|
||||||
del answer
|
del answer
|
||||||
@@ -105,15 +96,14 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
combination['tool_calls'] = answer['tool_calls']
|
combination['tool_calls'] = answer['tool_calls']
|
||||||
del answer
|
del answer
|
||||||
else:
|
else:
|
||||||
raise Exception(f"runnable returd unkown type {type(answer)}.")
|
raise Exception(f"runnable returned unkown type {type(answer)}.")
|
||||||
|
|
||||||
|
|
||||||
combination['test'] = test
|
combination['test'] = test
|
||||||
run_results[hash_key] = combination
|
run_results[hash_key] = combination
|
||||||
print("\r\033[0;32mModel '\033[0m" +
|
print("\r\033[0;32mModel '\033[0m" +
|
||||||
model +
|
model.display_name +
|
||||||
"\033[0;32m'" +
|
"\033[0;32m'" +
|
||||||
(" " * (get_len(models) - len(model))) +
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
" with seed \033[0m\033[0;30m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
"\033[0m" +
|
"\033[0m" +
|
||||||
@@ -127,12 +117,12 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
"\033[0;32m)\033[0m"
|
"\033[0;32m)\033[0m"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...")
|
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
|
||||||
else:
|
else:
|
||||||
print("\r\033[0;34mModel '\033[0m" +
|
print("\r\033[0;34mModel '\033[0m" +
|
||||||
model +
|
model.display_name +
|
||||||
"\033[0;34m'" +
|
"\033[0;34m'" +
|
||||||
(" " * (get_len(models) - len(model))) +
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
" with seed \033[0m\033[0;30m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
"\033[0m" +
|
"\033[0m" +
|
||||||
@@ -148,7 +138,8 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
|
|
||||||
|
|
||||||
# Validate Results
|
# Validate Results
|
||||||
if run_results != {}: print("\nStarting validation of tests ...")
|
if run_results != {}:
|
||||||
|
print("\nStarting validation of tests ...")
|
||||||
for hash_key in run_results:
|
for hash_key in run_results:
|
||||||
result = run_results[hash_key]
|
result = run_results[hash_key]
|
||||||
|
|
||||||
@@ -156,27 +147,28 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
entry = {
|
entry = {
|
||||||
'test_name': result['test_name'],
|
'test_name': result['test_name'],
|
||||||
'test_id': result['test_id'],
|
'test_id': result['test_id'],
|
||||||
'model': result['model'],
|
'model_name': result['model_name'],
|
||||||
|
'model_id': result['model_id'],
|
||||||
'seed': result['seed'],
|
'seed': result['seed'],
|
||||||
'answer': result['answer'],
|
'answer': result['answer'],
|
||||||
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
|
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m")
|
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m ")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
entry['tool_calls'] = result['tool_calls']
|
entry['tool_calls'] = result['tool_calls']
|
||||||
except:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
saved_results[hash_key] = entry # add result with validation to saved results
|
saved_results[hash_key] = entry # add result with validation to saved results
|
||||||
|
|
||||||
print("\033[0;36mTest results of model '\033[0m" +
|
print("\033[0;36mTest results of model '\033[0m" +
|
||||||
entry['model'] +
|
entry['model_name'] +
|
||||||
"\033[0;36m'" +
|
"\033[0;36m'" +
|
||||||
(" " * (get_len(models) - len(entry['model']))) +
|
(" " * (get_len(models) - len(entry['model_name']))) +
|
||||||
" with seed \033[0m\033[0;30m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
("0" * (get_len(seeds) - len(str(entry['seed'])))) +
|
("0" * (get_len(seeds) - len(str(entry['seed'])))) +
|
||||||
"\033[0m" +
|
"\033[0m" +
|
||||||
@@ -188,7 +180,7 @@ def run_tests(models: list[str], seeds: list[int], tests: list[Test], base_url:
|
|||||||
" (\033[0m" +
|
" (\033[0m" +
|
||||||
hash_key +
|
hash_key +
|
||||||
"\033[0;36m) evaluated to \033[0m" +
|
"\033[0;36m) evaluated to \033[0m" +
|
||||||
('\033[0;32mcorrect\033[0m' if entry['validation'] == True else '\033[0;31mincorrect\033[0m')
|
('\033[0;32mcorrect\033[0m' if entry['validation'] else '\033[0;31mincorrect\033[0m')
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("./saved_results.json", "w") as f:
|
with open("./saved_results.json", "w") as f:
|
||||||
|
|||||||
@@ -1,76 +1,100 @@
|
|||||||
|
from types import NoneType
|
||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
from libs.ollama_functions import OllamaFunctions
|
||||||
from libs.test_class import Test
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
|
||||||
|
from libs.classes import Test, Model
|
||||||
from langchain.tools import Tool
|
from langchain.tools import Tool
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, MessagesState
|
from langgraph.graph import StateGraph, MessagesState
|
||||||
# from langgraph.prebuilt import ToolNode
|
|
||||||
import json
|
import json
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
|
||||||
def basic(model: str, seed: int, test: Test, base_url: str) -> str:
|
if model.supports_tools:
|
||||||
system_msg = test.runnable_input['system_msg']
|
|
||||||
human_msg = test.runnable_input['human_msg']
|
|
||||||
|
|
||||||
if system_msg == None: prompt = [ human_msg ]
|
|
||||||
else: prompt = [ system_msg, human_msg ]
|
|
||||||
|
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
model=model,
|
model=model.identifier,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
base_url=base_url
|
base_url=base_url
|
||||||
)
|
)
|
||||||
ai_msg = llm.invoke(prompt)
|
else:
|
||||||
|
llm = OllamaFunctions(
|
||||||
|
model=model.identifier,
|
||||||
|
seed=seed,
|
||||||
|
base_url=base_url,
|
||||||
|
format="json"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
llm = llm.bind_tools(tools=tools)
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
|
||||||
|
|
||||||
|
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||||
|
try:
|
||||||
|
messages += test.runnable_input['fsp_messages']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||||
|
|
||||||
|
llm = _get_llm(model=model, base_url=base_url, seed=seed)
|
||||||
|
ai_msg = llm.invoke(messages)
|
||||||
|
assert isinstance(ai_msg.content, str)
|
||||||
return ai_msg.content
|
return ai_msg.content
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def one_tool_call_answer(model: str, seed: int, test: Test, base_url: str) -> str:
|
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
|
||||||
system_msg = test.runnable_input['system_msg']
|
|
||||||
human_msg = test.runnable_input['human_msg']
|
|
||||||
tools_dict = test.runnable_input['tools']
|
tools_dict = test.runnable_input['tools']
|
||||||
tools = []
|
tools = []
|
||||||
for key in tools_dict:
|
for key in tools_dict:
|
||||||
tools.append(tools_dict[key])
|
tools.append(tools_dict[key])
|
||||||
|
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||||
|
|
||||||
if system_msg == None: prompt = [ human_msg ]
|
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||||
else: prompt = [ system_msg, human_msg ]
|
try:
|
||||||
|
messages += test.runnable_input['fsp_messages']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||||
|
|
||||||
llm = ChatOllama(
|
ai_msg = llm.invoke(messages)
|
||||||
model=model,
|
|
||||||
seed=seed,
|
|
||||||
base_url=base_url
|
|
||||||
).bind_tools(tools)
|
|
||||||
|
|
||||||
ai_msg = llm.invoke(prompt)
|
messages += [ ai_msg ]
|
||||||
|
|
||||||
prompt.append(ai_msg)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for i in range(len(ai_msg.tool_calls)):
|
assert isinstance(ai_msg, AIMessage)
|
||||||
tool_call = ai_msg.tool_calls[i]
|
calls = ai_msg.tool_calls
|
||||||
selected_tool = tools_dict[tool_call["name"].lower()]
|
for call in calls:
|
||||||
tool_msg = selected_tool.invoke(tool_call)
|
selected_tool = tools_dict[call["name"].lower()]
|
||||||
prompt.append(tool_msg)
|
tool_msg = selected_tool.invoke(call)
|
||||||
ai_msg = llm.invoke(prompt)
|
messages.append(tool_msg)
|
||||||
|
ai_msg = llm.invoke(messages)
|
||||||
tool_calls.append({
|
tool_calls.append({
|
||||||
"tool": tool_call["name"],
|
"tool": call["name"],
|
||||||
"args": tool_call["args"],
|
"args": call["args"],
|
||||||
"index": 0
|
|
||||||
})
|
})
|
||||||
except IndexError: # LLM didnt use a tool -> jsut return the content
|
except IndexError: # LLM didnt use a tool -> jsut return the content
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
if len(ai_msg.tool_calls) > 0:
|
||||||
|
to_append_calls = []
|
||||||
|
for call in ai_msg.tool_calls:
|
||||||
|
to_append_calls.append({ "tool": call["name"], "args": call["args"] })
|
||||||
|
return {
|
||||||
|
"answer": ">>LLM did not respond conversationally<<",
|
||||||
|
"tool_calls": tool_calls + to_append_calls,
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
"answer": ai_msg.content,
|
"answer": ai_msg.content,
|
||||||
"tool_calls": tool_calls
|
"tool_calls": tool_calls,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
|
||||||
|
|
||||||
def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
index = -1
|
index = -1
|
||||||
@@ -79,6 +103,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
nonlocal index
|
nonlocal index
|
||||||
|
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
index += 1
|
index += 1
|
||||||
return "tools"
|
return "tools"
|
||||||
@@ -113,10 +138,10 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
||||||
except KeyError as e:
|
except KeyError:
|
||||||
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
tool_result = 'Tool got invalid input:\n' + e
|
tool_result = 'Tool got invalid input:\n' + str(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tool_result = 'Error: ' + str(e)
|
tool_result = 'Error: ' + str(e)
|
||||||
|
|
||||||
@@ -135,11 +160,7 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
for key in tools_dict:
|
for key in tools_dict:
|
||||||
tools.append(tools_dict[key])
|
tools.append(tools_dict[key])
|
||||||
tool_node = NxToolNode(tools)
|
tool_node = NxToolNode(tools)
|
||||||
llm = ChatOllama(
|
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||||
model=model,
|
|
||||||
seed=seed,
|
|
||||||
base_url=base_url
|
|
||||||
).bind_tools(tools)
|
|
||||||
|
|
||||||
workflow = StateGraph(MessagesState)
|
workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
@@ -156,124 +177,21 @@ def agent_with_tools(model: str, seed: int, test: Test, base_url: str) -> str:
|
|||||||
|
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|
||||||
# example with a single tool call
|
# compose "history" supprts few shot prompting
|
||||||
start_messages = [
|
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||||
SystemMessage(content=test.runnable_input['system_msg']),
|
|
||||||
HumanMessage(content=test.runnable_input['human_msg'])
|
|
||||||
]
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
for chunk in graph.stream(
|
|
||||||
{"messages": start_messages},
|
|
||||||
stream_mode="values",
|
|
||||||
): chunks.append(chunk["messages"][-1])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"answer": chunks[-1].content,
|
|
||||||
"tool_calls": tool_calls
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def agent_with_tools_fsp(model: str, seed: int, test: Test, base_url: str) -> str:
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
index = -1
|
|
||||||
|
|
||||||
def should_continue(state: MessagesState) -> Literal["tools", "__end__"]:
|
|
||||||
messages = state["messages"]
|
|
||||||
last_message = messages[-1]
|
|
||||||
nonlocal index
|
|
||||||
if last_message.tool_calls:
|
|
||||||
index += 1
|
|
||||||
return "tools"
|
|
||||||
return "__end__"
|
|
||||||
|
|
||||||
def call_llm(state: MessagesState):
|
|
||||||
messages = state["messages"]
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
return {"messages": [response]}
|
|
||||||
|
|
||||||
class NxToolNode:
|
|
||||||
"""A node that runs the tools requested in the last AIMessage."""
|
|
||||||
|
|
||||||
def __init__(self, tools: list) -> None:
|
|
||||||
self.tools_by_name = {tool.name: tool for tool in tools}
|
|
||||||
|
|
||||||
def __call__(self, inputs: dict):
|
|
||||||
if messages := inputs.get("messages", []):
|
|
||||||
message = messages[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError("No message found in input")
|
|
||||||
outputs = []
|
|
||||||
for tool_call in message.tool_calls:
|
|
||||||
|
|
||||||
nonlocal tool_calls
|
|
||||||
nonlocal index
|
|
||||||
tool_calls.append({
|
|
||||||
"tool": tool_call["name"],
|
|
||||||
"args": tool_call["args"],
|
|
||||||
"index": index
|
|
||||||
})
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_result = self.tools_by_name[tool_call["name"]].invoke(tool_call["args"])
|
start_messages += test.runnable_input['fsp_messages']
|
||||||
except KeyError as e:
|
except KeyError:
|
||||||
tool_result = f'Error: Tool with name `{tool_call["name"]}` does not exist. Available tools are: {[tool.name for tool in tools]}'
|
pass
|
||||||
except ValidationError as e:
|
start_messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||||
tool_result = 'Tool got invalid input:\n' + e
|
|
||||||
except Exception as e:
|
|
||||||
tool_result = 'Error: ' + str(e)
|
|
||||||
|
|
||||||
outputs.append(
|
|
||||||
ToolMessage(
|
|
||||||
content=json.dumps(tool_result),
|
|
||||||
name=tool_call["name"],
|
|
||||||
tool_call_id=tool_call["id"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return {"messages": outputs}
|
|
||||||
|
|
||||||
|
|
||||||
tools_dict = test.runnable_input['tools']
|
|
||||||
tools = []
|
|
||||||
for key in tools_dict:
|
|
||||||
tools.append(tools_dict[key])
|
|
||||||
tool_node = NxToolNode(tools)
|
|
||||||
llm = ChatOllama(
|
|
||||||
model=model,
|
|
||||||
seed=seed,
|
|
||||||
base_url=base_url
|
|
||||||
).bind_tools(tools)
|
|
||||||
|
|
||||||
workflow = StateGraph(MessagesState)
|
|
||||||
|
|
||||||
# Define the two nodes we will cycle between
|
|
||||||
workflow.add_node("agent", call_llm)
|
|
||||||
workflow.add_node("tools", tool_node)
|
|
||||||
|
|
||||||
workflow.add_edge("__start__", "agent")
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"agent",
|
|
||||||
should_continue,
|
|
||||||
)
|
|
||||||
workflow.add_edge("tools", "agent")
|
|
||||||
|
|
||||||
graph = workflow.compile()
|
|
||||||
|
|
||||||
# example with a single tool call
|
|
||||||
start_messages = [ SystemMessage(test.runnable_input['system_msg']) ] + test.runnable_input['fsp_messages'] + [ HumanMessage(test.runnable_input['human_msg']) ]
|
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
for chunk in graph.stream(
|
for chunk in graph.stream(
|
||||||
{"messages": start_messages},
|
{"messages": start_messages},
|
||||||
stream_mode="values",
|
stream_mode="values",
|
||||||
): chunks.append(chunk["messages"][-1])
|
):
|
||||||
|
chunks.append(chunk["messages"][-1])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": chunks[-1].content,
|
"answer": chunks[-1].content,
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable, Any
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Test:
|
|
||||||
name: str
|
|
||||||
runnable: Callable
|
|
||||||
runnable_input: dict
|
|
||||||
validator: Callable
|
|
||||||
validation_input: dict
|
|
||||||
@@ -6,14 +6,14 @@ from typing import Union
|
|||||||
|
|
||||||
@tool
|
@tool
|
||||||
def add(a: float, b: float) -> str:
|
def add(a: float, b: float) -> str:
|
||||||
"""Adds a+b and retuns the sum"""
|
"""Adds a+b and returns the sum"""
|
||||||
af = float(a)
|
af = float(a)
|
||||||
bf = float(b)
|
bf = float(b)
|
||||||
return f"{a} + {b} = {a+b}"
|
return f"{a} + {b} = {a+b}"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def multiply(a: float, b: float) -> str:
|
def multiply(a: float, b: float) -> str:
|
||||||
"""Multiplies a*b and retuns the product"""
|
"""Multiplies a*b and returns the product"""
|
||||||
af = float(a)
|
af = float(a)
|
||||||
bf = float(b)
|
bf = float(b)
|
||||||
return f"{a} * {b} = {a*b}"
|
return f"{a} * {b} = {a*b}"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from langchain_ollama.chat_models import ChatOllama
|
from langchain_ollama.chat_models import ChatOllama
|
||||||
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from libs.test_class import Test
|
from libs.classes import Test
|
||||||
from re import search
|
from re import search
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
|||||||
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
|
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
|
||||||
{validation_input}
|
{validation_input}
|
||||||
|
|
||||||
else as incorrect. Only use the rate tool. Do not answer conversationally.""")),
|
else as incorrect. Only use the `rate` tool. You do not have accesss to any other tools. Do not answer conversationally.""")),
|
||||||
HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
|
HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
|
||||||
{system_msg}
|
{system_msg}
|
||||||
|
|
||||||
@@ -50,7 +50,10 @@ def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
|||||||
elif ret_str.lower() == 'false': return False
|
elif ret_str.lower() == 'false': return False
|
||||||
else: raise Exception(f"rate tool retured {ret_str}")
|
else: raise Exception(f"rate tool retured {ret_str}")
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
print(f"\033[0;31mValidation Error \033[0mof {test.name} <{ai_msg.content[:20]}...> Retrying...")
|
print(f"\033[0;31mValidation Error of\033[0m {test.name} \033[0;31m<\033[0m{ai_msg.content[:20]}\033[0;31m...> Retrying...\033[0m")
|
||||||
|
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\033[0;31mValidation Error \033[0mof {test.name} \033[0;31m<\033[0m{e}\033[0;31m> Retrying...\033[0m")
|
||||||
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
||||||
|
|
||||||
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
langchain
|
langchain
|
||||||
langchain-core
|
langchain-core
|
||||||
langchain-ollama
|
langchain-ollama
|
||||||
|
langchain-community
|
||||||
langgraph
|
langgraph
|
||||||
seaborn
|
seaborn
|
||||||
pandas
|
pandas
|
||||||
|
|||||||
@@ -1,13 +1,112 @@
|
|||||||
models = [
|
from libs.classes import Model
|
||||||
"llama3.1", # 8b
|
|
||||||
"llama3.1:70b",
|
models = {
|
||||||
"llama3-groq-tool-use", # latest
|
245: Model(
|
||||||
"llama3-groq-tool-use:70b",
|
display_name="llama3.1 8b",
|
||||||
# "mixtral:8x7b",
|
identifier="llama3.1",
|
||||||
"mixtral:8x22b",
|
supports_tools=True,
|
||||||
# "gemma2:2b",
|
parameter_count_in_b=8
|
||||||
# "phi3", # 3.8b
|
),
|
||||||
# "tinyllama:1.1b",
|
238: Model(
|
||||||
"mistral-nemo:12b",
|
display_name="llama3.1 70b",
|
||||||
"command-r-plus:104b",
|
identifier="llama3.1:70b",
|
||||||
]
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=70
|
||||||
|
),
|
||||||
|
120: Model(
|
||||||
|
display_name="llama3 groq TU 8b",
|
||||||
|
identifier="llama3-groq-tool-use",
|
||||||
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=8
|
||||||
|
),
|
||||||
|
890: Model(
|
||||||
|
display_name="llama3 groq TU 70b",
|
||||||
|
identifier="llama3-groq-tool-use:70b",
|
||||||
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=70
|
||||||
|
),
|
||||||
|
348: Model(
|
||||||
|
display_name="Mixtral MoE 8x7b",
|
||||||
|
identifier="mixtral:8x7b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=13,
|
||||||
|
),
|
||||||
|
789: Model(
|
||||||
|
display_name="Mixtral MoE 8x22b",
|
||||||
|
identifier="mixtral:8x22b",
|
||||||
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=39
|
||||||
|
),
|
||||||
|
445: Model(
|
||||||
|
display_name="Gemma2 2b",
|
||||||
|
identifier="gemma2:2b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=2
|
||||||
|
),
|
||||||
|
475: Model(
|
||||||
|
display_name="Gemma2 9b",
|
||||||
|
identifier="gemma2:2b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=9
|
||||||
|
),
|
||||||
|
626: Model(
|
||||||
|
display_name="Gemma2 27b",
|
||||||
|
identifier="gemma2:2b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=27
|
||||||
|
),
|
||||||
|
229: Model(
|
||||||
|
display_name="Phi3 3.8b",
|
||||||
|
identifier="phi3",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=3.8
|
||||||
|
),
|
||||||
|
903: Model(
|
||||||
|
display_name="Tinyllama 1.1b",
|
||||||
|
identifier="tinyllama:1.1b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=1.1
|
||||||
|
),
|
||||||
|
670: Model(
|
||||||
|
display_name="Mistral Nemo 12b",
|
||||||
|
identifier="mistral-nemo:12b",
|
||||||
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=12
|
||||||
|
),
|
||||||
|
404: Model(
|
||||||
|
display_name="Command R+ 104b",
|
||||||
|
identifier="command-r-plus:104b",
|
||||||
|
supports_tools=True,
|
||||||
|
parameter_count_in_b=104
|
||||||
|
),
|
||||||
|
701: Model(
|
||||||
|
display_name="Yi 6b",
|
||||||
|
identifier="yi:7b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=6
|
||||||
|
),
|
||||||
|
704: Model(
|
||||||
|
display_name="Yi 6b",
|
||||||
|
identifier="yi:7b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=6
|
||||||
|
),
|
||||||
|
724: Model(
|
||||||
|
display_name="Yi 34b",
|
||||||
|
identifier="yi:34b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=34
|
||||||
|
),
|
||||||
|
129: Model(
|
||||||
|
display_name="Yi 34b",
|
||||||
|
identifier="yi:34b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=34
|
||||||
|
),
|
||||||
|
853: Model(
|
||||||
|
display_name="Qwen2 0.5b",
|
||||||
|
identifier="qwen2:0.5b",
|
||||||
|
supports_tools=False,
|
||||||
|
parameter_count_in_b=0.5
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from libs.test_class import Test
|
from libs.classes import Test
|
||||||
from libs.runnables import *
|
from libs.runnables import basic_prompt, one_tool_call_answer, agent_with_tools
|
||||||
from libs.validators import *
|
from libs.validators import regex_match_any, system_human_answer_match
|
||||||
from libs.tools import *
|
from libs.tools import add, multiply, get_current_date_and_time, get_notes_in_timespan, get_notes_containing, write_note
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage, AIMessage
|
||||||
|
|
||||||
tests = {
|
tests = {
|
||||||
607: Test(
|
607: Test(
|
||||||
name="Healthy Vegetables in Chinese",
|
name="Healthy Vegetables in Chinese",
|
||||||
runnable=basic,
|
runnable=basic_prompt,
|
||||||
runnable_input={
|
runnable_input={
|
||||||
"system_msg": "You are a helpful assistant. You serve people across the globe.",
|
"system_msg": "You are a helpful assistant. You serve people across the globe.",
|
||||||
"human_msg": "什么蔬菜最健康?",
|
"human_msg": "什么蔬菜最健康?",
|
||||||
@@ -67,7 +67,7 @@ tests = {
|
|||||||
"get_current_date_and_time": get_current_date_and_time,
|
"get_current_date_and_time": get_current_date_and_time,
|
||||||
"get_notes_in_timespan": get_notes_in_timespan,
|
"get_notes_in_timespan": get_notes_in_timespan,
|
||||||
"get_notes_containing": get_notes_containing,
|
"get_notes_containing": get_notes_containing,
|
||||||
"Write note": write_note
|
"Write note": write_note,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
validator=system_human_answer_match,
|
validator=system_human_answer_match,
|
||||||
@@ -96,7 +96,7 @@ tests = {
|
|||||||
),
|
),
|
||||||
856: Test(
|
856: Test(
|
||||||
name="Notes from last Saturday TSO FSP",
|
name="Notes from last Saturday TSO FSP",
|
||||||
runnable=agent_with_tools_fsp,
|
runnable=agent_with_tools,
|
||||||
runnable_input={
|
runnable_input={
|
||||||
"system_msg": "You are a helpful assistant. You can use tools to accomplish tasks. Once you've called a tool, the resulting tool_message content can be taken into consideration again. With that you can do \"multiple rounds\" of tool calling. To know the date, use the tool get_current_date_and_time.",
|
"system_msg": "You are a helpful assistant. You can use tools to accomplish tasks. Once you've called a tool, the resulting tool_message content can be taken into consideration again. With that you can do \"multiple rounds\" of tool calling. To know the date, use the tool get_current_date_and_time.",
|
||||||
"fsp_messages": [
|
"fsp_messages": [
|
||||||
|
|||||||
@@ -3,13 +3,14 @@ from suite_settings.models import models
|
|||||||
from suite_settings.seeds import seeds
|
from suite_settings.seeds import seeds
|
||||||
from suite_settings.tests import tests
|
from suite_settings.tests import tests
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
results = run_tests(
|
run_tests(
|
||||||
models=models,
|
models=models,
|
||||||
seeds=seeds,
|
seeds=seeds,
|
||||||
tests=tests,
|
tests=tests,
|
||||||
base_url="http://bolt.hs-mittweida.de:11434"
|
base_url="http://bolt.hs-mittweida.de:11434",
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
48
visualize.py
48
visualize.py
@@ -3,7 +3,6 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from math import pi
|
|
||||||
|
|
||||||
# Load the JSON data
|
# Load the JSON data
|
||||||
with open('saved_results.json', 'r') as f:
|
with open('saved_results.json', 'r') as f:
|
||||||
@@ -14,7 +13,7 @@ results = []
|
|||||||
for test_hash, test_data in data.items():
|
for test_hash, test_data in data.items():
|
||||||
results.append({
|
results.append({
|
||||||
"hash": test_hash,
|
"hash": test_hash,
|
||||||
"model": test_data['model'],
|
"model": test_data['model_name'],
|
||||||
"seed": test_data['seed'],
|
"seed": test_data['seed'],
|
||||||
"test_name": test_data['test_name'],
|
"test_name": test_data['test_name'],
|
||||||
"validation": test_data['validation']
|
"validation": test_data['validation']
|
||||||
@@ -61,52 +60,7 @@ plt.savefig('validation_results_by_test_name.png')
|
|||||||
|
|
||||||
|
|
||||||
## 3rd Chart
|
## 3rd Chart
|
||||||
# Prepare data for the spider chart
|
|
||||||
models = df['model'].unique()
|
|
||||||
|
|
||||||
# Calculate the pass rate for each model on each test
|
|
||||||
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
pass_rate = pd.pivot_table(df, values='validation', index='model', columns='test_name', aggfunc="mean", fill_value=0)
|
||||||
tests = df['test_name'].unique().tolist()
|
|
||||||
|
|
||||||
# Initialize the spider plot
|
|
||||||
num_vars = len(pass_rate)-1
|
|
||||||
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
|
|
||||||
angles += [ angles[0] ]
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
|
|
||||||
|
|
||||||
# Plot each model's performance
|
|
||||||
for model in models:
|
|
||||||
values = pass_rate.loc[model].tolist()
|
|
||||||
values += [ values[0] ]
|
|
||||||
ax.fill(angles, values, alpha=0.25)
|
|
||||||
ax.plot(angles, values, label=model)
|
|
||||||
#
|
|
||||||
|
|
||||||
# Configure the spider chart
|
|
||||||
ax.set_theta_offset(pi / 2)
|
|
||||||
ax.set_theta_direction(-1)
|
|
||||||
|
|
||||||
tests.append(tests[0])
|
|
||||||
tests.pop(0)
|
|
||||||
|
|
||||||
ax.set_xticks(angles[:-1])
|
|
||||||
ax.set_xticklabels(tests)
|
|
||||||
|
|
||||||
ax.set_yticks(np.linspace(0, 1, 5))
|
|
||||||
ax.set_yticklabels([f'{int(i * 100)}%' for i in np.linspace(0, 1, 5)], color="grey", size=8)
|
|
||||||
ax.set_ylim(0, 1)
|
|
||||||
|
|
||||||
plt.title('Model Performance on Each Test')
|
|
||||||
plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig('model_performance_spider_chart.png')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 4th chart
|
|
||||||
# Create a heatmap
|
# Create a heatmap
|
||||||
plt.figure(figsize=(8, 8))
|
plt.figure(figsize=(8, 8))
|
||||||
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
sns.heatmap(pass_rate*100, annot=True, fmt=".0f", cmap=sns.color_palette("blend:#100,#255,#4a3", as_cmap=True), cbar=True, annot_kws={"size": 10})
|
||||||
|
|||||||
Reference in New Issue
Block a user