sourxbhh commited on
Commit
005d0a4
·
1 Parent(s): fd1e351

Fix: LlamaRMSNorm to accept extra kwargs from timm

Browse files
Files changed (1) hide show
  1. models/ACMDM.py +2 -1
models/ACMDM.py CHANGED
@@ -424,7 +424,8 @@ class TimestepEmbedder(nn.Module):
424
 
425
 
426
  class LlamaRMSNorm(nn.Module):
427
- def __init__(self, hidden_size, eps=1e-6):
 
428
  super().__init__()
429
  self.weight = nn.Parameter(torch.ones(hidden_size))
430
  self.variance_epsilon = eps
 
424
 
425
 
426
  class LlamaRMSNorm(nn.Module):
427
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
428
+ # Accept and ignore extra kwargs (like 'device') that timm may pass
429
  super().__init__()
430
  self.weight = nn.Parameter(torch.ones(hidden_size))
431
  self.variance_epsilon = eps