m2
This commit is contained in:
@@ -24,6 +24,7 @@ from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||
from textwrap import dedent
|
||||
|
||||
from libs.functions import nxhash
|
||||
|
||||
@@ -87,65 +88,24 @@ _BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message # Store the message
|
||||
super().__init__(message)
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and (
|
||||
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||
)
|
||||
|
||||
|
||||
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||
"""Convert a tool to an Ollama tool."""
|
||||
description = None
|
||||
if _is_pydantic_class(tool):
|
||||
schema = tool.construct().schema()
|
||||
name = schema["title"]
|
||||
elif isinstance(tool, BaseTool):
|
||||
schema = tool.tool_call_schema.schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
elif is_basemodel_instance(tool):
|
||||
schema = tool.get_input_schema().schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
||||
return tool.copy()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Cannot convert {tool} to an Ollama tool.
|
||||
{tool} needs to be a Pydantic class, model, or a dict."""
|
||||
)
|
||||
definition = {"name": name, "parameters": schema}
|
||||
if description:
|
||||
definition["description"] = description
|
||||
|
||||
return definition
|
||||
|
||||
def parse_response(message: BaseMessage) -> str:
|
||||
"""Extract `function_call` from `AIMessage`."""
|
||||
if isinstance(message, AIMessage):
|
||||
kwargs = message.additional_kwargs
|
||||
tool_calls = message.tool_calls
|
||||
if len(tool_calls) > 0:
|
||||
tool_call = tool_calls[-1]
|
||||
args = tool_call.get("args")
|
||||
return json.dumps(args)
|
||||
elif "function_call" in kwargs:
|
||||
if "arguments" in kwargs["function_call"]:
|
||||
return kwargs["function_call"]["arguments"]
|
||||
raise ValueError(f"`arguments` missing from `function_call` within AIMessage: {message}")
|
||||
else:
|
||||
raise ValueError("`tool_calls` missing from AIMessage: {message}")
|
||||
raise ValueError(f"`message` is not an instance of `AIMessage`: {message}")
|
||||
|
||||
|
||||
|
||||
class OllamaFunctions(ChatOllama):
|
||||
"""Function chat model that uses Ollama API."""
|
||||
|
||||
tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT
|
||||
tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY
|
||||
max_tool_call_fails: int = 5
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(self, max_tool_call_fails, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def bind_tools(
|
||||
@@ -157,61 +117,55 @@ class OllamaFunctions(ChatOllama):
|
||||
|
||||
|
||||
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult:
|
||||
def _convert_to_ollama_tool(self, tool: Any) -> Dict:
|
||||
"""Convert a tool to an Ollama tool."""
|
||||
description = None
|
||||
if _is_pydantic_class(tool):
|
||||
schema = tool.construct().schema()
|
||||
name = schema["title"]
|
||||
elif isinstance(tool, BaseTool):
|
||||
schema = tool.tool_call_schema.schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
elif is_basemodel_instance(tool):
|
||||
schema = tool.get_input_schema().schema()
|
||||
name = tool.get_name()
|
||||
description = tool.description
|
||||
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
||||
return tool.copy()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"""Cannot convert {tool} to an Ollama tool.
|
||||
{tool} needs to be a Pydantic class, model, or a dict."""
|
||||
)
|
||||
definition = {"name": name, "parameters": schema}
|
||||
if description:
|
||||
definition["description"] = description
|
||||
|
||||
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||
call_list = []
|
||||
for c in tool_calls:
|
||||
call_list.append({
|
||||
"id": nxhash(c['id'])[-4:],
|
||||
"tool": c['name'],
|
||||
"args": c['args']
|
||||
})
|
||||
if len(call_list) == 1:
|
||||
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||
formated_history = ""
|
||||
system_msg = ""
|
||||
for m in messages:
|
||||
return definition
|
||||
|
||||
if formated_history != "":
|
||||
formated_history += "\n\n"
|
||||
|
||||
if isinstance(m, SystemMessage):
|
||||
system_msg += str(m.content)
|
||||
elif isinstance(m, HumanMessage):
|
||||
formated_history += "The Human said:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||
elif isinstance(m, ToolMessage):
|
||||
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||
formated_history += "You said:\n" + str(m.content)
|
||||
else:
|
||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||
|
||||
return system_msg, formated_history
|
||||
|
||||
|
||||
def _get_parsed_chat_result(self, chat_result_str: str) -> Union[dict, str]:
|
||||
def _get_parsed_chat_result(self, chat_result_str: str) -> dict:
|
||||
try:
|
||||
parsed_chat_result = json.loads(chat_result_str)
|
||||
return parsed_chat_result
|
||||
except json.JSONDecodeError:
|
||||
parsed_chat_result = chat_result_str
|
||||
return parsed_chat_result
|
||||
raise OllamaError(message="Error. Message is not valid JSON.")
|
||||
|
||||
def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType:
|
||||
if not parsed_chat_result:
|
||||
if not d:
|
||||
called_tool_name = None
|
||||
elif "tool" in parsed_chat_result:
|
||||
elif "tool" in d:
|
||||
called_tool_name = d["tool"] # per spec
|
||||
elif "name" in d:
|
||||
called_tool_name = d["name"] # Phi3 often does this
|
||||
elif "tool_name" in d:
|
||||
called_tool_name = d["tool_name"] # Phi3 often does this
|
||||
elif "action" in d:
|
||||
called_tool_name = d["action"] # Phi3 does this
|
||||
called_tool_name = d["action"] # Gemma2 does this
|
||||
elif "task" in d:
|
||||
called_tool_name = d["task"] # Gemma2 does this
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -238,17 +192,25 @@ class OllamaFunctions(ChatOllama):
|
||||
elif "tool_input" in d:
|
||||
response = d["tool_input"]
|
||||
else:
|
||||
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
||||
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
|
||||
"{\n" +
|
||||
' "tool": <name of the selected tool>,\n' +
|
||||
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
|
||||
"}")
|
||||
|
||||
try:
|
||||
assert isinstance(response, str)
|
||||
except AssertionError:
|
||||
raise ValueError(f"Failed to parse a response from {self.model} output: {chat_result}")
|
||||
raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" +
|
||||
"{\n" +
|
||||
' "tool": <name of the selected tool>,\n' +
|
||||
' "tool_input": <parameters for the selected tool, matching the tool\'s JSON schema>\n' +
|
||||
"}")
|
||||
|
||||
return response
|
||||
|
||||
def _extract_tool_args(self, d: dict) -> dict:
|
||||
if "tool_input" in parsed_chat_result:
|
||||
if "tool_input" in d:
|
||||
called_tool_args = d["tool_input"] # per spec
|
||||
elif "input" in d:
|
||||
called_tool_args = d["input"] # Phi3 often does this
|
||||
@@ -258,64 +220,114 @@ class OllamaFunctions(ChatOllama):
|
||||
called_tool_args = {}
|
||||
return called_tool_args
|
||||
|
||||
# prepare generation
|
||||
functions_list = [convert_to_ollama_tool(fn) for fn in kwargs.get("functions", [])]
|
||||
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
|
||||
functions_str = json.dumps(functions_list, indent=2)
|
||||
def _get_final_message(self, messages: list, functions_str: str) -> list:
|
||||
def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]:
|
||||
def _format_tools_for_history(tool_calls: list[ToolCall]) -> str:
|
||||
call_list = []
|
||||
for c in tool_calls:
|
||||
call_list.append({
|
||||
"id": nxhash(c['id'])[-4:],
|
||||
"tool": c['name'],
|
||||
"args": c['args']
|
||||
})
|
||||
if len(call_list) == 1:
|
||||
return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
return json.dumps(obj=call_list, ensure_ascii=False, indent=2)
|
||||
formated_history = ""
|
||||
system_msg = messages[0]
|
||||
for m in messages[1:]:
|
||||
|
||||
# prepare generation with history
|
||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str,
|
||||
history=formated_history,
|
||||
system_msg=system_msg
|
||||
)
|
||||
final_messages = [ system_message ]
|
||||
if formated_history != "":
|
||||
formated_history += "\n\n"
|
||||
|
||||
# prepare generation without history
|
||||
else:
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str
|
||||
)
|
||||
final_messages = [ system_message ] + messages
|
||||
if isinstance(m, SystemMessage):
|
||||
formated_history += "The system provided the info:\n" + str(m.content)
|
||||
elif isinstance(m, HumanMessage):
|
||||
formated_history += "The Human said:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and m.tool_calls:
|
||||
formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls)
|
||||
elif isinstance(m, ToolMessage):
|
||||
formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content)
|
||||
elif isinstance(m, AIMessage) and not m.tool_calls:
|
||||
formated_history += "You said:\n" + str(m.content)
|
||||
else:
|
||||
raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m)))
|
||||
|
||||
# genrerate chat result
|
||||
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
chat_result = response_message.generations[0].text
|
||||
return system_msg, formated_history
|
||||
|
||||
# chekc for validity
|
||||
if not isinstance(chat_result, str):
|
||||
raise ValueError("OllamaFunctions does not support non-string output.")
|
||||
# prepare generation with history
|
||||
if True in [ isinstance(m, ToolMessage) for m in messages ]:
|
||||
system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages)
|
||||
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str,
|
||||
history=formated_history,
|
||||
system_msg=system_msg
|
||||
)
|
||||
final_messages = [ system_message ]
|
||||
|
||||
# make str to dict
|
||||
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
|
||||
# if model failed to return vailid json, just retrun the whole thing
|
||||
if isinstance(parsed_chat_result, str):
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=parsed_chat_result))])
|
||||
# prepare generation without history
|
||||
else:
|
||||
system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template)
|
||||
system_message = system_message_prompt_template.format(
|
||||
tools=functions_str
|
||||
)
|
||||
final_messages = [ system_message ] + messages
|
||||
|
||||
return final_messages
|
||||
|
||||
|
||||
# get the called tool from the dict
|
||||
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
|
||||
|
||||
|
||||
if not called_tool:
|
||||
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||
elif called_tool == CONVERSATIONAL_RESPONSE_TOOL:
|
||||
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||
else:
|
||||
response_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(
|
||||
name=called_tool['name'],
|
||||
args=_extract_tool_args(self, d=parsed_chat_result),
|
||||
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||
)],
|
||||
)
|
||||
def gen(self, failed_tool_calls: int, messages: list) -> ChatResult:
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=response_msg)])
|
||||
# prepare generation
|
||||
functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])]
|
||||
functions_list.append(CONVERSATIONAL_RESPONSE_TOOL)
|
||||
functions_str = json.dumps(functions_list, indent=2)
|
||||
|
||||
# get messages to prompt with
|
||||
final_messages = _get_final_message(self, messages=messages, functions_str=functions_str)
|
||||
|
||||
# genrerate chat result
|
||||
response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
chat_result = response_message.generations[0].text
|
||||
|
||||
try:
|
||||
# make str to dict
|
||||
parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result)
|
||||
|
||||
# get the called tool from the dict
|
||||
called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list)
|
||||
|
||||
if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL):
|
||||
response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result))
|
||||
else:
|
||||
response_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(
|
||||
name=called_tool['name'],
|
||||
args=_extract_tool_args(self, d=parsed_chat_result),
|
||||
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
||||
)],
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=response_msg)])
|
||||
except OllamaError as e:
|
||||
if failed_tool_calls < self.max_tool_call_fails:
|
||||
# retry
|
||||
messages.append(AIMessage(chat_result))
|
||||
messages.append(SystemMessage(e.message))
|
||||
return gen(self, failed_tool_calls+1, messages=messages)
|
||||
else:
|
||||
# return error
|
||||
# return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))])
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))])
|
||||
|
||||
# inital call with no failed runs
|
||||
return gen(self, failed_tool_calls=0, messages=messages)
|
||||
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user