Commit
·
7dd48b3
1
Parent(s):
9e16acb
CPU mode: remove fp16 + autocast, use fp32 everywhere
Browse files- code/losses.py +30 -18
code/losses.py
CHANGED
|
@@ -15,17 +15,26 @@ class SDSLoss(nn.Module):
|
|
| 15 |
super(SDSLoss, self).__init__()
|
| 16 |
self.cfg = cfg
|
| 17 |
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 19 |
cfg.diffusion.model,
|
| 20 |
torch_dtype=torch.float32,
|
| 21 |
token=cfg.token,
|
| 22 |
).to("cpu")
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod)
|
| 31 |
self.alphas = alphas_cumprod.to(device)
|
|
@@ -38,21 +47,20 @@ class SDSLoss(nn.Module):
|
|
| 38 |
#self.pipe.enable_model_cpu_offload()
|
| 39 |
|
| 40 |
# text-encoder is no longer needed
|
| 41 |
-
del self.pipe.text_encoder, self.pipe.tokenizer
|
| 42 |
|
| 43 |
def embed_text(self):
|
| 44 |
tok = self.pipe.tokenizer
|
| 45 |
-
txt
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
un
|
| 49 |
-
|
| 50 |
-
return_tensors="pt")
|
| 51 |
|
| 52 |
with torch.no_grad():
|
| 53 |
-
te = self.pipe.text_encoder.eval()
|
| 54 |
-
em_txt = te(txt.input_ids
|
| 55 |
-
em_un = te(un .input_ids
|
| 56 |
|
| 57 |
self.text_embeddings = (
|
| 58 |
torch.cat([em_un, em_txt])
|
|
@@ -64,10 +72,14 @@ class SDSLoss(nn.Module):
|
|
| 64 |
|
| 65 |
def forward(self, x_aug: torch.Tensor) -> torch.Tensor:
|
| 66 |
# ---------------------------------------------------- encode
|
| 67 |
-
x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
latents = self.pipe.vae.encode(x).latent_dist.sample()
|
| 70 |
-
|
|
|
|
| 71 |
torch.cuda.empty_cache()
|
| 72 |
|
| 73 |
# ---------------------------------------------------- add noise
|
|
|
|
| 15 |
super(SDSLoss, self).__init__()
|
| 16 |
self.cfg = cfg
|
| 17 |
self.device = device
|
| 18 |
+
self.fp16 = device.type == "cuda"
|
| 19 |
+
dtype = torch.float16 if self.fp16 else torch.float32
|
| 20 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 21 |
+
cfg.diffusion.model,
|
| 22 |
+
torch_dtype=dtype,
|
| 23 |
+
token=cfg.token,
|
| 24 |
+
).to(device)
|
| 25 |
+
|
| 26 |
self.pipe = StableDiffusionPipeline.from_pretrained(
|
| 27 |
cfg.diffusion.model,
|
| 28 |
torch_dtype=torch.float32,
|
| 29 |
token=cfg.token,
|
| 30 |
).to("cpu")
|
| 31 |
|
| 32 |
+
if self.fp16:
|
| 33 |
+
# self.pipe.enable_xformers_memory_efficient_attention()
|
| 34 |
+
self.pipe.enable_attention_slicing(slice_size=1)
|
| 35 |
+
self.pipe.enable_vae_slicing()
|
| 36 |
+
self.pipe.enable_vae_tiling()
|
| 37 |
+
self.pipe.unet.enable_gradient_checkpointing()
|
| 38 |
|
| 39 |
alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod)
|
| 40 |
self.alphas = alphas_cumprod.to(device)
|
|
|
|
| 47 |
#self.pipe.enable_model_cpu_offload()
|
| 48 |
|
| 49 |
# text-encoder is no longer needed
|
| 50 |
+
#del self.pipe.text_encoder, self.pipe.tokenizer
|
| 51 |
|
| 52 |
def embed_text(self):
|
| 53 |
tok = self.pipe.tokenizer
|
| 54 |
+
txt = tok(self.cfg.caption, padding="max_length",
|
| 55 |
+
max_length=tok.model_max_length, truncation=True,
|
| 56 |
+
return_tensors="pt")
|
| 57 |
+
un = tok([""], padding="max_length",
|
| 58 |
+
max_length=tok.model_max_length, return_tensors="pt")
|
|
|
|
| 59 |
|
| 60 |
with torch.no_grad():
|
| 61 |
+
te = self.pipe.text_encoder.eval()
|
| 62 |
+
em_txt = te(txt.input_ids).last_hidden_state.to(torch.float32)
|
| 63 |
+
em_un = te(un .input_ids).last_hidden_state.to(torch.float32)
|
| 64 |
|
| 65 |
self.text_embeddings = (
|
| 66 |
torch.cat([em_un, em_txt])
|
|
|
|
| 72 |
|
| 73 |
def forward(self, x_aug: torch.Tensor) -> torch.Tensor:
|
| 74 |
# ---------------------------------------------------- encode
|
| 75 |
+
x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.float32)
|
| 76 |
+
if self.fp16:
|
| 77 |
+
with torch.cuda.amp.autocast():
|
| 78 |
+
latents = self.pipe.vae.encode(x).latent_dist.sample()
|
| 79 |
+
else:
|
| 80 |
latents = self.pipe.vae.encode(x).latent_dist.sample()
|
| 81 |
+
|
| 82 |
+
latents = 0.18215 * latents
|
| 83 |
torch.cuda.empty_cache()
|
| 84 |
|
| 85 |
# ---------------------------------------------------- add noise
|