WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
6.04 kB
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
from xtuner._lite import get_logger
from xtuner.utils import IGNORE_INDEX
from ..templates import ChatTemplate, HybridChatTemplate
from .base import BaseMessages
logger = get_logger()
class TextContentItem(BaseModel):
type: Literal["text"] = "text"
text: str
def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
return self.text
class ImageContentItem(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: str
def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
return chat_template.image_token
MultModalContentType = Union[TextContentItem, ImageContentItem]
ContentType = Union[str, List[MultModalContentType]]
class ChatMsg(BaseModel):
role: Literal["assistant", "user", "system"]
content: ContentType
loss: Optional[bool] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.loss is None:
if self.role == "system":
self.loss = False
elif self.role == "user":
self.loss = False
elif self.role == "assistant":
self.loss = True
else:
raise NotImplementedError
def collect_img_urls(self) -> List[str]:
img_urls = []
if isinstance(self.content, list):
for item in self.content:
if isinstance(item, ImageContentItem):
img_urls.append(item.image_url)
return img_urls
def get_prompt(self, chat_template: ChatTemplate) -> str:
if isinstance(self.content, str):
text = self.content
elif isinstance(self.content, list):
text = ""
for i, item in enumerate(self.content):
if i == 0:
text += item.apply_chat_template(chat_template)
else:
text += "\n" + item.apply_chat_template(chat_template)
else:
raise NotImplementedError
if self.role == "system":
prompt = chat_template.decorate_system(text)
elif self.role == "user":
prompt = chat_template.decorate_user(text)
elif self.role == "assistant":
prompt = chat_template.decorate_assistant(text)
else:
raise NotImplementedError
return prompt
def tokenize(
self,
tokenizer: PreTrainedTokenizer,
chat_template: ChatTemplate,
):
decorated = self.get_prompt(chat_template)
token_ids = tokenizer.encode(decorated, add_special_tokens=False)
if self.loss:
label_ids = copy.deepcopy(token_ids)
else:
label_ids = [IGNORE_INDEX] * len(token_ids)
return {
"input_ids": token_ids,
"labels": label_ids,
}
class ChatMessages(BaseMessages):
messages: List[ChatMsg]
def add(self, role, content, loss=False):
self.messages.append(ChatMsg(role=role, content=content, loss=loss))
def pop(self):
return self.messages.pop()
def get_prompt(self, chat_template: ChatTemplate) -> str:
prompt = ""
for msg in self.messages:
prompt += msg.get_prompt(chat_template)
if msg.role == "assistant":
prompt += chat_template.sep
return prompt
def tokenize(
self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate
) -> Dict:
input_ids = tokenizer.encode("", add_special_tokens=True)
labels = [IGNORE_INDEX for _ in input_ids]
image_urls = []
for msg in self.messages:
res = msg.tokenize(tokenizer, chat_template)
token_ids, label_ids = res["input_ids"], res["labels"]
input_ids.extend(token_ids)
labels.extend(label_ids)
image_urls.extend(msg.collect_img_urls())
if msg.role == "assistant":
sep = chat_template.sep
sep_tokens = tokenizer.encode(sep, add_special_tokens=False)
input_ids.extend(sep_tokens)
labels.extend([IGNORE_INDEX] * len(sep_tokens))
if len(input_ids) != len(labels):
logger.error(f"[messages] {self.messages}")
logger.error(f"[input_ids] {input_ids}")
logger.error(f"[labels] {labels}")
raise RuntimeError(
"The lengths of input_ids and labels must be "
f"equal, but found {len(input_ids)} and "
f"{len(labels)}."
)
training_data = {
"input_ids": input_ids,
"labels": labels,
"num_tokens": len(input_ids),
}
if len(image_urls) > 0:
training_data["image_urls"] = image_urls
return training_data
@classmethod
def from_str(cls, prompt: str) -> "ChatMessages":
msg = ChatMsg(role="user", content=prompt)
return cls(messages=[msg])
@classmethod
def from_dict(cls, item: dict) -> "ChatMessages":
"""
item
{
'messages':[
{'role':'user', 'content':'hello'},
{'role':'assistant', 'content':'hello!'},
],
}
"""
return cls(**item)
if __name__ == "__main__":
data = {
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hello!"},
]
}
messages = ChatMessages.from_dict(data)
chat_template = ChatTemplate(
system="<|im_start|>system\n{system}<|im_end|>\n",
user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
assistant="{assistant}<|im_end|>\n",
stop_words=["<|im_end|>"],
)
print(messages.get_prompt(chat_template))