This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-25 20:10:53 +02:00
parent a578dd26a0
commit 2723ced901
8 changed files with 307 additions and 229 deletions

View File

@@ -16,3 +16,9 @@ class Model:
supports_tools: bool supports_tools: bool
parameter_count_in_b: float parameter_count_in_b: float
@dataclass
class Technique:
name: str
for_supports_tools: bool
for_not_supports_tools: bool

View File

@@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from textwrap import dedent
from libs.functions import nxhash from libs.functions import nxhash
@@ -87,13 +88,36 @@ _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM] _DictOrPydantic = Union[Dict, _BM]
class OllamaError(Exception):
def __init__(self, message):
self.message = message # Store the message
super().__init__(message)
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and ( return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__ is_basemodel_subclass(obj) or BaseModel in obj.__bases__
) )
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
def convert_to_ollama_tool(tool: Any) -> Dict: tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
max_tool_call_fails: int = 5
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
"""Convert a tool to an Ollama tool.""" """Convert a tool to an Ollama tool."""
description = None description = None
if _is_pydantic_class(tool): if _is_pydantic_class(tool):
@@ -120,98 +144,28 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
return definition return definition
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
def _get_parsed_chat_result(self, chat_result_str: str) -> dict:
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = ""
for m in messages:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
system_msg += str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
try: try:
parsed_chat_result = json.loads(chat_result_str) parsed_chat_result = json.loads(chat_result_str)
except json.JSONDecodeError:
parsed_chat_result = chat_result_str
return parsed_chat_result return parsed_chat_result
except json.JSONDecodeError:
raise OllamaError(message="Error. Message is not valid JSON.")
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType: def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
if not parsed_chat_result: if not d:
called_tool_name = None called_tool_name = None
elif "tool" in parsed_chat_result: elif "tool" in d:
called_tool_name = d["tool"] # per spec called_tool_name = d["tool"] # per spec
elif "name" in d: elif "name" in d:
called_tool_name = d["name"] # Phi3 often does this called_tool_name = d["name"] # Phi3 often does this
elif "tool_name" in d: elif "tool_name" in d:
called_tool_name = d["tool_name"] # Phi3 often does this called_tool_name = d["tool_name"] # Phi3 often does this
elif "action" in d: elif "action" in d:
called_tool_name = d["action"] # Phi3 does this called_tool_name = d["action"] # Gemma2 does this
elif "task" in d:
called_tool_name = d["task"] # Gemma2 does this
else: else:
return None return None
@@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama):
elif "tool_input" in d: elif "tool_input" in d:
response = d["tool_input"] response = d["tool_input"]
else: else:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
try: try:
assert isinstance(response, str) assert isinstance(response, str)
except AssertionError: except AssertionError:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}") raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
return response return response
def _extract_tool_args(self, d: dict) -> dict: def _extract_tool_args(self, d: dict) -> dict:
if "tool_input" in parsed_chat_result: if "tool_input" in d:
called_tool_args = d["tool_input"] # per spec called_tool_args = d["tool_input"] # per spec
elif "input" in d: elif "input" in d:
called_tool_args = d["input"] # Phi3 often does this called_tool_args = d["input"] # Phi3 often does this
@@ -258,10 +220,41 @@ class OllamaFunctions(ChatOllama):
called_tool_args = {} called_tool_args = {}
return called_tool_args return called_tool_args
# prepare generation def _get_final_message(self, messages: list, functions_str: str) -> list:
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])] def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
functions_str = json.dumps(functions_list, indent=2) call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = messages[0]
for m in messages[1:]:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
formated_history += "The system provided the info:\n" + str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
# prepare generation with history # prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]: if True in [ isinstance(m, ToolMessage) for m in messages ]:
@@ -283,27 +276,33 @@ class OllamaFunctions(ChatOllama):
) )
final_messages = [ system_message ] + messages final_messages = [ system_message ] + messages
return final_messages
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
# prepare generation
functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
# get messages to prompt with
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
# genrerate chat result # genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs) response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text chat_result = response_message.generations[0].text
# chekc for validity try:
if not isinstance(chat_result, str):
raise ValueError("OllamaFunctions does not support non-string output.")
# make str to dict # make str to dict
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
# if model failed to return vailid json, just retrun the whole thing
if isinstance(parsed_chat_result, str):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
# get the called tool from the dict # get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list) called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if not called_tool: if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL):
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else: else:
response_msg = AIMessage( response_msg = AIMessage(
@@ -314,8 +313,21 @@ class OllamaFunctions(ChatOllama):
id=f"call_{str(uuid.uuid4()).replace('-', '')}", id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)], )],
) )
return ChatResult(generations=[ChatGeneration(message=response_msg)]) return ChatResult(generations=[ChatGeneration(message=response_msg)])
except OllamaError as e:
if failed_tool_calls < self.max_tool_call_fails:
# retry
messages.append(AIMessage(chat_result))
messages.append(SystemMessage(e.message))
return gen(self, failed_tool_calls+1, messages=messages)
else:
# return error
# return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))])
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))])
# inital call with no failed runs
return gen(self, failed_tool_calls=0, messages=messages)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@@ -1,4 +1,5 @@
from libs.classes import Test, Model from os import name
from libs.classes import Technique, Test, Model
from libs.functions import nxhash from libs.functions import nxhash
from typing import Union from typing import Union
@@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int:
collection_type = "models" collection_type = "models"
elif isinstance(collection[list(collection.keys())[0]], Test): elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "tests" collection_type = "tests"
elif isinstance(collection[list(collection.keys())[0]], Technique):
collection_type = "techniques"
else: else:
raise TypeError("get_len: unsupported collection_type") raise TypeError("get_len: unsupported collection_type")
else: else:
@@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int:
case "tests": case "tests":
for test_id in collection: for test_id in collection:
maximum_length = max(maximum_length, len(collection[test_id].name)) maximum_length = max(maximum_length, len(collection[test_id].name))
case "techniques":
for technique_id in collection:
maximum_length = max(maximum_length, len(collection[technique_id].name))
case _: case _:
for model_name in collection: for model_name in collection:
raise TypeError("get_len: unsupported collection_type") raise TypeError("get_len: unsupported collection_type")
@@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str): def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str):
try: try:
print("Trying to load saved_results.json") print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f: with open("./saved_results.json", "r") as f:
@@ -53,16 +59,25 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
model = models[model_id] model = models[model_id]
for test_id in tests: for test_id in tests:
test = tests[test_id] test = tests[test_id]
for technique_id in techniques:
technique = techniques[technique_id]
if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)):
continue
for seed in seeds: for seed in seeds:
# Init dict # Init dict
combination = { combination = {
'test_id': test_id, 'test_id': test_id,
'model_id': model_id, 'model_id': model_id,
'seed': seed, 'seed': seed,
'technique_id': technique_id
} }
hash_key = str(nxhash(json.dumps(combination, sort_keys=True))) hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
combination['test_name'] = test.name
combination['model_name'] = model.display_name combination.update({
'test_name': test.name,
'model_name': model.display_name,
'technique_name': technique.name,
})
# if hash_key == "DE3D137E": # if hash_key == "DE3D137E":
# pass # pass
@@ -77,6 +92,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
str(seed) + str(seed) +
"\033[0;35m using technique '\033[0m" +
technique.name +
"\033[0;35m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;35m now runs test '\033[0m" + "\033[0;35m now runs test '\033[0m" +
test.name + test.name +
"\033[0;35m'" + "\033[0;35m'" +
@@ -108,6 +127,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
str(seed) + str(seed) +
"\033[0;32m using technique '\033[0m" +
technique.name +
"\033[0;32m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;32m finished test '\033[0m" + "\033[0;32m finished test '\033[0m" +
test.name + test.name +
"\033[0;32m'" + "\033[0;32m'" +
@@ -127,6 +150,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) + ("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" + "\033[0m" +
str(seed) + str(seed) +
"\033[0;34m using technique '\033[0m" +
technique.name +
"\033[0;34m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;34m skipped test '\033[0m" + "\033[0;34m skipped test '\033[0m" +
test.name + test.name +
"\033[0;34m'" + "\033[0;34m'" +

View File

@@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
model=model.identifier, model=model.identifier,
seed=seed, seed=seed,
base_url=base_url, base_url=base_url,
format="json" format="json",
max_tool_call_fails=3,
temperature=0.0
) )
if tools: if tools:
@@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
tool_msg = selected_tool.invoke(call) tool_msg = selected_tool.invoke(call)
messages.append(tool_msg) messages.append(tool_msg)
ai_msg = llm.invoke(messages) ai_msg = llm.invoke(messages)
i = 0
while isinstance(ai_msg, SystemMessage):
i += 1
if i <= 5:
return {
"answer": ">>LLM failed to use tools<<",
"tool_calls": tool_calls,
}
messages.append(ai_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({ tool_calls.append({
"tool": call["name"], "tool": call["name"],
"args": call["args"], "args": call["args"],
"times_failed": i
}) })
except IndexError: # LLM didnt use a tool -> jsut return the content except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = [] tool_calls = []
@@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
messages = state["messages"] messages = state["messages"]
last_message = messages[-1] last_message = messages[-1]
nonlocal index nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls: if last_message.tool_calls:
index += 1 index += 1
return "tools" return "tools"
@@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
should_continue, should_continue,
) )
workflow.add_edge("tools", "agent") workflow.add_edge("tools", "agent")
graph = workflow.compile() graph = workflow.compile()
# compose "history" supprts few shot prompting # compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])] start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try: try:
@@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
chunks = [] chunks = []
for chunk in graph.stream( try:
{"messages": start_messages}, for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
stream_mode="values",
):
chunks.append(chunk["messages"][-1]) chunks.append(chunk["messages"][-1])
except RecursionError:
return {
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
"tool_calls": tool_calls
}
return { return {
"answer": chunks[-1].content, "answer": chunks[-1].content,

View File

@@ -81,13 +81,13 @@ models = {
), ),
701: Model( 701: Model(
display_name="Yi 6b", display_name="Yi 6b",
identifier="yi:7b", identifier="yi:6b",
supports_tools=False, supports_tools=False,
parameter_count_in_b=6 parameter_count_in_b=6
), ),
704: Model( 704: Model(
display_name="Yi 6b", display_name="Yi 9b",
identifier="yi:7b", identifier="yi:9b",
supports_tools=False, supports_tools=False,
parameter_count_in_b=6 parameter_count_in_b=6
), ),
@@ -97,12 +97,6 @@ models = {
supports_tools=False, supports_tools=False,
parameter_count_in_b=34 parameter_count_in_b=34
), ),
129: Model(
display_name="Yi 34b",
identifier="yi:34b",
supports_tools=False,
parameter_count_in_b=34
),
853: Model( 853: Model(
display_name="Qwen2 0.5b", display_name="Qwen2 0.5b",
identifier="qwen2:0.5b", identifier="qwen2:0.5b",

View File

@@ -0,0 +1,19 @@
from libs.classes import Technique
techniques = {
190: Technique(
name="Native",
for_supports_tools=True,
for_not_supports_tools=False,
),
903: Technique(
name="Long System Message",
for_supports_tools=False,
for_not_supports_tools=True,
),
# 572: Technique(
# name="Tool to System Messsages",
# for_supports_tools=False,
# for_not_supports_tools=True,
# ),
}

View File

@@ -127,6 +127,7 @@ tests = {
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""") - just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
} }
), ),
# 363: Test(), # 363: Test(),
# 600: Test(), # 600: Test(),
# 221: Test(), # 221: Test(),

View File

@@ -1,7 +1,9 @@
from libs.classes import Technique
from libs.run_tests import run_tests from libs.run_tests import run_tests
from suite_settings.models import models from suite_settings.models import models
from suite_settings.seeds import seeds from suite_settings.seeds import seeds
from suite_settings.tests import tests from suite_settings.tests import tests
from suite_settings.techniques import techniques
def main(): def main():
@@ -10,6 +12,7 @@ def main():
models=models, models=models,
seeds=seeds, seeds=seeds,
tests=tests, tests=tests,
techniques=techniques,
base_url="http://bolt.hs-mittweida.de:11434", base_url="http://bolt.hs-mittweida.de:11434",
) )