Spaces:
Runtime error
Runtime error
Update cappella.py
Browse files- cappella.py +5 -3
cappella.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class CappellaResult:
|
|
@@ -21,7 +22,7 @@ class Cappella:
|
|
| 21 |
It correctly:
|
| 22 |
1. Uses both SDXL tokenizers and text encoders.
|
| 23 |
2. Truncates prompts that are too long (fixes "78 vs 77" error).
|
| 24 |
-
3. Pads prompts
|
| 25 |
4. Returns all 4 required embedding tensors.
|
| 26 |
"""
|
| 27 |
def __init__(self, pipe, device):
|
|
@@ -49,9 +50,10 @@ class Cappella:
|
|
| 49 |
negative_pooled_embeds=neg_pooled
|
| 50 |
)
|
| 51 |
|
| 52 |
-
def _encode_one(self, prompt: str) ->
|
| 53 |
"""
|
| 54 |
-
Runs a single prompt string through both text encoders
|
|
|
|
| 55 |
"""
|
| 56 |
# --- Tokenizer 1 (CLIP-L) ---
|
| 57 |
tok_1_inputs = self.tokenizer(
|
|
|
|
| 1 |
import torch
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
| 4 |
+
from typing import Tuple
|
| 5 |
|
| 6 |
@dataclass
|
| 7 |
class CappellaResult:
|
|
|
|
| 22 |
It correctly:
|
| 23 |
1. Uses both SDXL tokenizers and text encoders.
|
| 24 |
2. Truncates prompts that are too long (fixes "78 vs 77" error).
|
| 25 |
+
3. Pads prompts (by using max_length) to ensure they are all 77 tokens.
|
| 26 |
4. Returns all 4 required embedding tensors.
|
| 27 |
"""
|
| 28 |
def __init__(self, pipe, device):
|
|
|
|
| 50 |
negative_pooled_embeds=neg_pooled
|
| 51 |
)
|
| 52 |
|
| 53 |
+
def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
"""
|
| 55 |
+
Runs a single prompt string through both text encoders,
|
| 56 |
+
ensuring truncation and padding to 77 tokens.
|
| 57 |
"""
|
| 58 |
# --- Tokenizer 1 (CLIP-L) ---
|
| 59 |
tok_1_inputs = self.tokenizer(
|