Spaces:
Sleeping
Sleeping
Fix: LlamaRMSNorm to accept extra kwargs from timm
Browse files- models/ACMDM.py +2 -1
models/ACMDM.py
CHANGED
|
@@ -424,7 +424,8 @@ class TimestepEmbedder(nn.Module):
|
|
| 424 |
|
| 425 |
|
| 426 |
class LlamaRMSNorm(nn.Module):
|
| 427 |
-
def __init__(self, hidden_size, eps=1e-6):
|
|
|
|
| 428 |
super().__init__()
|
| 429 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 430 |
self.variance_epsilon = eps
|
|
|
|
| 424 |
|
| 425 |
|
| 426 |
class LlamaRMSNorm(nn.Module):
|
| 427 |
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
| 428 |
+
# Accept and ignore extra kwargs (like 'device') that timm may pass
|
| 429 |
super().__init__()
|
| 430 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 431 |
self.variance_epsilon = eps
|