Update Custom-Advanced-VACE-Node/nodes_utility.py
Browse files
Custom-Advanced-VACE-Node/nodes_utility.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
from comfy.utils import common_upscale
|
|
|
|
|
|
|
| 4 |
from .utils import log
|
| 5 |
from einops import rearrange
|
| 6 |
|
|
@@ -12,6 +14,9 @@ except:
|
|
| 12 |
VAE_STRIDE = (4, 8, 8)
|
| 13 |
PATCH_SIZE = (1, 2, 2)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
class WanVideoImageResizeToClosest:
|
| 16 |
@classmethod
|
| 17 |
def INPUT_TYPES(s):
|
|
@@ -681,6 +686,96 @@ class FaceMaskFromPoseKeypoints:
|
|
| 681 |
cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
|
| 682 |
|
| 683 |
return canvas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
|
| 685 |
NODE_CLASS_MAPPINGS = {
|
| 686 |
"WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
|
|
@@ -694,6 +789,7 @@ NODE_CLASS_MAPPINGS = {
|
|
| 694 |
"NormalizeAudioLoudness": NormalizeAudioLoudness,
|
| 695 |
"WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
|
| 696 |
"FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
|
|
|
|
| 697 |
}
|
| 698 |
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 699 |
"WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
|
|
@@ -707,4 +803,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|
| 707 |
"NormalizeAudioLoudness": "Normalize Audio Loudness",
|
| 708 |
"WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
|
| 709 |
"FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
|
|
|
|
| 710 |
}
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
from comfy.utils import common_upscale
|
| 4 |
+
from comfy import model_management
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
from .utils import log
|
| 7 |
from einops import rearrange
|
| 8 |
|
|
|
|
| 14 |
VAE_STRIDE = (4, 8, 8)
|
| 15 |
PATCH_SIZE = (1, 2, 2)
|
| 16 |
|
| 17 |
+
main_device = model_management.get_torch_device()
|
| 18 |
+
offload_device = model_management.unet_offload_device()
|
| 19 |
+
|
| 20 |
class WanVideoImageResizeToClosest:
|
| 21 |
@classmethod
|
| 22 |
def INPUT_TYPES(s):
|
|
|
|
| 686 |
cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
|
| 687 |
|
| 688 |
return canvas
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class DrawGaussianNoiseOnImage:
|
| 692 |
+
@classmethod
|
| 693 |
+
def INPUT_TYPES(s):
|
| 694 |
+
return {"required": {
|
| 695 |
+
"image": ("IMAGE", ),
|
| 696 |
+
"mask": ("MASK", ),
|
| 697 |
+
},
|
| 698 |
+
"optional": {
|
| 699 |
+
"device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
|
| 700 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
| 701 |
+
}
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
RETURN_TYPES = ("IMAGE", )
|
| 705 |
+
RETURN_NAMES = ("images",)
|
| 706 |
+
FUNCTION = "apply"
|
| 707 |
+
CATEGORY = "KJNodes/masking"
|
| 708 |
+
DESCRIPTION = "Fills the background (masked area) with Gaussian noise sampled using the mean and variance of the subject (unmasked) region."
|
| 709 |
+
|
| 710 |
+
def apply(self, image, mask, device="cpu", seed=0):
|
| 711 |
+
B, H, W, C = image.shape
|
| 712 |
+
BM, HM, WM = mask.shape
|
| 713 |
+
|
| 714 |
+
processing_device = main_device if device == "gpu" else torch.device("cpu")
|
| 715 |
+
|
| 716 |
+
in_masks = mask.clone().to(processing_device)
|
| 717 |
+
in_images = image.clone().to(processing_device)
|
| 718 |
+
|
| 719 |
+
# Resize mask to match image dimensions
|
| 720 |
+
if HM != H or WM != W:
|
| 721 |
+
in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
|
| 722 |
+
|
| 723 |
+
# Match batch sizes
|
| 724 |
+
if B > BM:
|
| 725 |
+
in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B]
|
| 726 |
+
elif BM > B:
|
| 727 |
+
in_masks = in_masks[:B]
|
| 728 |
+
|
| 729 |
+
output_images = []
|
| 730 |
+
|
| 731 |
+
# Set random seed for reproducibility
|
| 732 |
+
generator = torch.Generator(device=processing_device).manual_seed(seed)
|
| 733 |
+
|
| 734 |
+
for i in tqdm(range(B), desc="DrawGaussianNoiseOnImage batch"):
|
| 735 |
+
curr_mask = in_masks[i]
|
| 736 |
+
img_idx = min(i, B - 1)
|
| 737 |
+
curr_image = in_images[img_idx]
|
| 738 |
+
|
| 739 |
+
# Expand mask to 3 channels
|
| 740 |
+
mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
|
| 741 |
+
|
| 742 |
+
# Calculate mean and std per channel from the subject region (where mask is 1)
|
| 743 |
+
subject_mask = mask_expanded > 0.5
|
| 744 |
+
|
| 745 |
+
# Initialize noise tensor
|
| 746 |
+
noise = torch.zeros_like(curr_image)
|
| 747 |
+
|
| 748 |
+
for c in range(C):
|
| 749 |
+
channel = curr_image[:, :, c]
|
| 750 |
+
channel_mask = subject_mask[:, :, c]
|
| 751 |
+
|
| 752 |
+
if channel_mask.sum() > 0:
|
| 753 |
+
# Get subject pixels
|
| 754 |
+
subject_pixels = channel[channel_mask]
|
| 755 |
+
|
| 756 |
+
# Calculate statistics
|
| 757 |
+
mean = subject_pixels.mean()
|
| 758 |
+
std = subject_pixels.std()
|
| 759 |
+
|
| 760 |
+
# Generate Gaussian noise for this channel
|
| 761 |
+
noise[:, :, c] = torch.normal(mean=mean.item(), std=std.item(),
|
| 762 |
+
size=(H, W), generator=generator,
|
| 763 |
+
device=processing_device)
|
| 764 |
+
|
| 765 |
+
# Clamp noise to valid range
|
| 766 |
+
noise = torch.clamp(noise, 0.0, 1.0)
|
| 767 |
+
|
| 768 |
+
# Apply: keep subject, fill background with noise
|
| 769 |
+
masked_image = curr_image * mask_expanded + noise * (1 - mask_expanded)
|
| 770 |
+
output_images.append(masked_image)
|
| 771 |
+
|
| 772 |
+
# If no masks were processed, return empty tensor
|
| 773 |
+
if not output_images:
|
| 774 |
+
return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
|
| 775 |
+
|
| 776 |
+
out_rgb = torch.stack(output_images, dim=0).cpu()
|
| 777 |
+
|
| 778 |
+
return (out_rgb, )
|
| 779 |
|
| 780 |
NODE_CLASS_MAPPINGS = {
|
| 781 |
"WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
|
|
|
|
| 789 |
"NormalizeAudioLoudness": NormalizeAudioLoudness,
|
| 790 |
"WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
|
| 791 |
"FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
|
| 792 |
+
"DrawGaussianNoiseOnImage": DrawGaussianNoiseOnImage,
|
| 793 |
}
|
| 794 |
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 795 |
"WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
|
|
|
|
| 803 |
"NormalizeAudioLoudness": "Normalize Audio Loudness",
|
| 804 |
"WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
|
| 805 |
"FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
|
| 806 |
+
"DrawGaussianNoiseOnImage": "Draw Gaussian Noise On Image",
|
| 807 |
}
|