Ashish Ajin commited on
Commit
2584547
·
1 Parent(s): ce5011c

Handle missing xformers CUDA

Browse files
Files changed (1) hide show
  1. code/losses.py +5 -1
code/losses.py CHANGED
@@ -30,7 +30,11 @@ class SDSLoss(nn.Module):
30
  self.pipe.to(device)
31
  self.pipe.text_encoder.to(device)
32
 
33
- self.pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
34
  # Use additional VRAM to disable memory-saving features for speed
35
  if hasattr(self.pipe, "disable_attention_slicing"):
36
  self.pipe.disable_attention_slicing()
 
30
  self.pipe.to(device)
31
  self.pipe.text_encoder.to(device)
32
 
33
+ if self.fp16:
34
+ try:
35
+ self.pipe.enable_xformers_memory_efficient_attention()
36
+ except Exception as e:
37
+ print(f"WARNING: xFormers memory efficient attention could not be enabled: {e}")
38
  # Use additional VRAM to disable memory-saving features for speed
39
  if hasattr(self.pipe, "disable_attention_slicing"):
40
  self.pipe.disable_attention_slicing()