This commit is contained in:
Lennart J. Kurzweg (Nx2)
2024-08-26 21:20:47 +02:00
parent 2723ced901
commit 5d7ce3cf71
12 changed files with 2055 additions and 2350 deletions

View File

@@ -20,11 +20,10 @@ from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, Tool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
from textwrap import dedent
from libs.functions import nxhash
@@ -98,14 +97,15 @@ def _is_pydantic_class(obj: Any) -> bool:
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
)
class OllamaFunctions(ChatOllama):
class OllamaFunctionsBase(ChatOllama):
"""Function chat model that uses Ollama API."""
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
max_tool_call_fails: int = 5
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def bind_tools(
@@ -115,6 +115,8 @@ class OllamaFunctions(ChatOllama):
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(functions=tools, **kwargs)
def _get_final_message(self, messages: list, functions_str: str) -> list:
raise NotImplementedError
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
@@ -177,11 +179,11 @@ class OllamaFunctions(ChatOllama):
return called_tool
def _extract_conversaional_response(self, d: dict) -> str:
if ("tool_input" in d and "response" in d["tool_input"]):
if ("tool_input" in d and d["tool_input"] and "response" in d["tool_input"]):
response = d["tool_input"]["response"]
elif ("input" in d and "response" in d["input"]):
elif ("input" in d and d["input"] and "response" in d["input"]):
response = d["input"]["response"]
elif ("args" in d and "response" in d["args"]):
elif ("args" in d and d["args"] and "response" in d["args"]):
response = d["args"]["response"]
elif "response" in d:
response = d["response"]
@@ -220,66 +222,6 @@ class OllamaFunctions(ChatOllama):
called_tool_args = {}
return called_tool_args
def _get_final_message(self, messages: list, functions_str: str) -> list:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = messages[0]
for m in messages[1:]:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
formated_history += "The system provided the info:\n" + str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
return system_msg, formated_history
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# prepare generation without history
else:
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(
tools=functions_str
)
final_messages = [ system_message ] + messages
return final_messages
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
@@ -289,7 +231,7 @@ class OllamaFunctions(ChatOllama):
functions_str = json.dumps(functions_list, indent=2)
# get messages to prompt with
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
final_messages = self._get_final_message(messages=messages, functions_str=functions_str)
# genrerate chat result
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
@@ -329,6 +271,125 @@ class OllamaFunctions(ChatOllama):
return gen(self, failed_tool_calls=0, messages=messages)
class OllamaFunctionsLSM(OllamaFunctionsBase):
"""Function chat model that uses Ollama API."""
def _get_final_message(self, messages: list, functions_str: str) -> list:
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
call_list = []
for c in tool_calls:
call_list.append({
"id": nxhash(c['id'])[-4:],
"tool": c['name'],
"args": c['args']
})
if len(call_list) == 1:
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
else:
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
formated_history = ""
system_msg = messages[0]
for m in messages[1:]:
if formated_history != "":
formated_history += "\n\n"
if isinstance(m, SystemMessage):
formated_history += "The system provided the info:\n" + str(m.content)
elif isinstance(m, HumanMessage):
formated_history += "The Human said:\n" + str(m.content)
elif isinstance(m, AIMessage) and m.tool_calls:
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
elif isinstance(m, ToolMessage):
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
elif isinstance(m, AIMessage) and not m.tool_calls:
formated_history += "You said:\n" + str(m.content)
else:
try:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
except NameError:
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage.")
return system_msg, formated_history
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
system_message = system_message_prompt_template.format(
tools=functions_str,
history=formated_history,
system_msg=system_msg
)
final_messages = [ system_message ]
# prepare generation without history
else:
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(
tools=functions_str
)
final_messages = [ system_message ] + messages
return final_messages
@property
def _llm_type(self) -> str:
return "ollama_functions"
return "ollama_functions_lsm"
class OllamaFunctionsT2S(OllamaFunctionsBase):
"""Function chat model that uses Ollama API."""
def _get_final_message(self, messages: list, functions_str: str) -> list:
# prepare generation with history
if True in [ isinstance(m, ToolMessage) for m in messages ]:
transformed_messages = []
for m in messages:
if isinstance(m, ToolMessage):
transformed_messages.append(SystemMessage(content=(
f"The Tool '{m.name}' replied with:" + "\n" + str(m.content)
)))
elif isinstance(m, AIMessage):
if m.tool_calls:
l = []
for call in m.tool_calls:
l.append({
"tool": call['name'],
"tool_input": call['args']
})
if len(l) == 1:
transformed_messages.append(AIMessage(content=json.dumps(l[0])))
else:
transformed_messages.append(AIMessage(content=json.dumps(l)))
else:
transformed_messages.append(m)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(tools=functions_str)
final_messages = [ system_message ] + transformed_messages
# prepare generation without history
else:
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
system_message = system_message_prompt_template.format(
tools=functions_str
)
final_messages = [ system_message ] + messages
return final_messages
@property
def _llm_type(self) -> str:
return "ollama_functions_t2s"