KingHacker9000 commited on
Commit
7dd48b3
·
1 Parent(s): 9e16acb

CPU mode: remove fp16 + autocast, use fp32 everywhere

Browse files
Files changed (1) hide show
  1. code/losses.py +30 -18
code/losses.py CHANGED
@@ -15,17 +15,26 @@ class SDSLoss(nn.Module):
15
  super(SDSLoss, self).__init__()
16
  self.cfg = cfg
17
  self.device = device
 
 
 
 
 
 
 
 
18
  self.pipe = StableDiffusionPipeline.from_pretrained(
19
  cfg.diffusion.model,
20
  torch_dtype=torch.float32,
21
  token=cfg.token,
22
  ).to("cpu")
23
 
24
- #self.pipe.enable_xformers_memory_efficient_attention()
25
- self.pipe.enable_attention_slicing(slice_size=1)
26
- self.pipe.enable_vae_slicing()
27
- self.pipe.enable_vae_tiling()
28
- self.pipe.unet.enable_gradient_checkpointing()
 
29
 
30
  alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod)
31
  self.alphas = alphas_cumprod.to(device)
@@ -38,21 +47,20 @@ class SDSLoss(nn.Module):
38
  #self.pipe.enable_model_cpu_offload()
39
 
40
  # text-encoder is no longer needed
41
- del self.pipe.text_encoder, self.pipe.tokenizer
42
 
43
  def embed_text(self):
44
  tok = self.pipe.tokenizer
45
- txt = tok(self.cfg.caption, padding="max_length",
46
- max_length=tok.model_max_length,
47
- truncation=True, return_tensors="pt")
48
- un = tok([""], padding="max_length",
49
- max_length=tok.model_max_length,
50
- return_tensors="pt")
51
 
52
  with torch.no_grad():
53
- te = self.pipe.text_encoder.eval() # still real tensors
54
- em_txt = te(txt.input_ids ).last_hidden_state.to(torch.float16)
55
- em_un = te(un .input_ids ).last_hidden_state.to(torch.float16)
56
 
57
  self.text_embeddings = (
58
  torch.cat([em_un, em_txt])
@@ -64,10 +72,14 @@ class SDSLoss(nn.Module):
64
 
65
  def forward(self, x_aug: torch.Tensor) -> torch.Tensor:
66
  # ---------------------------------------------------- encode
67
- x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.float16)
68
- with torch.cuda.amp.autocast():
 
 
 
69
  latents = self.pipe.vae.encode(x).latent_dist.sample()
70
- latents = 0.18215 * latents.to(self.device, dtype=torch.float16)
 
71
  torch.cuda.empty_cache()
72
 
73
  # ---------------------------------------------------- add noise
 
15
  super(SDSLoss, self).__init__()
16
  self.cfg = cfg
17
  self.device = device
18
+ self.fp16 = device.type == "cuda"
19
+ dtype = torch.float16 if self.fp16 else torch.float32
20
+ self.pipe = StableDiffusionPipeline.from_pretrained(
21
+ cfg.diffusion.model,
22
+ torch_dtype=dtype,
23
+ token=cfg.token,
24
+ ).to(device)
25
+
26
  self.pipe = StableDiffusionPipeline.from_pretrained(
27
  cfg.diffusion.model,
28
  torch_dtype=torch.float32,
29
  token=cfg.token,
30
  ).to("cpu")
31
 
32
+ if self.fp16:
33
+ # self.pipe.enable_xformers_memory_efficient_attention()
34
+ self.pipe.enable_attention_slicing(slice_size=1)
35
+ self.pipe.enable_vae_slicing()
36
+ self.pipe.enable_vae_tiling()
37
+ self.pipe.unet.enable_gradient_checkpointing()
38
 
39
  alphas_cumprod = torch.tensor(self.pipe.scheduler.alphas_cumprod)
40
  self.alphas = alphas_cumprod.to(device)
 
47
  #self.pipe.enable_model_cpu_offload()
48
 
49
  # text-encoder is no longer needed
50
+ #del self.pipe.text_encoder, self.pipe.tokenizer
51
 
52
  def embed_text(self):
53
  tok = self.pipe.tokenizer
54
+ txt = tok(self.cfg.caption, padding="max_length",
55
+ max_length=tok.model_max_length, truncation=True,
56
+ return_tensors="pt")
57
+ un = tok([""], padding="max_length",
58
+ max_length=tok.model_max_length, return_tensors="pt")
 
59
 
60
  with torch.no_grad():
61
+ te = self.pipe.text_encoder.eval()
62
+ em_txt = te(txt.input_ids).last_hidden_state.to(torch.float32)
63
+ em_un = te(un .input_ids).last_hidden_state.to(torch.float32)
64
 
65
  self.text_embeddings = (
66
  torch.cat([em_un, em_txt])
 
72
 
73
  def forward(self, x_aug: torch.Tensor) -> torch.Tensor:
74
  # ---------------------------------------------------- encode
75
+ x = (x_aug * 2.0 - 1.0).to(self.device, dtype=torch.float32)
76
+ if self.fp16:
77
+ with torch.cuda.amp.autocast():
78
+ latents = self.pipe.vae.encode(x).latent_dist.sample()
79
+ else:
80
  latents = self.pipe.vae.encode(x).latent_dist.sample()
81
+
82
+ latents = 0.18215 * latents
83
  torch.cuda.empty_cache()
84
 
85
  # ---------------------------------------------------- add noise