65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
from langchain_ollama.chat_models import ChatOllama
|
|
from langchain_core.prompts import HumanMessagePromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate
|
|
from langchain.tools import tool
|
|
from libs.classes import Test
|
|
from re import search
|
|
from textwrap import dedent
|
|
|
|
def system_human_answer_match(test: Test, answer: str, base_url: str) -> bool:
|
|
|
|
@tool
|
|
def rate(rating: bool) -> bool:
|
|
"""Rate answer as correct (True) or as incorrect (False)."""
|
|
return rating
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
SystemMessagePromptTemplate.from_template(template=dedent("""You evaluate LLMs. Rate the LLM answer as correct, if the answer is
|
|
{validation_input}
|
|
|
|
else as incorrect. Only use the `rate` tool. You do not have accesss to any other tools. Do not answer conversationally.""")),
|
|
HumanMessagePromptTemplate.from_template(template=dedent("""System Message:
|
|
{system_msg}
|
|
|
|
Human query:
|
|
{human_msg}
|
|
|
|
LLM answer:
|
|
{answer}
|
|
"""))
|
|
]).invoke({
|
|
"validation_input": test.validation_input['criteria'],
|
|
"system_msg": test.runnable_input['system_msg'],
|
|
"human_msg": test.runnable_input['human_msg'],
|
|
"answer": answer
|
|
})
|
|
|
|
llm = ChatOllama(
|
|
model="llama3.1:70b",
|
|
# model="llama3-groq-tool-use:70b",
|
|
base_url=base_url
|
|
).bind_tools([rate])
|
|
|
|
ai_msg = llm.invoke(prompt)
|
|
|
|
try:
|
|
tool_call = ai_msg.tool_calls[0]
|
|
if tool_call['name'] != "rate":
|
|
raise Exception(f"Verificaiton model tried to tool `{tool_call['name']}` not `rate`")
|
|
ret_str = rate.invoke(tool_call).content
|
|
if ret_str.lower() == 'true': return True
|
|
elif ret_str.lower() == 'false': return False
|
|
else: raise Exception(f"rate tool retured {ret_str}")
|
|
except IndexError as e:
|
|
print(f"\033[0;31mValidation Error of\033[0m {test.name} \033[0;31m<\033[0m{ai_msg.content[:20]}\033[0;31m...> Retrying...\033[0m")
|
|
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
|
except Exception as e:
|
|
print(f"\033[0;31mValidation Error \033[0mof {test.name} \033[0;31m<\033[0m{e}\033[0;31m> Retrying...\033[0m")
|
|
return system_human_answer_match(test=test, answer=answer, base_url=base_url)
|
|
|
|
def regex_match_any(test: Test, answer: str, base_url: str) -> bool:
|
|
match = False
|
|
for pattern in test.validation_input['patterns']:
|
|
if search(pattern, answer):
|
|
match = True
|
|
return match
|