|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Trainer for supervised training.""" |
|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
|
|
|
import deepspeed |
|
|
import torch |
|
|
import transformers |
|
|
from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor |
|
|
|
|
|
from align_anything.datasets.janus import SupervisedBatch, SupervisedDataset, SupervisedTokenizedDataset |
|
|
from align_anything.trainers.text_to_text.sft import SupervisedTrainer as SupervisedtextTrainer |
|
|
from align_anything.utils.device_utils import torch_set_device |
|
|
from align_anything.utils.multi_process import get_current_device |
|
|
from align_anything.utils.tools import ( |
|
|
custom_cfgs_to_dict, |
|
|
dict_to_namedtuple, |
|
|
read_cfgs, |
|
|
seed_everything, |
|
|
update_dict, |
|
|
) |
|
|
|
|
|
|
|
|
transformers.logging.set_verbosity_info() |
|
|
|
|
|
|
|
|
class SuperviseTrainer(SupervisedtextTrainer): |
|
|
|
|
|
def init_datasets(self) -> None: |
|
|
"""Initialize training and evaluation datasets.""" |
|
|
self.train_dataloader, self.eval_dataloader = self.get_dataloaders( |
|
|
SupervisedTokenizedDataset, SupervisedTokenizedDataset |
|
|
) |
|
|
|
|
|
def update_configs(self, model_config, args, fields): |
|
|
cross_update = lambda a, b, field_name: ( |
|
|
setattr(b, field_name, getattr(a, field_name)) |
|
|
if getattr(b, field_name, None) is None |
|
|
else setattr(a, field_name, getattr(b, field_name)) |
|
|
) |
|
|
|
|
|
for f in fields: |
|
|
cross_update(model_config, args, f) |
|
|
|
|
|
def init_models(self) -> None: |
|
|
"""Initialize model and tokenizer.""" |
|
|
self.model = MultiModalityCausalLM.from_pretrained( |
|
|
self.cfgs.model_cfgs.model_name_or_path, |
|
|
).to(get_current_device()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.cfgs.train_cfgs.bf16: |
|
|
self.model = self.model.to(torch.bfloat16) |
|
|
|
|
|
self.processor = VLChatProcessor.from_pretrained( |
|
|
self.cfgs.model_cfgs.model_name_or_path, |
|
|
) |
|
|
self.tokenizer = self.processor.tokenizer |
|
|
|
|
|
def loss(self, sft_batch: SupervisedBatch) -> dict[str, torch.Tensor]: |
|
|
"""Loss function for supervised finetuning.""" |
|
|
print("sft_batch", sft_batch.keys()) |
|
|
sft_batch['task'] = 'image_editing' |
|
|
print('SSSS ',sft_batch['source_image']) |
|
|
outputs = self.model.forward(vl_chat_processor=self.processor,**sft_batch) |
|
|
return { |
|
|
'loss': outputs.loss, |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
deepspeed.init_distributed() |
|
|
current_device = get_current_device() |
|
|
torch_set_device(current_device) |
|
|
|
|
|
|
|
|
task = os.path.join('janus', 'sft_gen') |
|
|
dict_cfgs, ds_cfgs = read_cfgs(mode='train', task=task) |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
_, unparsed_args = parser.parse_known_args() |
|
|
keys = [k[2:] for k in unparsed_args[1::2]] |
|
|
values = list(unparsed_args[2::2]) |
|
|
unparsed_args = dict(zip(keys, values)) |
|
|
for k, v in unparsed_args.items(): |
|
|
dict_cfgs = update_dict(dict_cfgs, custom_cfgs_to_dict(k, v)) |
|
|
|
|
|
|
|
|
cfgs = dict_to_namedtuple(dict_cfgs) |
|
|
seed_everything(cfgs.train_cfgs.seed) |
|
|
|
|
|
|
|
|
trainer = SuperviseTrainer(cfgs=cfgs, ds_cfgs=ds_cfgs) |
|
|
trainer.train() |
|
|
trainer.save() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
sys.exit(main()) |
|
|
|