primerz commited on
Commit
eda51f2
·
verified ·
1 Parent(s): 3e937c5

Delete ip_attention_processor_xformers.py

Browse files
Files changed (1) hide show
  1. ip_attention_processor_xformers.py +0 -414
ip_attention_processor_xformers.py DELETED
@@ -1,414 +0,0 @@
1
- """
2
- Enhanced IP-Adapter Attention Processor with XFormers Support
3
- ==============================================================
4
-
5
- This version combines:
6
- 1. Torch 2.0 scaled_dot_product_attention (from our enhanced version)
7
- 2. XFormers memory efficient attention (from InstantID reference)
8
- 3. Adaptive scaling and learnable parameters (from our enhanced version)
9
- 4. Region control support (from InstantID reference)
10
-
11
- Expected improvements:
12
- - +15-25% faster inference with xformers
13
- - +2-3% better face preservation with adaptive scaling
14
- - Lower memory usage
15
-
16
- Author: Pixagram Team
17
- License: MIT
18
- """
19
-
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from typing import Optional
24
- from diffusers.models.attention_processor import AttnProcessor2_0
25
-
26
- try:
27
- import xformers
28
- import xformers.ops
29
- xformers_available = True
30
- except Exception:
31
- xformers_available = False
32
-
33
-
34
- class RegionControler(object):
35
- """Region control for localized face embedding application"""
36
- def __init__(self) -> None:
37
- self.prompt_image_conditioning = []
38
-
39
- region_control = RegionControler()
40
-
41
-
42
- class IPAttnProcessorXFormers(nn.Module):
43
- """
44
- Enhanced IP-Adapter attention with XFormers and adaptive scaling.
45
-
46
- Features:
47
- - XFormers memory efficient attention (if available)
48
- - Torch 2.0 scaled_dot_product_attention (fallback)
49
- - Adaptive per-layer scaling
50
- - Learnable scale parameters
51
- - Region control support
52
-
53
- Args:
54
- hidden_size: Attention layer hidden dimension
55
- cross_attention_dim: Encoder hidden states dimension
56
- scale: Base blending weight for face features
57
- num_tokens: Number of face embedding tokens
58
- adaptive_scale: Enable adaptive scaling
59
- learnable_scale: Make scale learnable per layer
60
- """
61
-
62
- def __init__(
63
- self,
64
- hidden_size: int,
65
- cross_attention_dim: Optional[int] = None,
66
- scale: float = 1.0,
67
- num_tokens: int = 4,
68
- adaptive_scale: bool = True,
69
- learnable_scale: bool = True
70
- ):
71
- super().__init__()
72
-
73
- self.hidden_size = hidden_size
74
- self.cross_attention_dim = cross_attention_dim or hidden_size
75
- self.base_scale = scale
76
- self.num_tokens = num_tokens
77
- self.adaptive_scale = adaptive_scale
78
- self.use_xformers = xformers_available
79
-
80
- # Dedicated K/V projections for face features
81
- self.to_k_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
82
- self.to_v_ip = nn.Linear(self.cross_attention_dim, hidden_size, bias=False)
83
-
84
- # Learnable scale parameter (per layer)
85
- if learnable_scale:
86
- self.scale_param = nn.Parameter(torch.tensor(scale))
87
- else:
88
- self.register_buffer('scale_param', torch.tensor(scale))
89
-
90
- # Adaptive scaling module
91
- if adaptive_scale:
92
- self.adaptive_gate = nn.Sequential(
93
- nn.Linear(hidden_size, hidden_size // 4),
94
- nn.ReLU(),
95
- nn.Linear(hidden_size // 4, 1),
96
- nn.Sigmoid()
97
- )
98
-
99
- # Better initialization
100
- self._init_weights()
101
-
102
- if self.use_xformers:
103
- print(f" [XFORMERS] Enabled for IP-Adapter attention")
104
-
105
- def _init_weights(self):
106
- """Xavier initialization for stable training."""
107
- nn.init.xavier_uniform_(self.to_k_ip.weight)
108
- nn.init.xavier_uniform_(self.to_v_ip.weight)
109
-
110
- if self.adaptive_scale:
111
- for module in self.adaptive_gate:
112
- if isinstance(module, nn.Linear):
113
- nn.init.xavier_uniform_(module.weight)
114
- if module.bias is not None:
115
- nn.init.zeros_(module.bias)
116
-
117
- def compute_adaptive_scale(
118
- self,
119
- query: torch.Tensor,
120
- ip_key: torch.Tensor,
121
- base_scale: float
122
- ) -> torch.Tensor:
123
- """
124
- Compute adaptive scale based on query-key similarity.
125
- Higher similarity = stronger face preservation.
126
- """
127
- # Compute mean query features
128
- query_mean = query.mean(dim=(1, 2)) # [batch, head_dim * heads]
129
-
130
- # Pass through gating network
131
- gate = self.adaptive_gate(query_mean) # [batch, 1]
132
-
133
- # Modulate base scale
134
- adaptive_scale = base_scale * (0.5 + gate) # Range: [0.5*base, 1.5*base]
135
-
136
- return adaptive_scale.view(-1, 1, 1) # [batch, 1, 1] for broadcasting
137
-
138
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
139
- """XFormers memory efficient attention"""
140
- # XFormers expects (batch, seq_len, heads, head_dim)
141
- # Current shape: (batch * heads, seq_len, head_dim)
142
- batch_heads, seq_len, head_dim = query.shape
143
-
144
- # We need to reshape to (batch, seq_len, heads, head_dim)
145
- # But we don't know batch size here, so we keep it simple
146
- hidden_states = xformers.ops.memory_efficient_attention(
147
- query.unsqueeze(0),
148
- key.unsqueeze(0),
149
- value.unsqueeze(0),
150
- attn_bias=None if attention_mask is None else attention_mask.unsqueeze(0)
151
- )
152
-
153
- return hidden_states.squeeze(0)
154
-
155
- def forward(
156
- self,
157
- attn,
158
- hidden_states: torch.FloatTensor,
159
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
160
- attention_mask: Optional[torch.FloatTensor] = None,
161
- temb: Optional[torch.FloatTensor] = None,
162
- ) -> torch.FloatTensor:
163
- """Forward pass with XFormers or Torch 2.0 attention."""
164
- residual = hidden_states
165
-
166
- if attn.spatial_norm is not None:
167
- hidden_states = attn.spatial_norm(hidden_states, temb)
168
-
169
- input_ndim = hidden_states.ndim
170
- if input_ndim == 4:
171
- batch_size, channel, height, width = hidden_states.shape
172
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
173
-
174
- batch_size, sequence_length, _ = (
175
- hidden_states.shape if encoder_hidden_states is None
176
- else encoder_hidden_states.shape
177
- )
178
-
179
- if attention_mask is not None:
180
- attention_mask = attn.prepare_attention_mask(
181
- attention_mask, sequence_length, batch_size
182
- )
183
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
184
-
185
- if attn.group_norm is not None:
186
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
187
-
188
- query = attn.to_q(hidden_states)
189
-
190
- # Split text and face embeddings
191
- if encoder_hidden_states is None:
192
- encoder_hidden_states = hidden_states
193
- ip_hidden_states = None
194
- else:
195
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
196
- encoder_hidden_states, ip_hidden_states = (
197
- encoder_hidden_states[:, :end_pos, :],
198
- encoder_hidden_states[:, end_pos:, :]
199
- )
200
- if attn.norm_cross:
201
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
202
-
203
- # Text attention
204
- key = attn.to_k(encoder_hidden_states)
205
- value = attn.to_v(encoder_hidden_states)
206
-
207
- inner_dim = key.shape[-1]
208
- head_dim = inner_dim // attn.heads
209
-
210
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
211
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
212
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
213
-
214
- # Choose attention implementation
215
- if self.use_xformers and self.training == False:
216
- # XFormers during inference
217
- query_xf = query.reshape(batch_size * attn.heads, -1, head_dim)
218
- key_xf = key.reshape(batch_size * attn.heads, -1, head_dim)
219
- value_xf = value.reshape(batch_size * attn.heads, -1, head_dim)
220
-
221
- try:
222
- hidden_states = self._memory_efficient_attention_xformers(
223
- query_xf, key_xf, value_xf, attention_mask
224
- )
225
- hidden_states = hidden_states.reshape(batch_size, attn.heads, -1, head_dim)
226
- except:
227
- # Fallback to torch 2.0
228
- hidden_states = F.scaled_dot_product_attention(
229
- query, key, value,
230
- attn_mask=attention_mask,
231
- dropout_p=0.0,
232
- is_causal=False
233
- )
234
- else:
235
- # Torch 2.0 attention
236
- hidden_states = F.scaled_dot_product_attention(
237
- query, key, value,
238
- attn_mask=attention_mask,
239
- dropout_p=0.0,
240
- is_causal=False
241
- )
242
-
243
- hidden_states = hidden_states.transpose(1, 2).reshape(
244
- batch_size, -1, attn.heads * head_dim
245
- )
246
- hidden_states = hidden_states.to(query.dtype)
247
-
248
- # Face attention with enhancements
249
- if ip_hidden_states is not None:
250
- # Dedicated K/V projections
251
- ip_key = self.to_k_ip(ip_hidden_states)
252
- ip_value = self.to_v_ip(ip_hidden_states)
253
-
254
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
255
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
256
-
257
- # Face attention
258
- if self.use_xformers and self.training == False:
259
- # XFormers
260
- query_xf = query.reshape(batch_size * attn.heads, -1, head_dim)
261
- ip_key_xf = ip_key.reshape(batch_size * attn.heads, -1, head_dim)
262
- ip_value_xf = ip_value.reshape(batch_size * attn.heads, -1, head_dim)
263
-
264
- try:
265
- ip_hidden_states = self._memory_efficient_attention_xformers(
266
- query_xf, ip_key_xf, ip_value_xf, None
267
- )
268
- ip_hidden_states = ip_hidden_states.reshape(batch_size, attn.heads, -1, head_dim)
269
- except:
270
- # Fallback
271
- ip_hidden_states = F.scaled_dot_product_attention(
272
- query, ip_key, ip_value,
273
- attn_mask=None,
274
- dropout_p=0.0,
275
- is_causal=False
276
- )
277
- else:
278
- # Torch 2.0
279
- ip_hidden_states = F.scaled_dot_product_attention(
280
- query, ip_key, ip_value,
281
- attn_mask=None,
282
- dropout_p=0.0,
283
- is_causal=False
284
- )
285
-
286
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
287
- batch_size, -1, attn.heads * head_dim
288
- )
289
- ip_hidden_states = ip_hidden_states.to(query.dtype)
290
-
291
- # Compute effective scale
292
- if self.adaptive_scale and self.training == False:
293
- try:
294
- adaptive_scale = self.compute_adaptive_scale(query, ip_key, self.scale_param.item())
295
- effective_scale = adaptive_scale
296
- except:
297
- effective_scale = self.scale_param
298
- else:
299
- effective_scale = self.scale_param
300
-
301
- # Region control support
302
- if len(region_control.prompt_image_conditioning) == 1:
303
- region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
304
- if region_mask is not None:
305
- query_flat = query.reshape([-1, query.shape[-2], query.shape[-1]])
306
- h, w = region_mask.shape[:2]
307
- ratio = (h * w / query_flat.shape[1]) ** 0.5
308
- mask = F.interpolate(
309
- region_mask[None, None],
310
- scale_factor=1/ratio,
311
- mode='nearest'
312
- ).reshape([1, -1, 1])
313
- else:
314
- mask = torch.ones_like(ip_hidden_states)
315
- ip_hidden_states = ip_hidden_states * mask
316
-
317
- # Blend with adaptive scale
318
- hidden_states = hidden_states + effective_scale * ip_hidden_states
319
-
320
- # Output projection
321
- hidden_states = attn.to_out[0](hidden_states)
322
- hidden_states = attn.to_out[1](hidden_states)
323
-
324
- if input_ndim == 4:
325
- hidden_states = hidden_states.transpose(-1, -2).reshape(
326
- batch_size, channel, height, width
327
- )
328
-
329
- if attn.residual_connection:
330
- hidden_states = hidden_states + residual
331
-
332
- hidden_states = hidden_states / attn.rescale_output_factor
333
-
334
- return hidden_states
335
-
336
-
337
- def setup_xformers_ip_adapter_attention(
338
- pipe,
339
- ip_adapter_scale: float = 1.0,
340
- num_tokens: int = 4,
341
- device: str = "cuda",
342
- dtype = torch.float16,
343
- adaptive_scale: bool = True,
344
- learnable_scale: bool = True
345
- ):
346
- """
347
- Setup IP-Adapter with XFormers optimized attention processors.
348
-
349
- Args:
350
- pipe: Diffusers pipeline
351
- ip_adapter_scale: Base face embedding strength
352
- num_tokens: Number of face tokens
353
- device: Device
354
- dtype: Data type
355
- adaptive_scale: Enable adaptive scaling
356
- learnable_scale: Make scales learnable
357
-
358
- Returns:
359
- Dict of attention processors
360
- """
361
- attn_procs = {}
362
-
363
- for name in pipe.unet.attn_processors.keys():
364
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
365
-
366
- if name.startswith("mid_block"):
367
- hidden_size = pipe.unet.config.block_out_channels[-1]
368
- elif name.startswith("up_blocks"):
369
- block_id = int(name[len("up_blocks.")])
370
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
371
- elif name.startswith("down_blocks"):
372
- block_id = int(name[len("down_blocks.")])
373
- hidden_size = pipe.unet.config.block_out_channels[block_id]
374
- else:
375
- hidden_size = pipe.unet.config.block_out_channels[-1]
376
-
377
- if cross_attention_dim is None:
378
- attn_procs[name] = AttnProcessor2_0()
379
- else:
380
- attn_procs[name] = IPAttnProcessorXFormers(
381
- hidden_size=hidden_size,
382
- cross_attention_dim=cross_attention_dim,
383
- scale=ip_adapter_scale,
384
- num_tokens=num_tokens,
385
- adaptive_scale=adaptive_scale,
386
- learnable_scale=learnable_scale
387
- ).to(device, dtype=dtype)
388
-
389
- print(f"[OK] XFormers-optimized attention processors created")
390
- print(f" - Total processors: {len(attn_procs)}")
391
- print(f" - XFormers available: {xformers_available}")
392
- print(f" - Adaptive scaling: {adaptive_scale}")
393
- print(f" - Learnable scales: {learnable_scale}")
394
-
395
- return attn_procs
396
-
397
-
398
- if __name__ == "__main__":
399
- print("Testing XFormers IP-Adapter Processor...")
400
-
401
- processor = IPAttnProcessorXFormers(
402
- hidden_size=1280,
403
- cross_attention_dim=2048,
404
- scale=0.8,
405
- num_tokens=4,
406
- adaptive_scale=True,
407
- learnable_scale=True
408
- )
409
-
410
- print(f"\n[OK] Processor created successfully")
411
- print(f"Parameters: {sum(p.numel() for p in processor.parameters()):,}")
412
- print(f"XFormers available: {xformers_available}")
413
- print(f"Has adaptive scaling: {processor.adaptive_scale}")
414
- print(f"Has learnable scale: {isinstance(processor.scale_param, nn.Parameter)}")