|
|
|
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from pydantic import BaseModel, field_validator |
|
|
|
|
|
|
|
|
class HybridChatTemplate(BaseModel): |
|
|
"""Define a Pydantic data model for a hybrid chat with attributes for |
|
|
system, user and assistant chat as well as function and interpreter calls |
|
|
and results.""" |
|
|
|
|
|
|
|
|
system: str |
|
|
user: str |
|
|
assistant: str |
|
|
stop_words: List[str] |
|
|
sep: str = "\n" |
|
|
|
|
|
|
|
|
|
|
|
image_token: str = "<image>" |
|
|
image_token_index: int = -100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
files: Optional[str] = None |
|
|
|
|
|
functions: Optional[str] = None |
|
|
function_call: Optional[str] = None |
|
|
function_result: Optional[str] = None |
|
|
|
|
|
code_interpreter: Optional[str] = None |
|
|
code_interpreter_call: Optional[str] = None |
|
|
code_interpreter_result: Optional[str] = None |
|
|
|
|
|
function_token: Optional[str] = None |
|
|
code_interpreter_token: Optional[str] = None |
|
|
action_start_token: Optional[str] = None |
|
|
action_end_token: Optional[str] = None |
|
|
|
|
|
@property |
|
|
def mm_token_maps(self) -> Dict[str, int]: |
|
|
"""Return a dictionary that maps multimodal tokens to corresponding |
|
|
token indexes.""" |
|
|
return {self.image_token: self.image_token_index} |
|
|
|
|
|
def decorate_system(self, text: str) -> str: |
|
|
"""Decorate text with the `system` template.""" |
|
|
return self.system.format(system=text) |
|
|
|
|
|
def decorate_assistant(self, text: str) -> str: |
|
|
"""Decorate text with the `assistant` template.""" |
|
|
return self.assistant.format(assistant=text) |
|
|
|
|
|
def decorate_user(self, text: str) -> str: |
|
|
"""Decorate text with the `user` template.""" |
|
|
return self.user.format(user=text) |
|
|
|
|
|
def decorate_files(self, text: str) -> str: |
|
|
"""Decorate text with the `functions` template.""" |
|
|
return self.files.format(files=text) |
|
|
|
|
|
def decorate_functions(self, text: str) -> str: |
|
|
"""Decorate text with the `functions` template.""" |
|
|
return self.functions.format(functions=text) |
|
|
|
|
|
def decorate_function_call(self, text: str, func: str) -> str: |
|
|
"""Decorate text with the `function_call` template.""" |
|
|
return self.function_call.format(assistant=text, function_call=func) |
|
|
|
|
|
def decorate_function_result(self, text: str) -> str: |
|
|
"""Decorate text with the `function_result` template.""" |
|
|
return self.function_result.format(function_result=text) |
|
|
|
|
|
def decorate_code_interpreter(self, text: str) -> str: |
|
|
"""Decorate text with the `code_interpreter` template.""" |
|
|
return self.code_interpreter.format(code_interpreter=text) |
|
|
|
|
|
def decorate_code_interpreter_call(self, text: str, func: str) -> str: |
|
|
"""Decorate text with the `code_interpreter_call` template.""" |
|
|
return self.code_interpreter_call.format( |
|
|
assistant=text, code_interpreter_call=func |
|
|
) |
|
|
|
|
|
def decorate_code_interpreter_result(self, text: str) -> str: |
|
|
"""Decorate text with the `code_interpreter_result` template.""" |
|
|
return self.code_interpreter_result.format(code_interpreter_result=text) |
|
|
|
|
|
@field_validator("system") |
|
|
def check_system(cls, v: str) -> str: |
|
|
"""Validate that `system` contains '{system}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{system}" not in v: |
|
|
raise ValueError("system must contain the keyword '{system}'") |
|
|
return v |
|
|
|
|
|
@field_validator("user") |
|
|
def check_user(cls, v: str) -> str: |
|
|
"""Validate that `user` contains '{user}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{user}" not in v: |
|
|
raise ValueError("user must contain the keyword '{user}'") |
|
|
return v |
|
|
|
|
|
@field_validator("assistant") |
|
|
def check_assistant(cls, v: str) -> str: |
|
|
"""Validate that `assistant` contains '{assistant}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{assistant}" not in v: |
|
|
raise ValueError("assistant must contain the keyword '{assistant}'") |
|
|
return v |
|
|
|
|
|
@field_validator("function_call") |
|
|
def check_function_call(cls, v: str) -> str: |
|
|
"""Validate that `function_call` contains '{function_call}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{function_call}" not in v and "{assistant}" not in v: |
|
|
raise ValueError( |
|
|
"function_call must contain the keywords '{function_call}'" |
|
|
) |
|
|
if v is not None and "{assistant}" not in v: |
|
|
raise ValueError( |
|
|
"function_call must contain the keyword '{assistant}' and " |
|
|
"'{function_call}'" |
|
|
) |
|
|
return v |
|
|
|
|
|
@field_validator("function_result") |
|
|
def check_function_result(cls, v: str) -> str: |
|
|
"""Validate that `function_result` contains '{function_result}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{function_result}" not in v: |
|
|
raise ValueError( |
|
|
"function_result must contain the keyword '{function_result}'" |
|
|
) |
|
|
return v |
|
|
|
|
|
@field_validator("functions") |
|
|
def check_functions(cls, v: str) -> str: |
|
|
"""Validate that `functions` contains '{functions}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{functions}" not in v: |
|
|
raise ValueError("functions must contain the keyword '{functions}'") |
|
|
return v |
|
|
|
|
|
@field_validator("code_interpreter") |
|
|
def check_code_interpreter(cls, v: str) -> str: |
|
|
"""Validate that `code_interpreter` contains '{code_interpreter}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{code_interpreter}" not in v: |
|
|
raise ValueError( |
|
|
"code_interpreter must contain the keyword " "'{code_interpreter}'" |
|
|
) |
|
|
return v |
|
|
|
|
|
@field_validator("code_interpreter_call") |
|
|
def check_code_interpreter_call(cls, v: str) -> str: |
|
|
"""Validate that `code_interpreter_call` contains |
|
|
'{code_interpreter_call}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if ( |
|
|
v is not None |
|
|
and "{code_interpreter_call}" not in v |
|
|
and "{assistant}" not in v |
|
|
): |
|
|
raise ValueError( |
|
|
"code_interpreter_call must contain the keywords " |
|
|
"'{assistant}' and '{code_interpreter_call}'" |
|
|
) |
|
|
if v is not None and "{assistant}" not in v: |
|
|
raise ValueError( |
|
|
"code_interpreter_call must contain the keywords " |
|
|
"'{assistant}' and '{code_interpreter_call}'" |
|
|
) |
|
|
return v |
|
|
|
|
|
@field_validator("code_interpreter_result") |
|
|
def check_code_interpreter_result(cls, v: str) -> str: |
|
|
"""Validate that `code_interpreter_result` contains |
|
|
'{code_interpreter_result}'. |
|
|
|
|
|
If not, raises a ValueError. |
|
|
""" |
|
|
if v is not None and "{code_interpreter_result}" not in v: |
|
|
raise ValueError( |
|
|
"code_interpreter_result must contain the keyword " |
|
|
"'{code_interpreter_result}'" |
|
|
) |
|
|
return v |
|
|
|