mf1
This commit is contained in:
@@ -20,11 +20,10 @@ from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||
from textwrap import dedent
|
||||
|
||||
from libs.functions import nxhash
|
||||
|
||||
@@ -98,14 +97,15 @@ def _is_pydantic_class(obj: Any) -> bool:
|
||||
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||
)
|
||||
|
||||
class OllamaFunctions(ChatOllama):
|
||||
class OllamaFunctionsBase(ChatOllama):
|
||||
"""Function chat model that uses Ollama API."""
|
||||
|
||||
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
|
||||
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
|
||||
max_tool_call_fails: int = 5
|
||||
|
||||
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def bind_tools(
|
||||
@@ -115,6 +115,8 @@ class OllamaFunctions(ChatOllama):
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(functions=tools, **kwargs)
|
||||
|
||||
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
|
||||
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
|
||||
@@ -177,11 +179,11 @@ class OllamaFunctions(ChatOllama):
|
||||
return called_tool
|
||||
|
||||
def _extract_conversaional_response(self, d: dict) -> str:
|
||||
if ("tool_input" in d and "response" in d["tool_input"]):
|
||||
if ("tool_input" in d and d["tool_input"] and "response" in d["tool_input"]):
|
||||
response = d["tool_input"]["response"]
|
||||
elif ("input" in d and "response" in d["input"]):
|
||||
elif ("input" in d and d["input"] and "response" in d["input"]):
|
||||
response = d["input"]["response"]
|
||||
elif ("args" in d and "response" in d["args"]):
|
||||
elif ("args" in d and d["args"] and "response" in d["args"]):
|
||||
response = d["args"]["response"]
|
||||
elif "response" in d:
|
||||
response = d["response"]
|
||||
@@ -220,66 +222,6 @@ class OllamaFunctions(ChatOllama):
|
||||
called_tool_args = {}
|
||||
return called_tool_args
|
||||
|
||||
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||
call_list = []
|
||||
for c in tool_calls:
|
||||
call_list.append({
|
||||
"id": nxhash(c['id'])[-4:],
|
||||
"tool": c['name'],
|
||||
"args": c['args']
|
||||
})
|
||||
if len(call_list) == 1:
|
||||
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||
formated_history = ""
|
||||
system_msg = messages[0]
|
||||
for m in messages[1:]:
|
||||
|
||||
if formated_history != "":
|
||||
formated_history += "\n\n"
|
||||
|
||||
if isinstance(m, SystemMessage):
|
||||
formated_history += "The system provided the info:\n" + str(m.content)
|
||||
elif isinstance(m, HumanMessage):
|
||||
formated_history += "The Human said:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||
elif isinstance(m, ToolMessage):
|
||||
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||
formated_history += "You said:\n" + str(m.content)
|
||||
else:
|
||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||
|
||||
return system_msg, formated_history
|
||||
|
||||
# prepare generation with history
|
||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str,
|
||||
history=formated_history,
|
||||
system_msg=system_msg
|
||||
)
|
||||
final_messages = [ system_message ]
|
||||
|
||||
# prepare generation without history
|
||||
else:
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str
|
||||
)
|
||||
final_messages = [ system_message ] + messages
|
||||
|
||||
return final_messages
|
||||
|
||||
|
||||
|
||||
|
||||
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
|
||||
|
||||
@@ -289,7 +231,7 @@ class OllamaFunctions(ChatOllama):
|
||||
functions_str = json.dumps(functions_list, indent=2)
|
||||
|
||||
# get messages to prompt with
|
||||
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
|
||||
final_messages = self._get_final_message(messages=messages, functions_str=functions_str)
|
||||
|
||||
# genrerate chat result
|
||||
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
@@ -329,6 +271,125 @@ class OllamaFunctions(ChatOllama):
|
||||
return gen(self, failed_tool_calls=0, messages=messages)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class OllamaFunctionsLSM(OllamaFunctionsBase):
|
||||
"""Function chat model that uses Ollama API."""
|
||||
|
||||
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||
call_list = []
|
||||
for c in tool_calls:
|
||||
call_list.append({
|
||||
"id": nxhash(c['id'])[-4:],
|
||||
"tool": c['name'],
|
||||
"args": c['args']
|
||||
})
|
||||
if len(call_list) == 1:
|
||||
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||
formated_history = ""
|
||||
system_msg = messages[0]
|
||||
for m in messages[1:]:
|
||||
|
||||
if formated_history != "":
|
||||
formated_history += "\n\n"
|
||||
|
||||
if isinstance(m, SystemMessage):
|
||||
formated_history += "The system provided the info:\n" + str(m.content)
|
||||
elif isinstance(m, HumanMessage):
|
||||
formated_history += "The Human said:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||
elif isinstance(m, ToolMessage):
|
||||
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||
formated_history += "You said:\n" + str(m.content)
|
||||
else:
|
||||
try:
|
||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||
except NameError:
|
||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage.")
|
||||
|
||||
|
||||
return system_msg, formated_history
|
||||
|
||||
# prepare generation with history
|
||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str,
|
||||
history=formated_history,
|
||||
system_msg=system_msg
|
||||
)
|
||||
final_messages = [ system_message ]
|
||||
|
||||
# prepare generation without history
|
||||
else:
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str
|
||||
)
|
||||
final_messages = [ system_message ] + messages
|
||||
|
||||
return final_messages
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ollama_functions"
|
||||
return "ollama_functions_lsm"
|
||||
|
||||
|
||||
|
||||
class OllamaFunctionsT2S(OllamaFunctionsBase):
|
||||
"""Function chat model that uses Ollama API."""
|
||||
|
||||
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||
# prepare generation with history
|
||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||
|
||||
transformed_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, ToolMessage):
|
||||
transformed_messages.append(SystemMessage(content=(
|
||||
f"The Tool '{m.name}' replied with:" + "\n" + str(m.content)
|
||||
)))
|
||||
elif isinstance(m, AIMessage):
|
||||
if m.tool_calls:
|
||||
l = []
|
||||
for call in m.tool_calls:
|
||||
l.append({
|
||||
"tool": call['name'],
|
||||
"tool_input": call['args']
|
||||
})
|
||||
if len(l) == 1:
|
||||
transformed_messages.append(AIMessage(content=json.dumps(l[0])))
|
||||
else:
|
||||
transformed_messages.append(AIMessage(content=json.dumps(l)))
|
||||
else:
|
||||
transformed_messages.append(m)
|
||||
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(tools=functions_str)
|
||||
|
||||
final_messages = [ system_message ] + transformed_messages
|
||||
|
||||
# prepare generation without history
|
||||
else:
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str
|
||||
)
|
||||
final_messages = [ system_message ] + messages
|
||||
|
||||
return final_messages
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "ollama_functions_t2s"
|
||||
|
||||
@@ -79,8 +79,8 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
|
||||
'technique_name': technique.name,
|
||||
})
|
||||
|
||||
# if hash_key == "DE3D137E":
|
||||
# pass
|
||||
if hash_key == "0DEB2030":
|
||||
pass
|
||||
|
||||
if hash_key not in saved_results.keys():
|
||||
try:
|
||||
@@ -105,7 +105,7 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
|
||||
"\033[0;35m)\033[0m",
|
||||
end=""
|
||||
)
|
||||
answer = test.runnable(model=model, seed=seed, test=test, base_url=base_url)
|
||||
answer = test.runnable(model=model, seed=seed, test=test, technique=technique, base_url=base_url)
|
||||
if isinstance(answer, str):
|
||||
combination['answer'] = answer
|
||||
# combination['tool_calls'] = [] # no entry
|
||||
@@ -172,13 +172,15 @@ def run_tests(models: dict[int, Model], seeds: list[int], tests: dict[int, Test]
|
||||
|
||||
try:
|
||||
entry = {
|
||||
'test_name': result['test_name'],
|
||||
'test_id': result['test_id'],
|
||||
'model_name': result['model_name'],
|
||||
'model_id': result['model_id'],
|
||||
'seed': result['seed'],
|
||||
'answer': result['answer'],
|
||||
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
|
||||
'test_name': result['test_name'],
|
||||
'test_id': result['test_id'],
|
||||
'model_name': result['model_name'],
|
||||
'model_id': result['model_id'],
|
||||
'technique_name': result['technique_name'],
|
||||
'technique_id': result['technique_id'],
|
||||
'seed': result['seed'],
|
||||
'answer': result['answer'],
|
||||
'validation': result['test'].validator(test=result['test'], answer=result['answer'], base_url=base_url),
|
||||
}
|
||||
except Exception as e:
|
||||
print("\033[0;31mError validating entry (\033[0m" + hash_key + "\033[0;31m). <\033[0m" + str(e) + "\033[0;31m> Continuing...\033[0m ")
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from types import NoneType
|
||||
from langchain_ollama.chat_models import ChatOllama
|
||||
from libs.ollama_functions import OllamaFunctions
|
||||
from libs.ollama_functions import OllamaFunctionsLSM, OllamaFunctionsT2S
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
|
||||
from libs.classes import Test, Model
|
||||
from libs.classes import Technique, Test, Model
|
||||
from langchain.tools import Tool
|
||||
from typing import Literal
|
||||
|
||||
@@ -10,22 +10,31 @@ from langgraph.graph import StateGraph, MessagesState
|
||||
import json
|
||||
from pydantic import ValidationError
|
||||
|
||||
def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType = None):
|
||||
if model.supports_tools:
|
||||
from suite_settings.techniques import techniques
|
||||
|
||||
def _get_llm(model: Model, base_url: str, seed: int, technique: Technique, tools: list[Tool]|NoneType = None):
|
||||
if technique == techniques[1]: # Native
|
||||
llm = ChatOllama(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url
|
||||
)
|
||||
else:
|
||||
llm = OllamaFunctions(
|
||||
elif technique == techniques[903]: # Long System Message
|
||||
llm = OllamaFunctionsLSM(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url,
|
||||
format="json",
|
||||
max_tool_call_fails=3,
|
||||
temperature=0.0
|
||||
)
|
||||
elif technique == techniques[572]: # ToolMessages to SystemMessages
|
||||
llm = OllamaFunctionsT2S(
|
||||
model=model.identifier,
|
||||
seed=seed,
|
||||
base_url=base_url,
|
||||
format="json",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unkown Technique in _get_llm()")
|
||||
|
||||
if tools:
|
||||
llm = llm.bind_tools(tools=tools)
|
||||
@@ -33,7 +42,7 @@ def _get_llm(model: Model, base_url: str, seed: int, tools: list[Tool]|NoneType
|
||||
return llm
|
||||
|
||||
|
||||
def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
|
||||
def basic_prompt(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> str:
|
||||
|
||||
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||
try:
|
||||
@@ -42,20 +51,20 @@ def basic_prompt(model: Model, seed: int, test: Test, base_url: str) -> str:
|
||||
pass
|
||||
messages += [ HumanMessage(test.runnable_input['human_msg']) ]
|
||||
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed)
|
||||
llm = _get_llm(model=model, base_url=base_url, technique=technique, seed=seed)
|
||||
ai_msg = llm.invoke(messages)
|
||||
assert isinstance(ai_msg.content, str)
|
||||
return ai_msg.content
|
||||
|
||||
|
||||
|
||||
def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) -> dict:
|
||||
def one_tool_call_answer(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict:
|
||||
|
||||
tools_dict = test.runnable_input['tools']
|
||||
tools = []
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
||||
|
||||
messages = [SystemMessage(test.runnable_input['system_msg'])]
|
||||
try:
|
||||
@@ -108,7 +117,7 @@ def one_tool_call_answer(model: Model, seed: int, test: Test, base_url: str) ->
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict[str, str|list]:
|
||||
def agent_with_tools(model: Model, seed: int, test: Test, technique: Technique, base_url: str) -> dict[str, str|list]:
|
||||
|
||||
tool_calls = []
|
||||
index = -1
|
||||
@@ -173,7 +182,7 @@ def agent_with_tools(model: Model, seed: int, test: Test, base_url: str) -> dict
|
||||
for key in tools_dict:
|
||||
tools.append(tools_dict[key])
|
||||
tool_node = NxToolNode(tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, tools=tools)
|
||||
llm = _get_llm(model=model, base_url=base_url, seed=seed, technique=technique, tools=tools)
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
|
||||
@@ -141,6 +141,9 @@ def get_notes_containing(patterns: Union[list[str], str]) -> str:
|
||||
ret += f"{datetime.strftime(entry.time, '%Y/%m/%d %H:%M')} {entry.content}"
|
||||
is_first = False
|
||||
|
||||
if ret == "":
|
||||
ret = "No matching notes were found. Try diffrent patterns."
|
||||
|
||||
return ret
|
||||
|
||||
@tool
|
||||
|
||||
Reference in New Issue
Block a user