m2
This commit is contained in:
@@ -16,3 +16,9 @@ class Model:
|
|||||||
supports_tools: bool
|
supports_tools: bool
|
||||||
parameter_count_in_b: float
|
parameter_count_in_b: float
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Technique:
|
||||||
|
name: str
|
||||||
|
for_supports_tools: bool
|
||||||
|
for_not_supports_tools: bool
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
|||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import BaseTool, Tool
|
from langchain_core.tools import BaseTool, Tool
|
||||||
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
from libs.functions import nxhash
|
from libs.functions import nxhash
|
||||||
|
|
||||||
@@ -87,65 +88,24 @@ _BM = TypeVar("_BM", bound=BaseModel)
|
|||||||
_DictOrPydantic = Union[Dict, _BM]
|
_DictOrPydantic = Union[Dict, _BM]
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaError(Exception):
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message # Store the message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and (
|
return isinstance(obj, type) and (
|
||||||
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_ollama_tool(tool: Any) -> Dict:
|
|
||||||
"""Convert a tool to an Ollama tool."""
|
|
||||||
description = None
|
|
||||||
if _is_pydantic_class(tool):
|
|
||||||
schema = tool.construct().schema()
|
|
||||||
name = schema["title"]
|
|
||||||
elif isinstance(tool, BaseTool):
|
|
||||||
schema = tool.tool_call_schema.schema()
|
|
||||||
name = tool.get_name()
|
|
||||||
description = tool.description
|
|
||||||
elif is_basemodel_instance(tool):
|
|
||||||
schema = tool.get_input_schema().schema()
|
|
||||||
name = tool.get_name()
|
|
||||||
description = tool.description
|
|
||||||
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
|
||||||
return tool.copy()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"""Cannot convert {tool} to an Ollama tool.
|
|
||||||
{tool} needs to be a Pydantic class, model, or a dict."""
|
|
||||||
)
|
|
||||||
definition = {"name": name, "parameters": schema}
|
|
||||||
if description:
|
|
||||||
definition["description"] = description
|
|
||||||
|
|
||||||
return definition
|
|
||||||
|
|
||||||
def parse_response(message: BaseMessage) -> str:
|
|
||||||
"""Extract `function_call` from `AIMessage`."""
|
|
||||||
if isinstance(message, AIMessage):
|
|
||||||
kwargs = message.additional_kwargs
|
|
||||||
tool_calls = message.tool_calls
|
|
||||||
if len(tool_calls) > 0:
|
|
||||||
tool_call = tool_calls[-1]
|
|
||||||
args = tool_call.get("args")
|
|
||||||
return json.dumps(args)
|
|
||||||
elif "function_call" in kwargs:
|
|
||||||
if "arguments" in kwargs["function_call"]:
|
|
||||||
return kwargs["function_call"]["arguments"]
|
|
||||||
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
|
|
||||||
else:
|
|
||||||
raise ValueError("`tool_calls` missing from AIMessage: {message}")
|
|
||||||
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaFunctions(ChatOllama):
|
class OllamaFunctions(ChatOllama):
|
||||||
"""Function chat model that uses Ollama API."""
|
"""Function chat model that uses Ollama API."""
|
||||||
|
|
||||||
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
|
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
|
||||||
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
|
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
|
||||||
|
max_tool_call_fails: int = 5
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
@@ -157,61 +117,55 @@ class OllamaFunctions(ChatOllama):
|
|||||||
|
|
||||||
|
|
||||||
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
|
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
|
||||||
|
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
|
||||||
|
"""Convert a tool to an Ollama tool."""
|
||||||
|
description = None
|
||||||
|
if _is_pydantic_class(tool):
|
||||||
|
schema = tool.construct().schema()
|
||||||
|
name = schema["title"]
|
||||||
|
elif isinstance(tool, BaseTool):
|
||||||
|
schema = tool.tool_call_schema.schema()
|
||||||
|
name = tool.get_name()
|
||||||
|
description = tool.description
|
||||||
|
elif is_basemodel_instance(tool):
|
||||||
|
schema = tool.get_input_schema().schema()
|
||||||
|
name = tool.get_name()
|
||||||
|
description = tool.description
|
||||||
|
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
||||||
|
return tool.copy()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"""Cannot convert {tool} to an Ollama tool.
|
||||||
|
{tool} needs to be a Pydantic class, model, or a dict."""
|
||||||
|
)
|
||||||
|
definition = {"name": name, "parameters": schema}
|
||||||
|
if description:
|
||||||
|
definition["description"] = description
|
||||||
|
|
||||||
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
return definition
|
||||||
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
|
||||||
call_list = []
|
|
||||||
for c in tool_calls:
|
|
||||||
call_list.append({
|
|
||||||
"id": nxhash(c['id'])[-4:],
|
|
||||||
"tool": c['name'],
|
|
||||||
"args": c['args']
|
|
||||||
})
|
|
||||||
if len(call_list) == 1:
|
|
||||||
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
|
||||||
else:
|
|
||||||
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
|
||||||
formated_history = ""
|
|
||||||
system_msg = ""
|
|
||||||
for m in messages:
|
|
||||||
|
|
||||||
if formated_history != "":
|
|
||||||
formated_history += "\n\n"
|
|
||||||
|
|
||||||
if isinstance(m, SystemMessage):
|
|
||||||
system_msg += str(m.content)
|
|
||||||
elif isinstance(m, HumanMessage):
|
|
||||||
formated_history += "The Human said:\n" + str(m.content)
|
|
||||||
elif isinstance(m, AIMessage) and m.tool_calls:
|
|
||||||
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
|
||||||
elif isinstance(m, ToolMessage):
|
|
||||||
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
|
||||||
elif isinstance(m, AIMessage) and not m.tool_calls:
|
|
||||||
formated_history += "You said:\n" + str(m.content)
|
|
||||||
else:
|
|
||||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
|
||||||
|
|
||||||
return system_msg, formated_history
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
|
def _get_parsed_chat_result(self, chat_result_str: str) -> dict:
|
||||||
try:
|
try:
|
||||||
parsed_chat_result = json.loads(chat_result_str)
|
parsed_chat_result = json.loads(chat_result_str)
|
||||||
|
return parsed_chat_result
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
parsed_chat_result = chat_result_str
|
raise OllamaError(message="Error. Message is not valid JSON.")
|
||||||
return parsed_chat_result
|
|
||||||
|
|
||||||
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
|
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
|
||||||
if not parsed_chat_result:
|
if not d:
|
||||||
called_tool_name = None
|
called_tool_name = None
|
||||||
elif "tool" in parsed_chat_result:
|
elif "tool" in d:
|
||||||
called_tool_name = d["tool"] # per spec
|
called_tool_name = d["tool"] # per spec
|
||||||
elif "name" in d:
|
elif "name" in d:
|
||||||
called_tool_name = d["name"] # Phi3 often does this
|
called_tool_name = d["name"] # Phi3 often does this
|
||||||
elif "tool_name" in d:
|
elif "tool_name" in d:
|
||||||
called_tool_name = d["tool_name"] # Phi3 often does this
|
called_tool_name = d["tool_name"] # Phi3 often does this
|
||||||
elif "action" in d:
|
elif "action" in d:
|
||||||
called_tool_name = d["action"] # Phi3 does this
|
called_tool_name = d["action"] # Gemma2 does this
|
||||||
|
elif "task" in d:
|
||||||
|
called_tool_name = d["task"] # Gemma2 does this
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama):
|
|||||||
elif "tool_input" in d:
|
elif "tool_input" in d:
|
||||||
response = d["tool_input"]
|
response = d["tool_input"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
|
||||||
|
"{\n" +
|
||||||
|
' "tool": <name of the selected tool>,\n' +
|
||||||
|
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
|
||||||
|
"}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert isinstance(response, str)
|
assert isinstance(response, str)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
|
||||||
|
"{\n" +
|
||||||
|
' "tool": <name of the selected tool>,\n' +
|
||||||
|
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
|
||||||
|
"}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _extract_tool_args(self, d: dict) -> dict:
|
def _extract_tool_args(self, d: dict) -> dict:
|
||||||
if "tool_input" in parsed_chat_result:
|
if "tool_input" in d:
|
||||||
called_tool_args = d["tool_input"] # per spec
|
called_tool_args = d["tool_input"] # per spec
|
||||||
elif "input" in d:
|
elif "input" in d:
|
||||||
called_tool_args = d["input"] # Phi3 often does this
|
called_tool_args = d["input"] # Phi3 often does this
|
||||||
@@ -258,64 +220,114 @@ class OllamaFunctions(ChatOllama):
|
|||||||
called_tool_args = {}
|
called_tool_args = {}
|
||||||
return called_tool_args
|
return called_tool_args
|
||||||
|
|
||||||
# prepare generation
|
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||||
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
|
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||||
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
|
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||||
functions_str = json.dumps(functions_list, indent=2)
|
call_list = []
|
||||||
|
for c in tool_calls:
|
||||||
|
call_list.append({
|
||||||
|
"id": nxhash(c['id'])[-4:],
|
||||||
|
"tool": c['name'],
|
||||||
|
"args": c['args']
|
||||||
|
})
|
||||||
|
if len(call_list) == 1:
|
||||||
|
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||||
|
else:
|
||||||
|
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||||
|
formated_history = ""
|
||||||
|
system_msg = messages[0]
|
||||||
|
for m in messages[1:]:
|
||||||
|
|
||||||
# prepare generation with history
|
if formated_history != "":
|
||||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
formated_history += "\n\n"
|
||||||
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
|
||||||
|
|
||||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
|
||||||
system_message = system_message_prompt_template.format(
|
|
||||||
tools=functions_str,
|
|
||||||
history=formated_history,
|
|
||||||
system_msg=system_msg
|
|
||||||
)
|
|
||||||
final_messages = [ system_message ]
|
|
||||||
|
|
||||||
# prepare generation without history
|
if isinstance(m, SystemMessage):
|
||||||
else:
|
formated_history += "The system provided the info:\n" + str(m.content)
|
||||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
elif isinstance(m, HumanMessage):
|
||||||
system_message = system_message_prompt_template.format(
|
formated_history += "The Human said:\n" + str(m.content)
|
||||||
tools=functions_str
|
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||||
)
|
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||||
final_messages = [ system_message ] + messages
|
elif isinstance(m, ToolMessage):
|
||||||
|
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||||
|
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||||
|
formated_history += "You said:\n" + str(m.content)
|
||||||
|
else:
|
||||||
|
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||||
|
|
||||||
# genrerate chat result
|
return system_msg, formated_history
|
||||||
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
||||||
chat_result = response_message.generations[0].text
|
|
||||||
|
|
||||||
# chekc for validity
|
# prepare generation with history
|
||||||
if not isinstance(chat_result, str):
|
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||||
raise ValueError("OllamaFunctions does not support non-string output.")
|
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||||
|
|
||||||
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||||
|
system_message = system_message_prompt_template.format(
|
||||||
|
tools=functions_str,
|
||||||
|
history=formated_history,
|
||||||
|
system_msg=system_msg
|
||||||
|
)
|
||||||
|
final_messages = [ system_message ]
|
||||||
|
|
||||||
# make str to dict
|
# prepare generation without history
|
||||||
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
|
else:
|
||||||
# if model failed to return vailid json, just retrun the whole thing
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||||
if isinstance(parsed_chat_result, str):
|
system_message = system_message_prompt_template.format(
|
||||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
|
tools=functions_str
|
||||||
|
)
|
||||||
|
final_messages = [ system_message ] + messages
|
||||||
|
|
||||||
|
return final_messages
|
||||||
|
|
||||||
|
|
||||||
# get the called tool from the dict
|
|
||||||
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
|
|
||||||
|
|
||||||
if not called_tool:
|
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
|
||||||
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
|
||||||
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
|
|
||||||
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
|
||||||
else:
|
|
||||||
response_msg = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[ToolCall(
|
|
||||||
name=called_tool['name'],
|
|
||||||
args=_extract_tool_args(self, d=parsed_chat_result),
|
|
||||||
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
|
||||||
)],
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatResult(generations=[ChatGeneration(message=response_msg)])
|
# prepare generation
|
||||||
|
functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])]
|
||||||
|
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
|
||||||
|
functions_str = json.dumps(functions_list, indent=2)
|
||||||
|
|
||||||
|
# get messages to prompt with
|
||||||
|
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
|
||||||
|
|
||||||
|
# genrerate chat result
|
||||||
|
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
chat_result = response_message.generations[0].text
|
||||||
|
|
||||||
|
try:
|
||||||
|
# make str to dict
|
||||||
|
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
|
||||||
|
|
||||||
|
# get the called tool from the dict
|
||||||
|
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
|
||||||
|
|
||||||
|
if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL):
|
||||||
|
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||||
|
else:
|
||||||
|
response_msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[ToolCall(
|
||||||
|
name=called_tool['name'],
|
||||||
|
args=_extract_tool_args(self, d=parsed_chat_result),
|
||||||
|
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=response_msg)])
|
||||||
|
except OllamaError as e:
|
||||||
|
if failed_tool_calls < self.max_tool_call_fails:
|
||||||
|
# retry
|
||||||
|
messages.append(AIMessage(chat_result))
|
||||||
|
messages.append(SystemMessage(e.message))
|
||||||
|
return gen(self, failed_tool_calls+1, messages=messages)
|
||||||
|
else:
|
||||||
|
# return error
|
||||||
|
# return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))])
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))])
|
||||||
|
|
||||||
|
# inital call with no failed runs
|
||||||
|
return gen(self, failed_tool_calls=0, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from libs.classes import Test, Model
|
from os import name
|
||||||
|
from libs.classes import Technique, Test, Model
|
||||||
from libs.functions import nxhash
|
from libs.functions import nxhash
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int:
|
|||||||
collection_type = "models"
|
collection_type = "models"
|
||||||
elif isinstance(collection[list(collection.keys())[0]], Test):
|
elif isinstance(collection[list(collection.keys())[0]], Test):
|
||||||
collection_type = "tests"
|
collection_type = "tests"
|
||||||
|
elif isinstance(collection[list(collection.keys())[0]], Technique):
|
||||||
|
collection_type = "techniques"
|
||||||
else:
|
else:
|
||||||
raise TypeError("get_len: unsupported collection_type")
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
else:
|
else:
|
||||||
@@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int:
|
|||||||
case "tests":
|
case "tests":
|
||||||
for test_id in collection:
|
for test_id in collection:
|
||||||
maximum_length = max(maximum_length, len(collection[test_id].name))
|
maximum_length = max(maximum_length, len(collection[test_id].name))
|
||||||
|
case "techniques":
|
||||||
|
for technique_id in collection:
|
||||||
|
maximum_length = max(maximum_length, len(collection[technique_id].name))
|
||||||
case _:
|
case _:
|
||||||
for model_name in collection:
|
for model_name in collection:
|
||||||
raise TypeError("get_len: unsupported collection_type")
|
raise TypeError("get_len: unsupported collection_type")
|
||||||
@@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
|
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str):
|
||||||
try:
|
try:
|
||||||
print("Trying to load saved_results.json")
|
print("Trying to load saved_results.json")
|
||||||
with open("./saved_results.json", "r") as f:
|
with open("./saved_results.json", "r") as f:
|
||||||
@@ -53,88 +59,109 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
|
|||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
for test_id in tests:
|
for test_id in tests:
|
||||||
test = tests[test_id]
|
test = tests[test_id]
|
||||||
for seed in seeds:
|
for technique_id in techniques:
|
||||||
# Init dict
|
technique = techniques[technique_id]
|
||||||
combination = {
|
if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)):
|
||||||
'test_id': test_id,
|
continue
|
||||||
'model_id': model_id,
|
for seed in seeds:
|
||||||
'seed': seed,
|
# Init dict
|
||||||
}
|
combination = {
|
||||||
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
'test_id': test_id,
|
||||||
combination['test_name'] = test.name
|
'model_id': model_id,
|
||||||
combination['model_name'] = model.display_name
|
'seed': seed,
|
||||||
|
'technique_id': technique_id
|
||||||
# if hash_key == "DE3D137E":
|
}
|
||||||
# pass
|
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
|
||||||
|
|
||||||
if hash_key not in saved_results.keys():
|
combination.update({
|
||||||
try:
|
'test_name': test.name,
|
||||||
print("\033[0;35mModel '\033[0m" +
|
'model_name': model.display_name,
|
||||||
model.display_name +
|
'technique_name': technique.name,
|
||||||
"\033[0;35m'" +
|
})
|
||||||
(" " * (get_len(models) - len(model.display_name))) +
|
|
||||||
" with seed \033[0m\033[0;30m" +
|
# if hash_key == "DE3D137E":
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
# pass
|
||||||
"\033[0m" +
|
|
||||||
str(seed) +
|
if hash_key not in saved_results.keys():
|
||||||
"\033[0;35m now runs test '\033[0m" +
|
try:
|
||||||
test.name +
|
print("\033[0;35mModel '\033[0m" +
|
||||||
"\033[0;35m'" +
|
model.display_name +
|
||||||
(" " * (get_len(tests) - len(test.name))) +
|
"\033[0;35m'" +
|
||||||
" (\033[0m" +
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
hash_key +
|
" with seed \033[0m\033[0;30m" +
|
||||||
"\033[0;35m)\033[0m",
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
end=""
|
"\033[0m" +
|
||||||
)
|
str(seed) +
|
||||||
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
"\033[0;35m using technique '\033[0m" +
|
||||||
if isinstance(answer, str):
|
technique.name +
|
||||||
combination['answer'] = answer
|
"\033[0;35m'" +
|
||||||
# combination['tool_calls'] = [] # no entry
|
(" " * (get_len(techniques) - len(technique.name))) +
|
||||||
del answer
|
"\033[0;35m now runs test '\033[0m" +
|
||||||
elif isinstance(answer, dict): # calls
|
test.name +
|
||||||
combination['answer'] = answer['answer']
|
"\033[0;35m'" +
|
||||||
combination['tool_calls'] = answer['tool_calls']
|
(" " * (get_len(tests) - len(test.name))) +
|
||||||
del answer
|
" (\033[0m" +
|
||||||
else:
|
hash_key +
|
||||||
raise Exception(f"runnable returned unkown type {type(answer)}.")
|
"\033[0;35m)\033[0m",
|
||||||
|
end=""
|
||||||
|
)
|
||||||
|
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
||||||
|
if isinstance(answer, str):
|
||||||
|
combination['answer'] = answer
|
||||||
|
# combination['tool_calls'] = [] # no entry
|
||||||
|
del answer
|
||||||
|
elif isinstance(answer, dict): # calls
|
||||||
|
combination['answer'] = answer['answer']
|
||||||
|
combination['tool_calls'] = answer['tool_calls']
|
||||||
|
del answer
|
||||||
|
else:
|
||||||
|
raise Exception(f"runnable returned unkown type {type(answer)}.")
|
||||||
|
|
||||||
combination['test'] = test
|
combination['test'] = test
|
||||||
run_results[hash_key] = combination
|
run_results[hash_key] = combination
|
||||||
print("\r\033[0;32mModel '\033[0m" +
|
print("\r\033[0;32mModel '\033[0m" +
|
||||||
|
model.display_name +
|
||||||
|
"\033[0;32m'" +
|
||||||
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
|
" with seed \033[0m\033[0;30m" +
|
||||||
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
|
"\033[0m" +
|
||||||
|
str(seed) +
|
||||||
|
"\033[0;32m using technique '\033[0m" +
|
||||||
|
technique.name +
|
||||||
|
"\033[0;32m'" +
|
||||||
|
(" " * (get_len(techniques) - len(technique.name))) +
|
||||||
|
"\033[0;32m finished test '\033[0m" +
|
||||||
|
test.name +
|
||||||
|
"\033[0;32m'" +
|
||||||
|
(" " * (get_len(tests) - len(test.name))) +
|
||||||
|
" (\033[0m" +
|
||||||
|
hash_key +
|
||||||
|
"\033[0;32m)\033[0m"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
|
||||||
|
else:
|
||||||
|
print("\r\033[0;34mModel '\033[0m" +
|
||||||
model.display_name +
|
model.display_name +
|
||||||
"\033[0;32m'" +
|
"\033[0;34m'" +
|
||||||
(" " * (get_len(models) - len(model.display_name))) +
|
(" " * (get_len(models) - len(model.display_name))) +
|
||||||
" with seed \033[0m\033[0;30m" +
|
" with seed \033[0m\033[0;30m" +
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
("0" * (get_len(seeds) - len(str(seed)))) +
|
||||||
"\033[0m" +
|
"\033[0m" +
|
||||||
str(seed) +
|
str(seed) +
|
||||||
"\033[0;32m finished test '\033[0m" +
|
"\033[0;34m using technique '\033[0m" +
|
||||||
|
technique.name +
|
||||||
|
"\033[0;34m'" +
|
||||||
|
(" " * (get_len(techniques) - len(technique.name))) +
|
||||||
|
"\033[0;34m skipped test '\033[0m" +
|
||||||
test.name +
|
test.name +
|
||||||
"\033[0;32m'" +
|
"\033[0;34m'" +
|
||||||
(" " * (get_len(tests) - len(test.name))) +
|
(" " * (get_len(tests) - len(test.name))) +
|
||||||
" (\033[0m" +
|
" (\033[0m" +
|
||||||
hash_key +
|
hash_key +
|
||||||
"\033[0;32m)\033[0m"
|
"\033[0;34m) becasue its results exists in saved_results.json\033[0m"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
|
|
||||||
else:
|
|
||||||
print("\r\033[0;34mModel '\033[0m" +
|
|
||||||
model.display_name +
|
|
||||||
"\033[0;34m'" +
|
|
||||||
(" " * (get_len(models) - len(model.display_name))) +
|
|
||||||
" with seed \033[0m\033[0;30m" +
|
|
||||||
("0" * (get_len(seeds) - len(str(seed)))) +
|
|
||||||
"\033[0m" +
|
|
||||||
str(seed) +
|
|
||||||
"\033[0;34m skipped test '\033[0m" +
|
|
||||||
test.name +
|
|
||||||
"\033[0;34m'" +
|
|
||||||
(" " * (get_len(tests) - len(test.name))) +
|
|
||||||
" (\033[0m" +
|
|
||||||
hash_key +
|
|
||||||
"\033[0;34m) becasue its results exists in saved_results.json\033[0m"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Validate Results
|
# Validate Results
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
|
|||||||
model=model.identifier,
|
model=model.identifier,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
format="json"
|
format="json",
|
||||||
|
max_tool_call_fails=3,
|
||||||
|
temperature=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
@@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
|
|||||||
tool_msg = selected_tool.invoke(call)
|
tool_msg = selected_tool.invoke(call)
|
||||||
messages.append(tool_msg)
|
messages.append(tool_msg)
|
||||||
ai_msg = llm.invoke(messages)
|
ai_msg = llm.invoke(messages)
|
||||||
|
i = 0
|
||||||
|
while isinstance(ai_msg, SystemMessage):
|
||||||
|
i += 1
|
||||||
|
if i <= 5:
|
||||||
|
return {
|
||||||
|
"answer": ">>LLM failed to use tools<<",
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
messages.append(ai_msg)
|
||||||
|
ai_msg = llm.invoke(messages)
|
||||||
|
|
||||||
tool_calls.append({
|
tool_calls.append({
|
||||||
"tool": call["name"],
|
"tool": call["name"],
|
||||||
"args": call["args"],
|
"args": call["args"],
|
||||||
|
"times_failed": i
|
||||||
})
|
})
|
||||||
except IndexError: # LLM didnt use a tool -> jsut return the content
|
except IndexError: # LLM didnt use a tool -> jsut return the content
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
@@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
|
|||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
nonlocal index
|
nonlocal index
|
||||||
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
|
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
index += 1
|
index += 1
|
||||||
return "tools"
|
return "tools"
|
||||||
@@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
|
|||||||
should_continue,
|
should_continue,
|
||||||
)
|
)
|
||||||
workflow.add_edge("tools", "agent")
|
workflow.add_edge("tools", "agent")
|
||||||
|
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
# compose "history" supprts few shot prompting
|
# compose "history" supprts few shot prompting
|
||||||
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
|
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||||
try:
|
try:
|
||||||
@@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
|
|||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
for chunk in graph.stream(
|
try:
|
||||||
{"messages": start_messages},
|
for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
|
||||||
stream_mode="values",
|
chunks.append(chunk["messages"][-1])
|
||||||
):
|
except RecursionError:
|
||||||
chunks.append(chunk["messages"][-1])
|
return {
|
||||||
|
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": chunks[-1].content,
|
"answer": chunks[-1].content,
|
||||||
|
|||||||
@@ -81,13 +81,13 @@ models = {
|
|||||||
),
|
),
|
||||||
701: Model(
|
701: Model(
|
||||||
display_name="Yi 6b",
|
display_name="Yi 6b",
|
||||||
identifier="yi:7b",
|
identifier="yi:6b",
|
||||||
supports_tools=False,
|
supports_tools=False,
|
||||||
parameter_count_in_b=6
|
parameter_count_in_b=6
|
||||||
),
|
),
|
||||||
704: Model(
|
704: Model(
|
||||||
display_name="Yi 6b",
|
display_name="Yi 9b",
|
||||||
identifier="yi:7b",
|
identifier="yi:9b",
|
||||||
supports_tools=False,
|
supports_tools=False,
|
||||||
parameter_count_in_b=6
|
parameter_count_in_b=6
|
||||||
),
|
),
|
||||||
@@ -97,12 +97,6 @@ models = {
|
|||||||
supports_tools=False,
|
supports_tools=False,
|
||||||
parameter_count_in_b=34
|
parameter_count_in_b=34
|
||||||
),
|
),
|
||||||
129: Model(
|
|
||||||
display_name="Yi 34b",
|
|
||||||
identifier="yi:34b",
|
|
||||||
supports_tools=False,
|
|
||||||
parameter_count_in_b=34
|
|
||||||
),
|
|
||||||
853: Model(
|
853: Model(
|
||||||
display_name="Qwen2 0.5b",
|
display_name="Qwen2 0.5b",
|
||||||
identifier="qwen2:0.5b",
|
identifier="qwen2:0.5b",
|
||||||
|
|||||||
19
suite_settings/techniques.py
Normal file
19
suite_settings/techniques.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from libs.classes import Technique
|
||||||
|
|
||||||
|
techniques = {
|
||||||
|
190: Technique(
|
||||||
|
name="Native",
|
||||||
|
for_supports_tools=True,
|
||||||
|
for_not_supports_tools=False,
|
||||||
|
),
|
||||||
|
903: Technique(
|
||||||
|
name="Long System Message",
|
||||||
|
for_supports_tools=False,
|
||||||
|
for_not_supports_tools=True,
|
||||||
|
),
|
||||||
|
# 572: Technique(
|
||||||
|
# name="Tool to System Messsages",
|
||||||
|
# for_supports_tools=False,
|
||||||
|
# for_not_supports_tools=True,
|
||||||
|
# ),
|
||||||
|
}
|
||||||
@@ -121,12 +121,13 @@ tests = {
|
|||||||
"Write note": write_note
|
"Write note": write_note
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
validator=system_human_answer_match,
|
validator=system_human_answer_match,
|
||||||
validation_input={
|
validation_input={
|
||||||
"criteria": dedent("""- containing the information that the Human should call Wolfgang
|
"criteria": dedent("""- containing the information that the Human should call Wolfgang
|
||||||
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
|
||||||
# 363: Test(),
|
# 363: Test(),
|
||||||
# 600: Test(),
|
# 600: Test(),
|
||||||
# 221: Test(),
|
# 221: Test(),
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from libs.classes import Technique
|
||||||
from libs.run_tests import run_tests
|
from libs.run_tests import run_tests
|
||||||
from suite_settings.models import models
|
from suite_settings.models import models
|
||||||
from suite_settings.seeds import seeds
|
from suite_settings.seeds import seeds
|
||||||
from suite_settings.tests import tests
|
from suite_settings.tests import tests
|
||||||
|
from suite_settings.techniques import techniques
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -10,6 +12,7 @@ def main():
|
|||||||
models=models,
|
models=models,
|
||||||
seeds=seeds,
|
seeds=seeds,
|
||||||
tests=tests,
|
tests=tests,
|
||||||
|
techniques=techniques,
|
||||||
base_url="http://bolt.hs-mittweida.de:11434",
|
base_url="http://bolt.hs-mittweida.de:11434",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user