This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-25 20:10:53 +02:00
parent a578dd26a0
commit 2723ced901
8 changed files with 307 additions and 229 deletions

View File

@@ -16,3 +16,9 @@ class Model:
supports_tools: bool supports_tools: bool
parameter_count_in_b: float parameter_count_in_b: float
@dataclass
class Technique:
name: str
for_supports_tools: bool
for_not_supports_tools: bool

View File

@@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from textwrap import dedent
from libs.functions import nxhash from libs.functions import nxhash
@@ -87,65 +88,24 @@ _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[Dict, _BM]
class OllamaError(Exception):
def __init__(self, message):
self.message = message # Store the message
super().__init__(message)
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and ( return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__ is_basemodel_subclass(obj) or BaseModel in obj.__bases__
) )
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.construct().schema()
name = schema["title"]
elif isinstance(tool, BaseTool):
schema = tool.tool_call_schema.schema()
name = tool.get_name()
description = tool.description
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
return definition
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
class OllamaFunctions(ChatOllama): class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API.""" """Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
max_tool_call_fails: int = 5
def __init__(self, **kwargs: Any) -> None: def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
def bind_tools( def bind_tools(
@@ -157,61 +117,55 @@ class OllamaFunctions(ChatOllama):
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult: def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.construct().schema()
name = schema["title"]
elif isinstance(tool, BaseTool):
schema = tool.tool_call_schema.schema()
name = tool.get_name()
description = tool.description
elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]: return definition
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = ""
for m in messages:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
system_msg += str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]: def _get_parsed_chat_result(self, chat_result_str: str) -> dict:
try: try:
parsed_chat_result = json.loads(chat_result_str) parsed_chat_result = json.loads(chat_result_str)
return parsed_chat_result
except json.JSONDecodeError: except json.JSONDecodeError:
parsed_chat_result = chat_result_str raise OllamaError(message="Error. Message is not valid JSON.")
return parsed_chat_result
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType: def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
if not parsed_chat_result: if not d:
called_tool_name = None called_tool_name = None
elif "tool" in parsed_chat_result: elif "tool" in d:
called_tool_name = d["tool"] # per spec called_tool_name = d["tool"] # per spec
elif "name" in d: elif "name" in d:
called_tool_name = d["name"] # Phi3 often does this called_tool_name = d["name"] # Phi3 often does this
elif "tool_name" in d: elif "tool_name" in d:
called_tool_name = d["tool_name"] # Phi3 often does this called_tool_name = d["tool_name"] # Phi3 often does this
elif "action" in d: elif "action" in d:
called_tool_name = d["action"] # Phi3 does this called_tool_name = d["action"] # Gemma2 does this
elif "task" in d:
called_tool_name = d["task"] # Gemma2 does this
else: else:
return None return None
@@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama):
elif "tool_input" in d: elif "tool_input" in d:
response = d["tool_input"] response = d["tool_input"]
else: else:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
try: try:
assert isinstance(response, str) assert isinstance(response, str)
except AssertionError: except AssertionError:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
return response return response
def _extract_tool_args(self, d: dict) -> dict: def _extract_tool_args(self, d: dict) -> dict:
if "tool_input" in parsed_chat_result: if "tool_input" in d:
called_tool_args = d["tool_input"] # per spec called_tool_args = d["tool_input"] # per spec
elif "input" in d: elif "input" in d:
called_tool_args = d["input"] # Phi3 often does this called_tool_args = d["input"] # Phi3 often does this
@@ -258,64 +220,114 @@ class OllamaFunctions(ChatOllama):
called_tool_args = {} called_tool_args = {}
return called_tool_args return called_tool_args
# prepare generation def _get_final_message(self, messages: list, functions_str: str) -> list:
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])] def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
functions_str = json.dumps(functions_list, indent=2) call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = messages[0]
for m in messages[1:]:
# prepare generation with history if formated_history != "":
if True in [ isinstance(m, ToolMessage) for m in messages ]: formated_history += "\n\n"
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# prepare generation without history if isinstance(m, SystemMessage):
else: formated_history += "The system provided the info:\n" + str(m.content)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) elif isinstance(m, HumanMessage):
system_message = system_message_prompt_template.format( formated_history += "The Human said:\n" + str(m.content)
tools=functions_str elif isinstance(m, AIMessage) and m.tool_calls:
) formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
final_messages = [ system_message ] + messages elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
# genrerate chat result return system_msg, formated_history
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text
# chekc for validity # prepare generation with history
if not isinstance(chat_result, str): if True in [ isinstance(m, ToolMessage) for m in messages ]:
raise ValueError("OllamaFunctions does not support non-string output.") system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# make str to dict # prepare generation without history
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) else:
# if model failed to return vailid json, just retrun the whole thing system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
if isinstance(parsed_chat_result, str): system_message = system_message_prompt_template.format(
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))]) tools=functions_str
)
final_messages = [ system_message ] + messages
return final_messages
# get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if not called_tool: def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else:
response_msg = AIMessage(
content="",
tool_calls=[ToolCall(
name=called_tool['name'],
args=_extract_tool_args(self, d=parsed_chat_result),
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)],
)
return ChatResult(generations=[ChatGeneration(message=response_msg)]) # prepare generation
functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
# get messages to prompt with
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
# genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text
try:
# make str to dict
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
# get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL):
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else:
response_msg = AIMessage(
content="",
tool_calls=[ToolCall(
name=called_tool['name'],
args=_extract_tool_args(self, d=parsed_chat_result),
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)],
)
return ChatResult(generations=[ChatGeneration(message=response_msg)])
except OllamaError as e:
if failed_tool_calls < self.max_tool_call_fails:
# retry
messages.append(AIMessage(chat_result))
messages.append(SystemMessage(e.message))
return gen(self, failed_tool_calls+1, messages=messages)
else:
# return error
# return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))])
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))])
# inital call with no failed runs
return gen(self, failed_tool_calls=0, messages=messages)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@@ -1,4 +1,5 @@
from libs.classes import Test, Model from os import name
from libs.classes import Technique, Test, Model
from libs.functions import nxhash from libs.functions import nxhash
from typing import Union from typing import Union
@@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int:
collection_type = "models" collection_type = "models"
elif isinstance(collection[list(collection.keys())[0]], Test): elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "tests" collection_type = "tests"
elif isinstance(collection[list(collection.keys())[0]], Technique):
collection_type = "techniques"
else: else:
raise TypeError("get_len: unsupported collection_type") raise TypeError("get_len: unsupported collection_type")
else: else:
@@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int:
case "tests": case "tests":
for test_id in collection: for test_id in collection:
maximum_length = max(maximum_length, len(collection[test_id].name)) maximum_length = max(maximum_length, len(collection[test_id].name))
case "techniques":
for technique_id in collection:
maximum_length = max(maximum_length, len(collection[technique_id].name))
case _: case _:
for model_name in collection: for model_name in collection:
raise TypeError("get_len: unsupported collection_type") raise TypeError("get_len: unsupported collection_type")
@@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str): def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str):
try: try:
print("Trying to load saved_results.json") print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f: with open("./saved_results.json", "r") as f:
@@ -53,88 +59,109 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
model = models[model_id] model = models[model_id]
for test_id in tests: for test_id in tests:
test = tests[test_id] test = tests[test_id]
for seed in seeds: for technique_id in techniques:
# Init dict technique = techniques[technique_id]
combination = { if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)):
'test_id': test_id, continue
'model_id': model_id, for seed in seeds:
'seed': seed, # Init dict
} combination = {
hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) 'test_id': test_id,
combination['test_name'] = test.name 'model_id': model_id,
combination['model_name'] = model.display_name 'seed': seed,
'technique_id': technique_id
# if hash_key == "DE3D137E": }
# pass hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
if hash_key not in saved_results.keys(): combination.update({
try: 'test_name': test.name,
print("\033[0;35mModel '\033[0m" + 'model_name': model.display_name,
model.display_name + 'technique_name': technique.name,
"\033[0;35m'" + })
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" + # if hash_key == "DE3D137E":
("0" * (get_len(seeds) - len(str(seed)))) + # pass
"\033[0m" +
str(seed) + if hash_key not in saved_results.keys():
"\033[0;35m now runs test '\033[0m" + try:
test.name + print("\033[0;35mModel '\033[0m" +
"\033[0;35m'" + model.display_name +
(" " * (get_len(tests) - len(test.name))) + "\033[0;35m'" +
" (\033[0m" + (" " * (get_len(models) - len(model.display_name))) +
hash_key + " with seed \033[0m\033[0;30m" +
"\033[0;35m)\033[0m", ("0" * (get_len(seeds) - len(str(seed)))) +
end="" "\033[0m" +
) str(seed) +
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url) "\033[0;35m using technique '\033[0m" +
if isinstance(answer, str): technique.name +
combination['answer'] = answer "\033[0;35m'" +
# combination['tool_calls'] = [] # no entry (" " * (get_len(techniques) - len(technique.name))) +
del answer "\033[0;35m now runs test '\033[0m" +
elif isinstance(answer, dict): # calls test.name +
combination['answer'] = answer['answer'] "\033[0;35m'" +
combination['tool_calls'] = answer['tool_calls'] (" " * (get_len(tests) - len(test.name))) +
del answer " (\033[0m" +
else: hash_key +
raise Exception(f"runnable returned unkown type {type(answer)}.") "\033[0;35m)\033[0m",
end=""
)
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
if isinstance(answer, str):
combination['answer'] = answer
# combination['tool_calls'] = [] # no entry
del answer
elif isinstance(answer, dict): # calls
combination['answer'] = answer['answer']
combination['tool_calls'] = answer['tool_calls']
del answer
else:
raise Exception(f"runnable returned unkown type {type(answer)}.")
combination['test'] = test combination['test'] = test
run_results[hash_key] = combination run_results[hash_key] = combination
print("\r\033[0;32mModel '\033[0m" + print("\r\033[0;32mModel '\033[0m" +
model.display_name +
"\033[0;32m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;32m using technique '\033[0m" +
technique.name +
"\033[0;32m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;32m finished test '\033[0m" +
test.name +
"\033[0;32m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;32m)\033[0m"
)
except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else:
print("\r\033[0;34mModel '\033[0m" +
model.display_name + model.display_name +
"\033[0;32m'" + "\033[0;34m'" +
(" " * (get_len(models) - len(model.display_name))) + (" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" + " with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
str(seed) + str(seed) +
"\033[0;32m finished test '\033[0m" + "\033[0;34m using technique '\033[0m" +
technique.name +
"\033[0;34m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;34m skipped test '\033[0m" +
test.name + test.name +
"\033[0;32m'" + "\033[0;34m'" +
(" " * (get_len(tests) - len(test.name))) + (" " * (get_len(tests) - len(test.name))) +
" (\033[0m" + " (\033[0m" +
hash_key + hash_key +
"\033[0;32m)\033[0m" "\033[0;34m) becasue its results exists in saved_results.json\033[0m"
) )
except Exception as e:
print("\r\033[0;31mError: <\033[0m" + str(e) + "\033[0;31m> at (\033[0m" + hash_key + "\033[0;31m). Continuing...\033[0m ")
else:
print("\r\033[0;34mModel '\033[0m" +
model.display_name +
"\033[0;34m'" +
(" " * (get_len(models) - len(model.display_name))) +
" with seed \033[0m\033[0;30m" +
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;34m skipped test '\033[0m" +
test.name +
"\033[0;34m'" +
(" " * (get_len(tests) - len(test.name))) +
" (\033[0m" +
hash_key +
"\033[0;34m) becasue its results exists in saved_results.json\033[0m"
)
# Validate Results # Validate Results

View File

@@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
model=model.identifier, model=model.identifier,
seed=seed, seed=seed,
base_url=base_url, base_url=base_url,
format="json" format="json",
max_tool_call_fails=3,
temperature=0.0
) )
if tools: if tools:
@@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
tool_msg = selected_tool.invoke(call) tool_msg = selected_tool.invoke(call)
messages.append(tool_msg) messages.append(tool_msg)
ai_msg = llm.invoke(messages) ai_msg = llm.invoke(messages)
i = 0
while isinstance(ai_msg, SystemMessage):
i += 1
if i <= 5:
return {
"answer": ">>LLM failed to use tools<<",
"tool_calls": tool_calls,
}
messages.append(ai_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({ tool_calls.append({
"tool": call["name"], "tool": call["name"],
"args": call["args"], "args": call["args"],
"times_failed": i
}) })
except IndexError: # LLM didnt use a tool -> jsut return the content except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = [] tool_calls = []
@@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
messages = state["messages"] messages = state["messages"]
last_message = messages[-1] last_message = messages[-1]
nonlocal index nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls: if last_message.tool_calls:
index += 1 index += 1
return "tools" return "tools"
@@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
should_continue, should_continue,
) )
workflow.add_edge("tools", "agent") workflow.add_edge("tools", "agent")
graph = workflow.compile() graph = workflow.compile()
# compose "history" supprts few shot prompting # compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])] start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try: try:
@@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
chunks = [] chunks = []
for chunk in graph.stream( try:
{"messages": start_messages}, for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
stream_mode="values", chunks.append(chunk["messages"][-1])
): except RecursionError:
chunks.append(chunk["messages"][-1]) return {
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
"tool_calls": tool_calls
}
return { return {
"answer": chunks[-1].content, "answer": chunks[-1].content,

View File

@@ -81,13 +81,13 @@ models = {
), ),
701: Model( 701: Model(
display_name="Yi 6b", display_name="Yi 6b",
identifier="yi:7b", identifier="yi:6b",
supports_tools=False, supports_tools=False,
parameter_count_in_b=6 parameter_count_in_b=6
), ),
704: Model( 704: Model(
display_name="Yi 6b", display_name="Yi 9b",
identifier="yi:7b", identifier="yi:9b",
supports_tools=False, supports_tools=False,
parameter_count_in_b=6 parameter_count_in_b=6
), ),
@@ -97,12 +97,6 @@ models = {
supports_tools=False, supports_tools=False,
parameter_count_in_b=34 parameter_count_in_b=34
), ),
129: Model(
display_name="Yi 34b",
identifier="yi:34b",
supports_tools=False,
parameter_count_in_b=34
),
853: Model( 853: Model(
display_name="Qwen2 0.5b", display_name="Qwen2 0.5b",
identifier="qwen2:0.5b", identifier="qwen2:0.5b",

View File

@@ -0,0 +1,19 @@
from libs.classes import Technique
techniques = {
190: Technique(
name="Native",
for_supports_tools=True,
for_not_supports_tools=False,
),
903: Technique(
name="Long System Message",
for_supports_tools=False,
for_not_supports_tools=True,
),
# 572: Technique(
# name="Tool to System Messsages",
# for_supports_tools=False,
# for_not_supports_tools=True,
# ),
}

View File

@@ -121,12 +121,13 @@ tests = {
"Write note": write_note "Write note": write_note
} }
}, },
validator=system_human_answer_match, validator=system_human_answer_match,
validation_input={ validation_input={
"criteria": dedent("""- containing the information that the Human should call Wolfgang "criteria": dedent("""- containing the information that the Human should call Wolfgang
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
} }
), ),
# 363: Test(), # 363: Test(),
# 600: Test(), # 600: Test(),
# 221: Test(), # 221: Test(),

View File

@@ -1,7 +1,9 @@
from libs.classes import Technique
from libs.run_tests import run_tests from libs.run_tests import run_tests
from suite_settings.models import models from suite_settings.models import models
from suite_settings.seeds import seeds from suite_settings.seeds import seeds
from suite_settings.tests import tests from suite_settings.tests import tests
from suite_settings.techniques import techniques
def main(): def main():
@@ -10,6 +12,7 @@ def main():
models=models, models=models,
seeds=seeds, seeds=seeds,
tests=tests, tests=tests,
techniques=techniques,
base_url="http://bolt.hs-mittweida.de:11434", base_url="http://bolt.hs-mittweida.de:11434",
) )