import json import uuid from typing import ( Any, Callable, Dict, List, Optional, Sequence, Type, TypeVar, Union, Tuple, ) from types import NoneType from langchain_ollama.chat_models import ChatOllama from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage, ToolCall from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.prompts import SystemMessagePromptTemplate from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass from libs.functions import nxhash DEFAULT_SYTEM_PROMPT = """You have access to the following tools: {tools} You must always select one of the above tools and respond with only a JSON object matching the following schema: {{ "tool": , "tool_input": }} """ DEFAULT_SYTEM_PROMPT_WITH_HISTORY = """{system_msg} You continue a chat history either conversationally or with a tool call. You have access to the following tools: {tools} You must either select one of the above tools and respond with only a JSON object matching the following schema: {{ "tool": , "tool_input": }} or answer conversationally normally. The conversation before consisted of the following messages: {history} Now you must answer accordingly either conversationally or with another tool call. For conversational answers: Answer as if it was a continuous conversation. The Human only sees the conversational responses, and not anything about the tools. Do not mention the tools or the process of using them. """ CONVERSATIONAL_RESPONSE_TOOL = { "name": "__conversational_response", "description": ( "Respond conversationally if no other tools should be called for a given query." ), "parameters": { "type": "object", "properties": { "response": { "type": "string", "description": "Conversational response to the user.", }, }, "required": ["response"], }, } _BM = TypeVar("_BM", bound=BaseModel) _DictOrPydantic = Union[Dict, _BM] class OllamaError(Exception): def __init__(self, message): self.message = message # Store the message super().__init__(message) def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and ( is_basemodel_subclass(obj) or BaseModel in obj.__bases__ ) class OllamaFunctionsBase(ChatOllama): """Function chat model that uses Ollama API.""" tool_system_prompt_template: str = DEFAULT_SYTEM_PROMPT tool_system_prompt_template_with_history: str = DEFAULT_SYTEM_PROMPT_WITH_HISTORY max_tool_call_fails: int = 5 def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: return self.bind(functions=tools, **kwargs) def _get_final_message(self, messages: list, functions_str: str) -> list: raise NotImplementedError def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> ChatResult: def _convert_to_ollama_tool(self, tool: Any) -> Dict: """Convert a tool to an Ollama tool.""" description = None if _is_pydantic_class(tool): schema = tool.construct().schema() name = schema["title"] elif isinstance(tool, BaseTool): schema = tool.tool_call_schema.schema() name = tool.get_name() description = tool.description elif is_basemodel_instance(tool): schema = tool.get_input_schema().schema() name = tool.get_name() description = tool.description elif isinstance(tool, dict) and "name" in tool and "parameters" in tool: return tool.copy() else: raise ValueError( f"""Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic class, model, or a dict.""" ) definition = {"name": name, "parameters": schema} if description: definition["description"] = description return definition def _get_parsed_chat_result(self, chat_result_str: str) -> dict: try: parsed_chat_result = json.loads(chat_result_str) return parsed_chat_result except json.JSONDecodeError: raise OllamaError(message="Error. Message is not valid JSON.") def _get_called_tool(self, d: dict, functions_list: list[dict]) -> dict|NoneType: if not d: called_tool_name = None elif "tool" in d: called_tool_name = d["tool"] # per spec elif "name" in d: called_tool_name = d["name"] # Phi3 often does this elif "tool_name" in d: called_tool_name = d["tool_name"] # Phi3 often does this elif "action" in d: called_tool_name = d["action"] # Gemma2 does this elif "task" in d: called_tool_name = d["task"] # Gemma2 does this else: return None try: called_tool = [tool for tool in functions_list if tool['name'] == called_tool_name][0] except IndexError: return None # when a tool is called, but the tool doesnt exist return called_tool def _extract_conversaional_response(self, d: dict) -> str: if ("tool_input" in d and d["tool_input"] and "response" in d["tool_input"]): response = d["tool_input"]["response"] elif ("input" in d and d["input"] and "response" in d["input"]): response = d["input"]["response"] elif ("args" in d and d["args"] and "response" in d["args"]): response = d["args"]["response"] elif "response" in d: response = d["response"] elif "input" in d: response = d["input"] elif "args" in d: response = d["args"] elif "tool_input" in d: response = d["tool_input"] else: raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" + "{\n" + ' "tool": ,\n' + ' "tool_input": \n' + "}") try: assert isinstance(response, str) except AssertionError: raise OllamaError("Error: Failed to parse response. Make sure to follow the schema\n" + "{\n" + ' "tool": ,\n' + ' "tool_input": \n' + "}") return response def _extract_tool_args(self, d: dict) -> dict: if "tool_input" in d: called_tool_args = d["tool_input"] # per spec elif "input" in d: called_tool_args = d["input"] # Phi3 often does this elif "args" in d: called_tool_args = d["args"] else: called_tool_args = {} return called_tool_args def gen(self, failed_tool_calls: int, messages: list) -> ChatResult: # prepare generation functions_list = [_convert_to_ollama_tool(self, fn) for fn in kwargs.get("functions", [])] functions_list.append(CONVERSATIONAL_RESPONSE_TOOL) functions_str = json.dumps(functions_list, indent=2) # get messages to prompt with final_messages = self._get_final_message(messages=messages, functions_str=functions_str) # genrerate chat result response_message = super()._generate(final_messages, stop=stop, run_manager=run_manager, **kwargs) chat_result = response_message.generations[0].text try: # make str to dict parsed_chat_result = _get_parsed_chat_result(self, chat_result_str=chat_result) # get the called tool from the dict called_tool = _get_called_tool(self, d=parsed_chat_result, functions_list=functions_list) if (not called_tool) or (called_tool == CONVERSATIONAL_RESPONSE_TOOL): response_msg = AIMessage(content=_extract_conversaional_response(self, d=parsed_chat_result)) else: response_msg = AIMessage( content="", tool_calls=[ToolCall( name=called_tool['name'], args=_extract_tool_args(self, d=parsed_chat_result), id=f"call_{str(uuid.uuid4()).replace('-', '')}", )], ) return ChatResult(generations=[ChatGeneration(message=response_msg)]) except OllamaError as e: if failed_tool_calls < self.max_tool_call_fails: # retry messages.append(AIMessage(chat_result)) messages.append(SystemMessage(e.message)) return gen(self, failed_tool_calls+1, messages=messages) else: # return error # return ChatResult(generations=[ChatGeneration(message=SystemMessage(content=e.message))]) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=">>Model failed<<"))]) # inital call with no failed runs return gen(self, failed_tool_calls=0, messages=messages) class OllamaFunctionsLSM(OllamaFunctionsBase): """Function chat model that uses Ollama API.""" def _get_final_message(self, messages: list, functions_str: str) -> list: def _get_system_msg_and_formatted_history(self, messages: list) -> Tuple[str, str]: def _format_tools_for_history(tool_calls: list[ToolCall]) -> str: call_list = [] for c in tool_calls: call_list.append({ "id": nxhash(c['id'])[-4:], "tool": c['name'], "args": c['args'] }) if len(call_list) == 1: return json.dumps(obj=call_list[0], ensure_ascii=False, indent=2) else: return json.dumps(obj=call_list, ensure_ascii=False, indent=2) formated_history = "" system_msg = messages[0] for m in messages[1:]: if formated_history != "": formated_history += "\n\n" if isinstance(m, SystemMessage): formated_history += "The system provided the info:\n" + str(m.content) elif isinstance(m, HumanMessage): formated_history += "The Human said:\n" + str(m.content) elif isinstance(m, AIMessage) and m.tool_calls: formated_history += "So you called the tool" + (":\n" if len(m.tool_calls) == 1 else "s:\n") + _format_tools_for_history(m.tool_calls) elif isinstance(m, ToolMessage): formated_history += "To which the tool (" + nxhash(m.tool_call_id)[-4:] + ") replied with:\n" + str(m.content) elif isinstance(m, AIMessage) and not m.tool_calls: formated_history += "You said:\n" + str(m.content) else: try: raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage but got " + str(type(m))) except NameError: raise TypeError("OllamaFunctions only supports SystemMessage HumanMessage ToolMessage AIMessage.") return system_msg, formated_history # prepare generation with history if True in [ isinstance(m, ToolMessage) for m in messages ]: system_msg, formated_history = _get_system_msg_and_formatted_history(self, messages=messages) system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template_with_history) system_message = system_message_prompt_template.format( tools=functions_str, history=formated_history, system_msg=system_msg ) final_messages = [ system_message ] # prepare generation without history else: system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) system_message = system_message_prompt_template.format( tools=functions_str ) final_messages = [ system_message ] + messages return final_messages @property def _llm_type(self) -> str: return "ollama_functions_lsm" class OllamaFunctionsT2S(OllamaFunctionsBase): """Function chat model that uses Ollama API.""" def _get_final_message(self, messages: list, functions_str: str) -> list: # prepare generation with history if True in [ isinstance(m, ToolMessage) for m in messages ]: transformed_messages = [] for m in messages: if isinstance(m, ToolMessage): transformed_messages.append(SystemMessage(content=( f"The Tool '{m.name}' replied with:" + "\n" + str(m.content) ))) elif isinstance(m, AIMessage): if m.tool_calls: l = [] for call in m.tool_calls: l.append({ "tool": call['name'], "tool_input": call['args'] }) if len(l) == 1: transformed_messages.append(AIMessage(content=json.dumps(l[0]))) else: transformed_messages.append(AIMessage(content=json.dumps(l))) else: transformed_messages.append(m) system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) system_message = system_message_prompt_template.format(tools=functions_str) final_messages = [ system_message ] + transformed_messages # prepare generation without history else: system_message_prompt_template = SystemMessagePromptTemplate.from_template(self.tool_system_prompt_template) system_message = system_message_prompt_template.format( tools=functions_str ) final_messages = [ system_message ] + messages return final_messages @property def _llm_type(self) -> str: return "ollama_functions_t2s"