import torch from dataclasses import dataclass from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from typing import Tuple @dataclass class CappellaResult: """ Holds the 4 tensors required by the SDXL pipeline, all guaranteed to have the correct, matching sequence length. """ embeds: torch.Tensor pooled_embeds: torch.Tensor negative_embeds: torch.Tensor negative_pooled_embeds: torch.Tensor class Cappella: """ A minimal, custom-built prompt encoder for our SDXL pipeline. It replaces the 'compel' dependency and is tailored for our exact use case. It correctly: 1. Uses both SDXL tokenizers and text encoders. 2. Truncates prompts that are too long (fixes "78 vs 77" error). 3. Pads prompts (by using max_length) to ensure they are all 77 tokens. 4. Returns all 4 required embedding tensors. """ def __init__(self, pipe, device): self.tokenizer: CLIPTokenizer = pipe.tokenizer self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2 self.text_encoder: CLIPTextModel = pipe.text_encoder self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2 self.device = device # In cappella.py @torch.no_grad() def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult: """ Encodes the positive and negative prompts. Ensures both embedding tensors have the same sequence length. """ # Encode the positive prompt pos_embeds, pos_pooled = self._encode_one(prompt) # Encode the negative prompt neg_embeds, neg_pooled = self._encode_one(negative_prompt) # --- START FIX: Pad shorter embeds --- # Ensure embeds and negative_embeds have the same sequence length seq_len_pos = pos_embeds.shape[1] seq_len_neg = neg_embeds.shape[1] if seq_len_pos > seq_len_neg: # Pad negative embeds pad_len = seq_len_pos - seq_len_neg padding = torch.zeros( (neg_embeds.shape[0], pad_len, neg_embeds.shape[2]), device=self.device, dtype=neg_embeds.dtype ) neg_embeds = torch.cat([neg_embeds, padding], dim=1) elif seq_len_neg > seq_len_pos: # Pad positive embeds pad_len = seq_len_neg - seq_len_pos padding = torch.zeros( (pos_embeds.shape[0], pad_len, pos_embeds.shape[2]), device=self.device, dtype=pos_embeds.dtype ) pos_embeds = torch.cat([pos_embeds, padding], dim=1) # Now seq_len_pos and seq_len_neg are guaranteed to be equal # --- END FIX --- return CappellaResult( embeds=pos_embeds, pooled_embeds=pos_pooled, negative_embeds=neg_embeds, negative_pooled_embeds=neg_pooled ) def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Runs a single prompt string through both text encoders. Handles prompts longer than 77 tokens by chunking. """ # --- Get Tokenizers and Encoders --- tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] prompt_embeds_list = [] pooled_prompt_embeds = None for tokenizer, text_encoder in zip(tokenizers, text_encoders): # --- Tokenize --- # Tokenize without padding or truncation first text_inputs = tokenizer( prompt, padding=False, truncation=False, return_tensors="pt" ) input_ids = text_inputs.input_ids.to(self.device) # --- Chunking --- # Manually chunk the input_ids max_length = tokenizer.model_max_length bos = tokenizer.bos_token_id eos = tokenizer.eos_token_id # We subtract 2 for BOS and EOS chunk_length = max_length - 2 # Get all token IDs *except* BOS and EOS clean_input_ids = input_ids[0, 1:-1] # Split into chunks chunks = [clean_input_ids[i:i + chunk_length] for i in range(0, len(clean_input_ids), chunk_length)] # --- Prepare Batches --- batch_input_ids = [] for chunk in chunks: # Add BOS and EOS chunk_with_bos_eos = torch.cat([ torch.tensor([bos], dtype=torch.long, device=self.device), chunk.to(torch.long), torch.tensor([eos], dtype=torch.long, device=self.device) ]) # Pad to max_length pad_len = max_length - len(chunk_with_bos_eos) if pad_len > 0: padding = torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long, device=self.device) chunk_with_bos_eos = torch.cat([chunk_with_bos_eos, padding]) batch_input_ids.append(chunk_with_bos_eos) if not batch_input_ids: # Handle empty prompt batch_input_ids.append( torch.full((max_length,), tokenizer.pad_token_id, dtype=torch.long, device=self.device) ) batch_input_ids = torch.stack(batch_input_ids) # --- Encode --- if text_encoder == self.text_encoder: # Text Encoder 1 (CLIP-L) # We only need the last_hidden_state encoder_output = text_encoder( batch_input_ids, output_hidden_states=False ) # [num_chunks, 77, 768] prompt_embeds = encoder_output.last_hidden_state prompt_embeds_list.append(prompt_embeds) elif text_encoder == self.text_encoder_2: # Text Encoder 2 (OpenCLIP-G) # We need hidden_states[-2] and the pooled output from the FIRST chunk encoder_output = text_encoder( batch_input_ids, output_hidden_states=True ) # [num_chunks, 77, 1280] prompt_embeds = encoder_output.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) # Pooled output comes from the FIRST chunk # We use .text_embeds which is the pooled output # [num_chunks, 1280] all_pooled = encoder_output.text_embeds pooled_prompt_embeds = all_pooled[0:1] # Keep as [1, 1280] # --- Concatenate Chunks --- # Reshape from [num_chunks, 77, dim] to [1, num_chunks*77, dim] # and then concatenate along the dim=-1 embeds_1 = prompt_embeds_list[0].reshape(1, -1, prompt_embeds_list[0].shape[-1]) embeds_2 = prompt_embeds_list[1].reshape(1, -1, prompt_embeds_list[1].shape[-1]) prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1) # pooled_prompt_embeds is already [1, 1280] from Encoder 2's first chunk return prompt_embeds, pooled_prompt_embeds