Phr00t commited on
Commit
5a05226
·
verified ·
1 Parent(s): d9dcbda

Update Custom-Advanced-VACE-Node/nodes_utility.py

Browse files
Custom-Advanced-VACE-Node/nodes_utility.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import numpy as np
3
  from comfy.utils import common_upscale
 
 
4
  from .utils import log
5
  from einops import rearrange
6
 
@@ -12,6 +14,9 @@ except:
12
  VAE_STRIDE = (4, 8, 8)
13
  PATCH_SIZE = (1, 2, 2)
14
 
 
 
 
15
  class WanVideoImageResizeToClosest:
16
  @classmethod
17
  def INPUT_TYPES(s):
@@ -681,6 +686,96 @@ class FaceMaskFromPoseKeypoints:
681
  cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
682
 
683
  return canvas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
 
685
  NODE_CLASS_MAPPINGS = {
686
  "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
@@ -694,6 +789,7 @@ NODE_CLASS_MAPPINGS = {
694
  "NormalizeAudioLoudness": NormalizeAudioLoudness,
695
  "WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
696
  "FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
 
697
  }
698
  NODE_DISPLAY_NAME_MAPPINGS = {
699
  "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
@@ -707,4 +803,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
707
  "NormalizeAudioLoudness": "Normalize Audio Loudness",
708
  "WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
709
  "FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
 
710
  }
 
1
  import torch
2
  import numpy as np
3
  from comfy.utils import common_upscale
4
+ from comfy import model_management
5
+ from tqdm import tqdm
6
  from .utils import log
7
  from einops import rearrange
8
 
 
14
  VAE_STRIDE = (4, 8, 8)
15
  PATCH_SIZE = (1, 2, 2)
16
 
17
+ main_device = model_management.get_torch_device()
18
+ offload_device = model_management.unet_offload_device()
19
+
20
  class WanVideoImageResizeToClosest:
21
  @classmethod
22
  def INPUT_TYPES(s):
 
686
  cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
687
 
688
  return canvas
689
+
690
+
691
+ class DrawGaussianNoiseOnImage:
692
+ @classmethod
693
+ def INPUT_TYPES(s):
694
+ return {"required": {
695
+ "image": ("IMAGE", ),
696
+ "mask": ("MASK", ),
697
+ },
698
+ "optional": {
699
+ "device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
700
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
701
+ }
702
+ }
703
+
704
+ RETURN_TYPES = ("IMAGE", )
705
+ RETURN_NAMES = ("images",)
706
+ FUNCTION = "apply"
707
+ CATEGORY = "KJNodes/masking"
708
+ DESCRIPTION = "Fills the background (masked area) with Gaussian noise sampled using the mean and variance of the subject (unmasked) region."
709
+
710
+ def apply(self, image, mask, device="cpu", seed=0):
711
+ B, H, W, C = image.shape
712
+ BM, HM, WM = mask.shape
713
+
714
+ processing_device = main_device if device == "gpu" else torch.device("cpu")
715
+
716
+ in_masks = mask.clone().to(processing_device)
717
+ in_images = image.clone().to(processing_device)
718
+
719
+ # Resize mask to match image dimensions
720
+ if HM != H or WM != W:
721
+ in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
722
+
723
+ # Match batch sizes
724
+ if B > BM:
725
+ in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B]
726
+ elif BM > B:
727
+ in_masks = in_masks[:B]
728
+
729
+ output_images = []
730
+
731
+ # Set random seed for reproducibility
732
+ generator = torch.Generator(device=processing_device).manual_seed(seed)
733
+
734
+ for i in tqdm(range(B), desc="DrawGaussianNoiseOnImage batch"):
735
+ curr_mask = in_masks[i]
736
+ img_idx = min(i, B - 1)
737
+ curr_image = in_images[img_idx]
738
+
739
+ # Expand mask to 3 channels
740
+ mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
741
+
742
+ # Calculate mean and std per channel from the subject region (where mask is 1)
743
+ subject_mask = mask_expanded > 0.5
744
+
745
+ # Initialize noise tensor
746
+ noise = torch.zeros_like(curr_image)
747
+
748
+ for c in range(C):
749
+ channel = curr_image[:, :, c]
750
+ channel_mask = subject_mask[:, :, c]
751
+
752
+ if channel_mask.sum() > 0:
753
+ # Get subject pixels
754
+ subject_pixels = channel[channel_mask]
755
+
756
+ # Calculate statistics
757
+ mean = subject_pixels.mean()
758
+ std = subject_pixels.std()
759
+
760
+ # Generate Gaussian noise for this channel
761
+ noise[:, :, c] = torch.normal(mean=mean.item(), std=std.item(),
762
+ size=(H, W), generator=generator,
763
+ device=processing_device)
764
+
765
+ # Clamp noise to valid range
766
+ noise = torch.clamp(noise, 0.0, 1.0)
767
+
768
+ # Apply: keep subject, fill background with noise
769
+ masked_image = curr_image * mask_expanded + noise * (1 - mask_expanded)
770
+ output_images.append(masked_image)
771
+
772
+ # If no masks were processed, return empty tensor
773
+ if not output_images:
774
+ return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
775
+
776
+ out_rgb = torch.stack(output_images, dim=0).cpu()
777
+
778
+ return (out_rgb, )
779
 
780
  NODE_CLASS_MAPPINGS = {
781
  "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
 
789
  "NormalizeAudioLoudness": NormalizeAudioLoudness,
790
  "WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
791
  "FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
792
+ "DrawGaussianNoiseOnImage": DrawGaussianNoiseOnImage,
793
  }
794
  NODE_DISPLAY_NAME_MAPPINGS = {
795
  "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
 
803
  "NormalizeAudioLoudness": "Normalize Audio Loudness",
804
  "WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
805
  "FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
806
+ "DrawGaussianNoiseOnImage": "Draw Gaussian Noise On Image",
807
  }