This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-25 20:10:53 +02:00
parent a578dd26a0
commit 2723ced901
8 changed files with 307 additions and 229 deletions

View File

@@ -16,3 +16,9 @@ class Model:
supports_tools: bool
parameter_count_in_b: float
@dataclass
class Technique:
name: str
for_supports_tools: bool
for_not_supports_tools: bool

View File

@@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from textwrap import dedent
from libs.functions import nxhash
@@ -87,13 +88,36 @@ _BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydantic = Union[Dict, _BM]
class OllamaError(Exception):
def __init__(self, message):
self.message = message # Store the message
super().__init__(message)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and (
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
def convert_to_ollama_tool(tool: Any) -> Dict:
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
max_tool_call_fails: int = 5
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
@@ -120,98 +144,28 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
return definition
def parse_response(message: BaseMessage) -> str:
"""Extract `function_call` from `AIMessage`."""
if isinstance(message, AIMessage):
kwargs = message.additional_kwargs
tool_calls = message.tool_calls
if len(tool_calls) > 0:
tool_call = tool_calls[-1]
args = tool_call.get("args")
return json.dumps(args)
elif "function_call" in kwargs:
if "arguments" in kwargs["function_call"]:
return kwargs["function_call"]["arguments"]
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
else:
raise ValueError("`tool_calls` missing from AIMessage: {message}")
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
class OllamaFunctions(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = ""
for m in messages:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
system_msg += str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
def _get_parsed_chat_result(self, chat_result_str: str) -> dict:
try:
parsed_chat_result = json.loads(chat_result_str)
except json.JSONDecodeError:
parsed_chat_result = chat_result_str
return parsed_chat_result
except json.JSONDecodeError:
raise OllamaError(message="Error. Message is not valid JSON.")
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
if not parsed_chat_result:
if not d:
called_tool_name = None
elif "tool" in parsed_chat_result:
elif "tool" in d:
called_tool_name = d["tool"] # per spec
elif "name" in d:
called_tool_name = d["name"] # Phi3 often does this
elif "tool_name" in d:
called_tool_name = d["tool_name"] # Phi3 often does this
elif "action" in d:
called_tool_name = d["action"] # Phi3 does this
called_tool_name = d["action"] # Gemma2 does this
elif "task" in d:
called_tool_name = d["task"] # Gemma2 does this
else:
return None
@@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama):
elif "tool_input" in d:
response = d["tool_input"]
else:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
try:
assert isinstance(response, str)
except AssertionError:
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
"{\n" +
' "tool": <name of the selected tool>,\n' +
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
"}")
return response
def _extract_tool_args(self, d: dict) -> dict:
if "tool_input" in parsed_chat_result:
if "tool_input" in d:
called_tool_args = d["tool_input"] # per spec
elif "input" in d:
called_tool_args = d["input"] # Phi3 often does this
@@ -258,10 +220,41 @@ class OllamaFunctions(ChatOllama):
called_tool_args = {}
return called_tool_args
# prepare generation
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
def _get_final_message(self, messages: list, functions_str: str) -> list:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = messages[0]
for m in messages[1:]:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
formated_history += "The system provided the info:\n" + str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
@@ -283,27 +276,33 @@ class OllamaFunctions(ChatOllama):
)
final_messages = [ system_message ] + messages
return final_messages
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
# prepare generation
functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])]
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
functions_str = json.dumps(functions_list, indent=2)
# get messages to prompt with
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
# genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
chat_result = response_message.generations[0].text
# chekc for validity
if not isinstance(chat_result, str):
raise ValueError("OllamaFunctions does not support non-string output.")
try:
# make str to dict
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
# if model failed to return vailid json, just retrun the whole thing
if isinstance(parsed_chat_result, str):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
# get the called tool from the dict
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
if not called_tool:
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL):
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
else:
response_msg = AIMessage(
@@ -314,8 +313,21 @@ class OllamaFunctions(ChatOllama):
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)],
)
return ChatResult(generations=[ChatGeneration(message=response_msg)])
except OllamaError as e:
if failed_tool_calls < self.max_tool_call_fails:
# retry
messages.append(AIMessage(chat_result))
messages.append(SystemMessage(e.message))
return gen(self, failed_tool_calls+1, messages=messages)
else:
# return error
# return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))])
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))])
# inital call with no failed runs
return gen(self, failed_tool_calls=0, messages=messages)
@property
def _llm_type(self) -> str:

View File

@@ -1,4 +1,5 @@
from libs.classes import Test, Model
from os import name
from libs.classes import Technique, Test, Model
from libs.functions import nxhash
from typing import Union
@@ -14,6 +15,8 @@ def get_len(collection: Union[list, dict]) -> int:
collection_type = "models"
elif isinstance(collection[list(collection.keys())[0]], Test):
collection_type = "tests"
elif isinstance(collection[list(collection.keys())[0]], Technique):
collection_type = "techniques"
else:
raise TypeError("get_len: unsupported collection_type")
else:
@@ -29,6 +32,9 @@ def get_len(collection: Union[list, dict]) -> int:
case "tests":
for test_id in collection:
maximum_length = max(maximum_length, len(collection[test_id].name))
case "techniques":
for technique_id in collection:
maximum_length = max(maximum_length, len(collection[technique_id].name))
case _:
for model_name in collection:
raise TypeError("get_len: unsupported collection_type")
@@ -37,7 +43,7 @@ def get_len(collection: Union[list, dict]) -> int:
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], base_url: str):
def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test], techniques: dict[int, Technique], base_url: str):
try:
print("Trying to load saved_results.json")
with open("./saved_results.json", "r") as f:
@@ -53,16 +59,25 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
model = models[model_id]
for test_id in tests:
test = tests[test_id]
for technique_id in techniques:
technique = techniques[technique_id]
if ((model.supports_tools != technique.for_supports_tools) and (model.supports_tools == technique.for_not_supports_tools)):
continue
for seed in seeds:
# Init dict
combination = {
'test_id': test_id,
'model_id': model_id,
'seed': seed,
'technique_id': technique_id
}
hash_key = str(nxhash(json.dumps(combination, sort_keys=True)))
combination['test_name'] = test.name
combination['model_name'] = model.display_name
combination.update({
'test_name': test.name,
'model_name': model.display_name,
'technique_name': technique.name,
})
# if hash_key == "DE3D137E":
# pass
@@ -77,6 +92,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;35m using technique '\033[0m" +
technique.name +
"\033[0;35m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;35m now runs test '\033[0m" +
test.name +
"\033[0;35m'" +
@@ -108,6 +127,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;32m using technique '\033[0m" +
technique.name +
"\033[0;32m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;32m finished test '\033[0m" +
test.name +
"\033[0;32m'" +
@@ -127,6 +150,10 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
("0" * (get_len(seeds) - len(str(seed)))) +
"\033[0m" +
str(seed) +
"\033[0;34m using technique '\033[0m" +
technique.name +
"\033[0;34m'" +
(" " * (get_len(techniques) - len(technique.name))) +
"\033[0;34m skipped test '\033[0m" +
test.name +
"\033[0;34m'" +

View File

@@ -22,7 +22,9 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
model=model.identifier,
seed=seed,
base_url=base_url,
format="json"
format="json",
max_tool_call_fails=3,
temperature=0.0
)
if tools:
@@ -75,9 +77,21 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
tool_msg = selected_tool.invoke(call)
messages.append(tool_msg)
ai_msg = llm.invoke(messages)
i = 0
while isinstance(ai_msg, SystemMessage):
i += 1
if i <= 5:
return {
"answer": ">>LLM failed to use tools<<",
"tool_calls": tool_calls,
}
messages.append(ai_msg)
ai_msg = llm.invoke(messages)
tool_calls.append({
"tool": call["name"],
"args": call["args"],
"times_failed": i
})
except IndexError: # LLM didnt use a tool -> jsut return the content
tool_calls = []
@@ -103,7 +117,6 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
messages = state["messages"]
last_message = messages[-1]
nonlocal index
assert isinstance(last_message, AIMessage) # this is just so the type checker is happy
if last_message.tool_calls:
index += 1
return "tools"
@@ -174,9 +187,9 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
should_continue,
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
# compose "history" supprts few shot prompting
start_messages = [SystemMessage(test.runnable_input['system_msg'])]
try:
@@ -187,11 +200,14 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
chunks = []
for chunk in graph.stream(
{"messages": start_messages},
stream_mode="values",
):
try:
for chunk in graph.stream({"messages": start_messages}, stream_mode="values", config={"recursion_limit": 10}):
chunks.append(chunk["messages"][-1])
except RecursionError:
return {
"answer": ">>Model did not come to a conclusion (Recusion Error)<<",
"tool_calls": tool_calls
}
return {
"answer": chunks[-1].content,

View File

@@ -81,13 +81,13 @@ models = {
),
701: Model(
display_name="Yi 6b",
identifier="yi:7b",
identifier="yi:6b",
supports_tools=False,
parameter_count_in_b=6
),
704: Model(
display_name="Yi 6b",
identifier="yi:7b",
display_name="Yi 9b",
identifier="yi:9b",
supports_tools=False,
parameter_count_in_b=6
),
@@ -97,12 +97,6 @@ models = {
supports_tools=False,
parameter_count_in_b=34
),
129: Model(
display_name="Yi 34b",
identifier="yi:34b",
supports_tools=False,
parameter_count_in_b=34
),
853: Model(
display_name="Qwen2 0.5b",
identifier="qwen2:0.5b",

View File

@@ -0,0 +1,19 @@
from libs.classes import Technique
techniques = {
190: Technique(
name="Native",
for_supports_tools=True,
for_not_supports_tools=False,
),
903: Technique(
name="Long System Message",
for_supports_tools=False,
for_not_supports_tools=True,
),
# 572: Technique(
# name="Tool to System Messsages",
# for_supports_tools=False,
# for_not_supports_tools=True,
# ),
}

View File

@@ -127,6 +127,7 @@ tests = {
- just one single conversational answer, without any AI fragments (A/B versions, "end of message" parts, unfitting discalimers or notes, what specific tool was used to get the answer, etc.)""")
}
),
# 363: Test(),
# 600: Test(),
# 221: Test(),

View File

@@ -1,7 +1,9 @@
from libs.classes import Technique
from libs.run_tests import run_tests
from suite_settings.models import models
from suite_settings.seeds import seeds
from suite_settings.tests import tests
from suite_settings.techniques import techniques
def main():
@@ -10,6 +12,7 @@ def main():
models=models,
seeds=seeds,
tests=tests,
techniques=techniques,
base_url="http://bolt.hs-mittweida.de:11434",
)