Ashish Ajin
commited on
Commit
·
2584547
1
Parent(s):
ce5011c
Handle missing xformers CUDA
Browse files- code/losses.py +5 -1
code/losses.py
CHANGED
|
@@ -30,7 +30,11 @@ class SDSLoss(nn.Module):
|
|
| 30 |
self.pipe.to(device)
|
| 31 |
self.pipe.text_encoder.to(device)
|
| 32 |
|
| 33 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# Use additional VRAM to disable memory-saving features for speed
|
| 35 |
if hasattr(self.pipe, "disable_attention_slicing"):
|
| 36 |
self.pipe.disable_attention_slicing()
|
|
|
|
| 30 |
self.pipe.to(device)
|
| 31 |
self.pipe.text_encoder.to(device)
|
| 32 |
|
| 33 |
+
if self.fp16:
|
| 34 |
+
try:
|
| 35 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"WARNING: xFormers memory efficient attention could not be enabled: {e}")
|
| 38 |
# Use additional VRAM to disable memory-saving features for speed
|
| 39 |
if hasattr(self.pipe, "disable_attention_slicing"):
|
| 40 |
self.pipe.disable_attention_slicing()
|