Spaces:
Build error
Build error
Commit
·
fe3fdf0
1
Parent(s):
8dd41a8
feat: Add plugins
Browse files- mypy.ini +7 -0
- plugin_options/core.json +6 -0
- plugin_options/core_video.json +5 -0
- plugin_options/plugin_codeformer.json +8 -0
- plugin_options/plugin_dmdnet.json +3 -0
- plugin_options/plugin_faceswap.json +5 -0
- plugin_options/plugin_gfpgan.json +3 -0
- plugin_options/plugin_txt2clip.json +3 -0
- plugins/codeformer_app_cv2.py +300 -0
- plugins/codeformer_face_helper_cv2.py +94 -0
- plugins/core.py +29 -0
- plugins/core_video.py +26 -0
- plugins/plugin_codeformer.py +83 -0
- plugins/plugin_dmdnet.py +835 -0
- plugins/plugin_faceswap.py +86 -0
- plugins/plugin_gfpgan.py +85 -0
- plugins/plugin_txt2clip.py +122 -0
- roop-unleashed.ipynb +3 -0
mypy.ini
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[mypy]
|
| 2 |
+
check_untyped_defs = True
|
| 3 |
+
disallow_any_generics = True
|
| 4 |
+
disallow_untyped_calls = True
|
| 5 |
+
disallow_untyped_defs = True
|
| 6 |
+
ignore_missing_imports = True
|
| 7 |
+
strict_optional = False
|
plugin_options/core.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"default_chain": "faceswap",
|
| 3 |
+
"init_on_start": "faceswap,dmdnet,gfpgan,codeformer",
|
| 4 |
+
"is_demo_row_render": false,
|
| 5 |
+
"v": "2.0"
|
| 6 |
+
}
|
plugin_options/core_video.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"v": "2.0",
|
| 3 |
+
"video_save_codec": "libx264",
|
| 4 |
+
"video_save_crf": 14
|
| 5 |
+
}
|
plugin_options/plugin_codeformer.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"background_enhance": false,
|
| 3 |
+
"codeformer_fidelity": 0.8,
|
| 4 |
+
"face_upsample": true,
|
| 5 |
+
"skip_if_no_face": true,
|
| 6 |
+
"upscale": 1,
|
| 7 |
+
"v": "3.0"
|
| 8 |
+
}
|
plugin_options/plugin_dmdnet.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"v": "1.0"
|
| 3 |
+
}
|
plugin_options/plugin_faceswap.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"max_distance": 0.65,
|
| 3 |
+
"swap_mode": "selected",
|
| 4 |
+
"v": "1.0"
|
| 5 |
+
}
|
plugin_options/plugin_gfpgan.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"v": "1.4"
|
| 3 |
+
}
|
plugin_options/plugin_txt2clip.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"v": "1.0"
|
| 3 |
+
}
|
plugins/codeformer_app_cv2.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified version from codeformer-pip project
|
| 3 |
+
|
| 4 |
+
S-Lab License 1.0
|
| 5 |
+
|
| 6 |
+
Copyright 2022 S-Lab
|
| 7 |
+
|
| 8 |
+
https://github.com/kadirnar/codeformer-pip/blob/main/LICENSE
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import torch
|
| 15 |
+
from codeformer.facelib.detection import init_detection_model
|
| 16 |
+
from codeformer.facelib.parsing import init_parsing_model
|
| 17 |
+
from torchvision.transforms.functional import normalize
|
| 18 |
+
|
| 19 |
+
from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet
|
| 20 |
+
from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img
|
| 21 |
+
from codeformer.basicsr.utils.download_util import load_file_from_url
|
| 22 |
+
from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer
|
| 23 |
+
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
| 24 |
+
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 25 |
+
from codeformer.facelib.utils.misc import is_gray
|
| 26 |
+
import threading
|
| 27 |
+
|
| 28 |
+
from plugins.codeformer_face_helper_cv2 import FaceRestoreHelperOptimized
|
| 29 |
+
|
| 30 |
+
THREAD_LOCK_FACE_HELPER = threading.Lock()
|
| 31 |
+
THREAD_LOCK_FACE_HELPER_CREATE = threading.Lock()
|
| 32 |
+
THREAD_LOCK_FACE_HELPER_PROCERSSING = threading.Lock()
|
| 33 |
+
THREAD_LOCK_CODEFORMER_NET = threading.Lock()
|
| 34 |
+
THREAD_LOCK_CODEFORMER_NET_CREATE = threading.Lock()
|
| 35 |
+
THREAD_LOCK_BGUPSAMPLER = threading.Lock()
|
| 36 |
+
|
| 37 |
+
pretrain_model_url = {
|
| 38 |
+
"codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth",
|
| 39 |
+
"detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth",
|
| 40 |
+
"parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth",
|
| 41 |
+
"realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# download weights
|
| 45 |
+
if not os.path.exists("models/CodeFormer/codeformer.pth"):
|
| 46 |
+
load_file_from_url(
|
| 47 |
+
url=pretrain_model_url["codeformer"], model_dir="models/CodeFormer/", progress=True, file_name=None
|
| 48 |
+
)
|
| 49 |
+
if not os.path.exists("models/CodeFormer/facelib/detection_Resnet50_Final.pth"):
|
| 50 |
+
load_file_from_url(
|
| 51 |
+
url=pretrain_model_url["detection"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
|
| 52 |
+
)
|
| 53 |
+
if not os.path.exists("models/CodeFormer/facelib/parsing_parsenet.pth"):
|
| 54 |
+
load_file_from_url(
|
| 55 |
+
url=pretrain_model_url["parsing"], model_dir="models/CodeFormer/facelib", progress=True, file_name=None
|
| 56 |
+
)
|
| 57 |
+
if not os.path.exists("models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth"):
|
| 58 |
+
load_file_from_url(
|
| 59 |
+
url=pretrain_model_url["realesrgan"], model_dir="models/CodeFormer/realesrgan", progress=True, file_name=None
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def imread(img_path):
|
| 64 |
+
img = cv2.imread(img_path)
|
| 65 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 66 |
+
return img
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# set enhancer with RealESRGAN
|
| 70 |
+
def set_realesrgan():
|
| 71 |
+
half = True if torch.cuda.is_available() else False
|
| 72 |
+
model = RRDBNet(
|
| 73 |
+
num_in_ch=3,
|
| 74 |
+
num_out_ch=3,
|
| 75 |
+
num_feat=64,
|
| 76 |
+
num_block=23,
|
| 77 |
+
num_grow_ch=32,
|
| 78 |
+
scale=2,
|
| 79 |
+
)
|
| 80 |
+
upsampler = RealESRGANer(
|
| 81 |
+
scale=2,
|
| 82 |
+
model_path="models/CodeFormer/realesrgan/RealESRGAN_x2plus.pth",
|
| 83 |
+
model=model,
|
| 84 |
+
tile=400,
|
| 85 |
+
tile_pad=40,
|
| 86 |
+
pre_pad=0,
|
| 87 |
+
half=half,
|
| 88 |
+
)
|
| 89 |
+
return upsampler
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
upsampler = set_realesrgan()
|
| 93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 94 |
+
|
| 95 |
+
codeformers_cache = []
|
| 96 |
+
|
| 97 |
+
def get_codeformer():
|
| 98 |
+
if len(codeformers_cache) > 0:
|
| 99 |
+
with THREAD_LOCK_CODEFORMER_NET:
|
| 100 |
+
if len(codeformers_cache) > 0:
|
| 101 |
+
return codeformers_cache.pop()
|
| 102 |
+
|
| 103 |
+
with THREAD_LOCK_CODEFORMER_NET_CREATE:
|
| 104 |
+
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
| 105 |
+
dim_embd=512,
|
| 106 |
+
codebook_size=1024,
|
| 107 |
+
n_head=8,
|
| 108 |
+
n_layers=9,
|
| 109 |
+
connect_list=["32", "64", "128", "256"],
|
| 110 |
+
).to(device)
|
| 111 |
+
ckpt_path = "models/CodeFormer/codeformer.pth"
|
| 112 |
+
checkpoint = torch.load(ckpt_path)["params_ema"]
|
| 113 |
+
codeformer_net.load_state_dict(checkpoint)
|
| 114 |
+
codeformer_net.eval()
|
| 115 |
+
return codeformer_net
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def release_codeformer(codeformer):
|
| 120 |
+
with THREAD_LOCK_CODEFORMER_NET:
|
| 121 |
+
codeformers_cache.append(codeformer)
|
| 122 |
+
|
| 123 |
+
#os.makedirs("output", exist_ok=True)
|
| 124 |
+
|
| 125 |
+
# ------- face restore thread cache ----------
|
| 126 |
+
|
| 127 |
+
face_restore_helper_cache = []
|
| 128 |
+
|
| 129 |
+
detection_model = "retinaface_resnet50"
|
| 130 |
+
|
| 131 |
+
inited_face_restore_helper_nn = False
|
| 132 |
+
|
| 133 |
+
import time
|
| 134 |
+
|
| 135 |
+
def get_face_restore_helper(upscale):
|
| 136 |
+
global inited_face_restore_helper_nn
|
| 137 |
+
with THREAD_LOCK_FACE_HELPER:
|
| 138 |
+
face_helper = FaceRestoreHelperOptimized(
|
| 139 |
+
upscale,
|
| 140 |
+
face_size=512,
|
| 141 |
+
crop_ratio=(1, 1),
|
| 142 |
+
det_model=detection_model,
|
| 143 |
+
save_ext="png",
|
| 144 |
+
use_parse=True,
|
| 145 |
+
device=device,
|
| 146 |
+
)
|
| 147 |
+
#return face_helper
|
| 148 |
+
|
| 149 |
+
if inited_face_restore_helper_nn:
|
| 150 |
+
while len(face_restore_helper_cache) == 0:
|
| 151 |
+
time.sleep(0.05)
|
| 152 |
+
face_detector, face_parse = face_restore_helper_cache.pop()
|
| 153 |
+
face_helper.face_detector = face_detector
|
| 154 |
+
face_helper.face_parse = face_parse
|
| 155 |
+
return face_helper
|
| 156 |
+
else:
|
| 157 |
+
inited_face_restore_helper_nn = True
|
| 158 |
+
face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
|
| 159 |
+
face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
|
| 160 |
+
return face_helper
|
| 161 |
+
|
| 162 |
+
def get_face_restore_helper2(upscale): # still not work well!!!
|
| 163 |
+
face_helper = FaceRestoreHelperOptimized(
|
| 164 |
+
upscale,
|
| 165 |
+
face_size=512,
|
| 166 |
+
crop_ratio=(1, 1),
|
| 167 |
+
det_model=detection_model,
|
| 168 |
+
save_ext="png",
|
| 169 |
+
use_parse=True,
|
| 170 |
+
device=device,
|
| 171 |
+
)
|
| 172 |
+
#return face_helper
|
| 173 |
+
|
| 174 |
+
if len(face_restore_helper_cache) > 0:
|
| 175 |
+
with THREAD_LOCK_FACE_HELPER:
|
| 176 |
+
if len(face_restore_helper_cache) > 0:
|
| 177 |
+
face_detector, face_parse = face_restore_helper_cache.pop()
|
| 178 |
+
face_helper.face_detector = face_detector
|
| 179 |
+
face_helper.face_parse = face_parse
|
| 180 |
+
return face_helper
|
| 181 |
+
|
| 182 |
+
with THREAD_LOCK_FACE_HELPER_CREATE:
|
| 183 |
+
face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device)
|
| 184 |
+
face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device)
|
| 185 |
+
return face_helper
|
| 186 |
+
|
| 187 |
+
def release_face_restore_helper(face_helper):
|
| 188 |
+
#return
|
| 189 |
+
#with THREAD_LOCK_FACE_HELPER:
|
| 190 |
+
face_restore_helper_cache.append((face_helper.face_detector, face_helper.face_parse))
|
| 191 |
+
#pass
|
| 192 |
+
|
| 193 |
+
def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity, skip_if_no_face = False):
|
| 194 |
+
# take the default setting for the demo
|
| 195 |
+
has_aligned = False
|
| 196 |
+
only_center_face = False
|
| 197 |
+
draw_box = False
|
| 198 |
+
|
| 199 |
+
#print("Inp:", image, background_enhance, face_upsample, upscale, codeformer_fidelity)
|
| 200 |
+
if isinstance(image, str):
|
| 201 |
+
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
| 202 |
+
else:
|
| 203 |
+
img = image
|
| 204 |
+
#print("\timage size:", img.shape)
|
| 205 |
+
|
| 206 |
+
upscale = int(upscale) # convert type to int
|
| 207 |
+
if upscale > 4: # avoid memory exceeded due to too large upscale
|
| 208 |
+
upscale = 4
|
| 209 |
+
if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution
|
| 210 |
+
upscale = 2
|
| 211 |
+
if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
|
| 212 |
+
upscale = 1
|
| 213 |
+
background_enhance = False
|
| 214 |
+
#face_upsample = False
|
| 215 |
+
|
| 216 |
+
face_helper = get_face_restore_helper(upscale)
|
| 217 |
+
|
| 218 |
+
bg_upsampler = upsampler if background_enhance else None
|
| 219 |
+
face_upsampler = upsampler if face_upsample else None
|
| 220 |
+
|
| 221 |
+
if has_aligned:
|
| 222 |
+
# the input faces are already cropped and aligned
|
| 223 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 224 |
+
face_helper.is_gray = is_gray(img, threshold=5)
|
| 225 |
+
if face_helper.is_gray:
|
| 226 |
+
print("\tgrayscale input: True")
|
| 227 |
+
face_helper.cropped_faces = [img]
|
| 228 |
+
else:
|
| 229 |
+
with THREAD_LOCK_FACE_HELPER_PROCERSSING:
|
| 230 |
+
face_helper.read_image(img)
|
| 231 |
+
# get face landmarks for each face
|
| 232 |
+
|
| 233 |
+
num_det_faces = face_helper.get_face_landmarks_5(
|
| 234 |
+
only_center_face=only_center_face, resize=640, eye_dist_threshold=5
|
| 235 |
+
)
|
| 236 |
+
#print(f"\tdetect {num_det_faces} faces")
|
| 237 |
+
|
| 238 |
+
if num_det_faces == 0 and skip_if_no_face:
|
| 239 |
+
release_face_restore_helper(face_helper)
|
| 240 |
+
return img
|
| 241 |
+
|
| 242 |
+
# align and warp each face
|
| 243 |
+
face_helper.align_warp_face()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# face restoration for each cropped face
|
| 248 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
| 249 |
+
# prepare data
|
| 250 |
+
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
| 251 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 252 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 253 |
+
|
| 254 |
+
codeformer_net = get_codeformer()
|
| 255 |
+
try:
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
|
| 258 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
| 259 |
+
del output
|
| 260 |
+
except RuntimeError as error:
|
| 261 |
+
print(f"Failed inference for CodeFormer: {error}")
|
| 262 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 263 |
+
release_codeformer(codeformer_net)
|
| 264 |
+
|
| 265 |
+
restored_face = restored_face.astype("uint8")
|
| 266 |
+
face_helper.add_restored_face(restored_face)
|
| 267 |
+
|
| 268 |
+
# paste_back
|
| 269 |
+
if not has_aligned:
|
| 270 |
+
# upsample the background
|
| 271 |
+
if bg_upsampler is not None:
|
| 272 |
+
with THREAD_LOCK_BGUPSAMPLER:
|
| 273 |
+
# Now only support RealESRGAN for upsampling background
|
| 274 |
+
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
|
| 275 |
+
else:
|
| 276 |
+
bg_img = None
|
| 277 |
+
face_helper.get_inverse_affine(None)
|
| 278 |
+
# paste each restored face to the input image
|
| 279 |
+
if face_upsample and face_upsampler is not None:
|
| 280 |
+
restored_img = face_helper.paste_faces_to_input_image(
|
| 281 |
+
upsample_img=bg_img,
|
| 282 |
+
draw_box=draw_box,
|
| 283 |
+
face_upsampler=face_upsampler,
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
|
| 287 |
+
|
| 288 |
+
if image.shape != restored_img.shape:
|
| 289 |
+
h, w, _ = image.shape
|
| 290 |
+
restored_img = cv2.resize(restored_img, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
release_face_restore_helper(face_helper)
|
| 294 |
+
# save restored img
|
| 295 |
+
if isinstance(image, str):
|
| 296 |
+
save_path = f"output/out.png"
|
| 297 |
+
imwrite(restored_img, str(save_path))
|
| 298 |
+
return save_path
|
| 299 |
+
else:
|
| 300 |
+
return restored_img
|
plugins/codeformer_face_helper_cv2.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from codeformer.basicsr.utils.misc import get_device
|
| 5 |
+
|
| 6 |
+
class FaceRestoreHelperOptimized(FaceRestoreHelper):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
upscale_factor,
|
| 10 |
+
face_size=512,
|
| 11 |
+
crop_ratio=(1, 1),
|
| 12 |
+
det_model="retinaface_resnet50",
|
| 13 |
+
save_ext="png",
|
| 14 |
+
template_3points=False,
|
| 15 |
+
pad_blur=False,
|
| 16 |
+
use_parse=False,
|
| 17 |
+
device=None,
|
| 18 |
+
):
|
| 19 |
+
self.template_3points = template_3points # improve robustness
|
| 20 |
+
self.upscale_factor = int(upscale_factor)
|
| 21 |
+
# the cropped face ratio based on the square face
|
| 22 |
+
self.crop_ratio = crop_ratio # (h, w)
|
| 23 |
+
assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1"
|
| 24 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 25 |
+
self.det_model = det_model
|
| 26 |
+
|
| 27 |
+
if self.det_model == "dlib":
|
| 28 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
| 29 |
+
self.face_template = np.array(
|
| 30 |
+
[
|
| 31 |
+
[686.77227723, 488.62376238],
|
| 32 |
+
[586.77227723, 493.59405941],
|
| 33 |
+
[337.91089109, 488.38613861],
|
| 34 |
+
[437.95049505, 493.51485149],
|
| 35 |
+
[513.58415842, 678.5049505],
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
self.face_template = self.face_template / (1024 // face_size)
|
| 39 |
+
elif self.template_3points:
|
| 40 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 41 |
+
else:
|
| 42 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
| 43 |
+
# facexlib
|
| 44 |
+
self.face_template = np.array(
|
| 45 |
+
[
|
| 46 |
+
[192.98138, 239.94708],
|
| 47 |
+
[318.90277, 240.1936],
|
| 48 |
+
[256.63416, 314.01935],
|
| 49 |
+
[201.26117, 371.41043],
|
| 50 |
+
[313.08905, 371.15118],
|
| 51 |
+
]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
| 55 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 56 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 57 |
+
|
| 58 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
| 59 |
+
if self.crop_ratio[0] > 1:
|
| 60 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
| 61 |
+
if self.crop_ratio[1] > 1:
|
| 62 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
| 63 |
+
self.save_ext = save_ext
|
| 64 |
+
self.pad_blur = pad_blur
|
| 65 |
+
if self.pad_blur is True:
|
| 66 |
+
self.template_3points = False
|
| 67 |
+
|
| 68 |
+
self.all_landmarks_5 = []
|
| 69 |
+
self.det_faces = []
|
| 70 |
+
self.affine_matrices = []
|
| 71 |
+
self.inverse_affine_matrices = []
|
| 72 |
+
self.cropped_faces = []
|
| 73 |
+
self.restored_faces = []
|
| 74 |
+
self.pad_input_imgs = []
|
| 75 |
+
|
| 76 |
+
if device is None:
|
| 77 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 78 |
+
self.device = get_device()
|
| 79 |
+
else:
|
| 80 |
+
self.device = device
|
| 81 |
+
|
| 82 |
+
# init face detection model
|
| 83 |
+
# if self.det_model == "dlib":
|
| 84 |
+
# self.face_detector, self.shape_predictor_5 = self.init_dlib(
|
| 85 |
+
# dlib_model_url["face_detector"], dlib_model_url["shape_predictor_5"]
|
| 86 |
+
# )
|
| 87 |
+
# else:
|
| 88 |
+
# self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
| 89 |
+
|
| 90 |
+
# init face parsing model
|
| 91 |
+
self.use_parse = use_parse
|
| 92 |
+
#self.face_parse = init_parsing_model(model_name="parsenet", device=self.device)
|
| 93 |
+
|
| 94 |
+
# MUST set face_detector and face_parse!!!
|
plugins/core.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core plugin
|
| 2 |
+
# author: Vladislav Janvarev
|
| 3 |
+
|
| 4 |
+
from chain_img_processor import ChainImgProcessor
|
| 5 |
+
|
| 6 |
+
# start function
|
| 7 |
+
def start(core:ChainImgProcessor):
|
| 8 |
+
manifest = {
|
| 9 |
+
"name": "Core plugin",
|
| 10 |
+
"version": "2.0",
|
| 11 |
+
|
| 12 |
+
"default_options": {
|
| 13 |
+
"default_chain": "faceswap", # default chain to run
|
| 14 |
+
"init_on_start": "faceswap,txt2clip,gfpgan,codeformer", # init these processors on start
|
| 15 |
+
"is_demo_row_render": False,
|
| 16 |
+
},
|
| 17 |
+
|
| 18 |
+
}
|
| 19 |
+
return manifest
|
| 20 |
+
|
| 21 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 22 |
+
options = manifest["options"]
|
| 23 |
+
|
| 24 |
+
core.default_chain = options["default_chain"]
|
| 25 |
+
core.init_on_start = options["init_on_start"]
|
| 26 |
+
|
| 27 |
+
core.is_demo_row_render= options["is_demo_row_render"]
|
| 28 |
+
|
| 29 |
+
return manifest
|
plugins/core_video.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core plugin
|
| 2 |
+
# author: Vladislav Janvarev
|
| 3 |
+
|
| 4 |
+
from chain_img_processor import ChainImgProcessor, ChainVideoProcessor
|
| 5 |
+
|
| 6 |
+
# start function
|
| 7 |
+
def start(core:ChainImgProcessor):
|
| 8 |
+
manifest = {
|
| 9 |
+
"name": "Core video plugin",
|
| 10 |
+
"version": "2.0",
|
| 11 |
+
|
| 12 |
+
"default_options": {
|
| 13 |
+
"video_save_codec": "libx264", # default codec to save
|
| 14 |
+
"video_save_crf": 14, # default crf to save
|
| 15 |
+
},
|
| 16 |
+
|
| 17 |
+
}
|
| 18 |
+
return manifest
|
| 19 |
+
|
| 20 |
+
def start_with_options(core:ChainVideoProcessor, manifest:dict):
|
| 21 |
+
options = manifest["options"]
|
| 22 |
+
|
| 23 |
+
core.video_save_codec = options["video_save_codec"]
|
| 24 |
+
core.video_save_crf = options["video_save_crf"]
|
| 25 |
+
|
| 26 |
+
return manifest
|
plugins/plugin_codeformer.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Codeformer enchance plugin
|
| 2 |
+
# author: Vladislav Janvarev
|
| 3 |
+
|
| 4 |
+
# CountFloyd 20230717, extended to blend original/destination images
|
| 5 |
+
|
| 6 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from numpy import asarray
|
| 10 |
+
|
| 11 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
| 12 |
+
|
| 13 |
+
# start function
|
| 14 |
+
def start(core:ChainImgProcessor):
|
| 15 |
+
manifest = { # plugin settings
|
| 16 |
+
"name": "Codeformer", # name
|
| 17 |
+
"version": "3.0", # version
|
| 18 |
+
|
| 19 |
+
"default_options": {
|
| 20 |
+
"background_enhance": True, #
|
| 21 |
+
"face_upsample": True, #
|
| 22 |
+
"upscale": 2, #
|
| 23 |
+
"codeformer_fidelity": 0.8,
|
| 24 |
+
"skip_if_no_face":False,
|
| 25 |
+
|
| 26 |
+
},
|
| 27 |
+
|
| 28 |
+
"img_processor": {
|
| 29 |
+
"codeformer": PluginCodeformer # 1 function - init, 2 - process
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
return manifest
|
| 33 |
+
|
| 34 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
class PluginCodeformer(ChainImgPlugin):
|
| 38 |
+
def init_plugin(self):
|
| 39 |
+
import plugins.codeformer_app_cv2
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def process(self, img, params:dict):
|
| 43 |
+
import copy
|
| 44 |
+
|
| 45 |
+
# params can be used to transfer some img info to next processors
|
| 46 |
+
from plugins.codeformer_app_cv2 import inference_app
|
| 47 |
+
options = self.core.plugin_options(modname)
|
| 48 |
+
|
| 49 |
+
if "face_detected" in params:
|
| 50 |
+
if not params["face_detected"]:
|
| 51 |
+
return img
|
| 52 |
+
|
| 53 |
+
# don't touch original
|
| 54 |
+
temp_frame = copy.copy(img)
|
| 55 |
+
if "processed_faces" in params:
|
| 56 |
+
for face in params["processed_faces"]:
|
| 57 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
| 58 |
+
padding_x = int((end_x - start_x) * 0.5)
|
| 59 |
+
padding_y = int((end_y - start_y) * 0.5)
|
| 60 |
+
start_x = max(0, start_x - padding_x)
|
| 61 |
+
start_y = max(0, start_y - padding_y)
|
| 62 |
+
end_x = max(0, end_x + padding_x)
|
| 63 |
+
end_y = max(0, end_y + padding_y)
|
| 64 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
| 65 |
+
if temp_face.size:
|
| 66 |
+
temp_face = inference_app(temp_face, options.get("background_enhance"), options.get("face_upsample"),
|
| 67 |
+
options.get("upscale"), options.get("codeformer_fidelity"),
|
| 68 |
+
options.get("skip_if_no_face"))
|
| 69 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
| 70 |
+
else:
|
| 71 |
+
temp_frame = inference_app(temp_frame, options.get("background_enhance"), options.get("face_upsample"),
|
| 72 |
+
options.get("upscale"), options.get("codeformer_fidelity"),
|
| 73 |
+
options.get("skip_if_no_face"))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if not "blend_ratio" in params:
|
| 78 |
+
return temp_frame
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
temp_frame = Image.blend(Image.fromarray(img), Image.fromarray(temp_frame), params["blend_ratio"])
|
| 82 |
+
return asarray(temp_frame)
|
| 83 |
+
|
plugins/plugin_dmdnet.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from numpy import asarray
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import scipy.io as sio
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn.utils.spectral_norm as SpectralNorm
|
| 12 |
+
from torchvision.ops import roi_align
|
| 13 |
+
|
| 14 |
+
from math import sqrt
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import cv2
|
| 18 |
+
import os
|
| 19 |
+
from torchvision.transforms.functional import normalize
|
| 20 |
+
import copy
|
| 21 |
+
import threading
|
| 22 |
+
|
| 23 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
| 24 |
+
|
| 25 |
+
oDMDNet = None
|
| 26 |
+
device = None
|
| 27 |
+
|
| 28 |
+
THREAD_LOCK_DMDNET = threading.Lock()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# start function
|
| 33 |
+
def start(core:ChainImgProcessor):
|
| 34 |
+
manifest = { # plugin settings
|
| 35 |
+
"name": "DMDNet", # name
|
| 36 |
+
"version": "1.0", # version
|
| 37 |
+
|
| 38 |
+
"default_options": {},
|
| 39 |
+
"img_processor": {
|
| 40 |
+
"dmdnet": DMDNETPlugin
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
return manifest
|
| 44 |
+
|
| 45 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DMDNETPlugin(ChainImgPlugin):
|
| 50 |
+
|
| 51 |
+
# https://stackoverflow.com/a/67174339
|
| 52 |
+
def landmarks106_to_68(self, pt106):
|
| 53 |
+
map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
|
| 54 |
+
43,48,49,51,50,
|
| 55 |
+
102,103,104,105,101,
|
| 56 |
+
72,73,74,86,78,79,80,85,84,
|
| 57 |
+
35,41,42,39,37,36,
|
| 58 |
+
89,95,96,93,91,90,
|
| 59 |
+
52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
pt68 = []
|
| 63 |
+
for i in range(68):
|
| 64 |
+
index = map106to68[i]
|
| 65 |
+
pt68.append(pt106[index])
|
| 66 |
+
return pt68
|
| 67 |
+
|
| 68 |
+
def init_plugin(self):
|
| 69 |
+
global create
|
| 70 |
+
|
| 71 |
+
if oDMDNet == None:
|
| 72 |
+
create(self.device)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def process(self, frame, params:dict):
|
| 76 |
+
if "face_detected" in params:
|
| 77 |
+
if not params["face_detected"]:
|
| 78 |
+
return frame
|
| 79 |
+
|
| 80 |
+
temp_frame = copy.copy(frame)
|
| 81 |
+
if "processed_faces" in params:
|
| 82 |
+
for face in params["processed_faces"]:
|
| 83 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
| 84 |
+
# padding_x = int((end_x - start_x) * 0.5)
|
| 85 |
+
# padding_y = int((end_y - start_y) * 0.5)
|
| 86 |
+
padding_x = 0
|
| 87 |
+
padding_y = 0
|
| 88 |
+
|
| 89 |
+
start_x = max(0, start_x - padding_x)
|
| 90 |
+
start_y = max(0, start_y - padding_y)
|
| 91 |
+
end_x = max(0, end_x + padding_x)
|
| 92 |
+
end_y = max(0, end_y + padding_y)
|
| 93 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
| 94 |
+
if temp_face.size:
|
| 95 |
+
temp_face = self.enhance_face(temp_face, face)
|
| 96 |
+
temp_face = cv2.resize(temp_face, (end_x - start_x,end_y - start_y), interpolation = cv2.INTER_LANCZOS4)
|
| 97 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
| 98 |
+
|
| 99 |
+
temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
|
| 100 |
+
return asarray(temp_frame)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def enhance_face(self, clip, face):
|
| 104 |
+
global device
|
| 105 |
+
|
| 106 |
+
lm106 = face.landmark_2d_106
|
| 107 |
+
lq_landmarks = asarray(self.landmarks106_to_68(lm106))
|
| 108 |
+
lq = read_img_tensor(clip, False)
|
| 109 |
+
|
| 110 |
+
LQLocs = get_component_location(lq_landmarks)
|
| 111 |
+
# generic
|
| 112 |
+
SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
with THREAD_LOCK_DMDNET:
|
| 116 |
+
try:
|
| 117 |
+
GenericResult, SpecificResult = oDMDNet(lq = lq.to(device), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f'Error {e} there may be something wrong with the detected component locations.')
|
| 120 |
+
return clip
|
| 121 |
+
save_generic = GenericResult * 0.5 + 0.5
|
| 122 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
| 123 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
| 124 |
+
|
| 125 |
+
check_lq = lq * 0.5 + 0.5
|
| 126 |
+
check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
| 127 |
+
check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
|
| 128 |
+
enhanced_img = np.hstack((check_lq, save_generic))
|
| 129 |
+
temp_frame = save_generic.astype("uint8")
|
| 130 |
+
# temp_frame = save_generic.astype("uint8")
|
| 131 |
+
return temp_frame
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def create(devicename):
|
| 135 |
+
global device, oDMDNet
|
| 136 |
+
|
| 137 |
+
test = "cuda" if torch.cuda.is_available() else "cpu"
|
| 138 |
+
device = torch.device(devicename)
|
| 139 |
+
oDMDNet = DMDNet().to(device)
|
| 140 |
+
weights = torch.load('./models/DMDNet.pth')
|
| 141 |
+
oDMDNet.load_state_dict(weights, strict=True)
|
| 142 |
+
|
| 143 |
+
oDMDNet.eval()
|
| 144 |
+
num_params = 0
|
| 145 |
+
for param in oDMDNet.parameters():
|
| 146 |
+
num_params += param.numel()
|
| 147 |
+
|
| 148 |
+
# print('{:>8s} : {}'.format('Using device', device))
|
| 149 |
+
# print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def read_img_tensor(Img=None, return_landmark=True): #rgb -1~1
|
| 154 |
+
# Img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # BGR or G
|
| 155 |
+
if Img.ndim == 2:
|
| 156 |
+
Img = cv2.cvtColor(Img, cv2.COLOR_GRAY2RGB) # GGG
|
| 157 |
+
else:
|
| 158 |
+
Img = cv2.cvtColor(Img, cv2.COLOR_BGR2RGB) # RGB
|
| 159 |
+
|
| 160 |
+
if Img.shape[0] < 512 or Img.shape[1] < 512:
|
| 161 |
+
Img = cv2.resize(Img, (512,512), interpolation = cv2.INTER_AREA)
|
| 162 |
+
# ImgForLands = Img.copy()
|
| 163 |
+
|
| 164 |
+
Img = Img.transpose((2, 0, 1))/255.0
|
| 165 |
+
Img = torch.from_numpy(Img).float()
|
| 166 |
+
normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
|
| 167 |
+
ImgTensor = Img.unsqueeze(0)
|
| 168 |
+
return ImgTensor
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_component_location(Landmarks, re_read=False):
|
| 172 |
+
if re_read:
|
| 173 |
+
ReadLandmark = []
|
| 174 |
+
with open(Landmarks,'r') as f:
|
| 175 |
+
for line in f:
|
| 176 |
+
tmp = [float(i) for i in line.split(' ') if i != '\n']
|
| 177 |
+
ReadLandmark.append(tmp)
|
| 178 |
+
ReadLandmark = np.array(ReadLandmark) #
|
| 179 |
+
Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
|
| 180 |
+
Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
|
| 181 |
+
Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
|
| 182 |
+
Map_LE = list(range(36,42))
|
| 183 |
+
Map_RE = list(range(42,48))
|
| 184 |
+
Map_NO = list(range(29,36))
|
| 185 |
+
Map_MO = list(range(48,68))
|
| 186 |
+
|
| 187 |
+
Landmarks[Landmarks>504]=504
|
| 188 |
+
Landmarks[Landmarks<8]=8
|
| 189 |
+
|
| 190 |
+
#left eye
|
| 191 |
+
Mean_LE = np.mean(Landmarks[Map_LE],0)
|
| 192 |
+
L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
|
| 193 |
+
L_LE1 = L_LE1 * 1.3
|
| 194 |
+
L_LE2 = L_LE1 / 1.9
|
| 195 |
+
L_LE_xy = L_LE1 + L_LE2
|
| 196 |
+
L_LE_lt = [L_LE_xy/2, L_LE1]
|
| 197 |
+
L_LE_rb = [L_LE_xy/2, L_LE2]
|
| 198 |
+
Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
|
| 199 |
+
|
| 200 |
+
#right eye
|
| 201 |
+
Mean_RE = np.mean(Landmarks[Map_RE],0)
|
| 202 |
+
L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
|
| 203 |
+
L_RE1 = L_RE1 * 1.3
|
| 204 |
+
L_RE2 = L_RE1 / 1.9
|
| 205 |
+
L_RE_xy = L_RE1 + L_RE2
|
| 206 |
+
L_RE_lt = [L_RE_xy/2, L_RE1]
|
| 207 |
+
L_RE_rb = [L_RE_xy/2, L_RE2]
|
| 208 |
+
Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
|
| 209 |
+
|
| 210 |
+
#nose
|
| 211 |
+
Mean_NO = np.mean(Landmarks[Map_NO],0)
|
| 212 |
+
L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
|
| 213 |
+
L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
|
| 214 |
+
L_NO_xy = L_NO1 * 2
|
| 215 |
+
L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
|
| 216 |
+
L_NO_rb = [L_NO_xy/2, L_NO2]
|
| 217 |
+
Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
|
| 218 |
+
|
| 219 |
+
#mouth
|
| 220 |
+
Mean_MO = np.mean(Landmarks[Map_MO],0)
|
| 221 |
+
L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
|
| 222 |
+
MO_O = Mean_MO - L_MO + 1
|
| 223 |
+
MO_T = Mean_MO + L_MO
|
| 224 |
+
MO_T[MO_T>510]=510
|
| 225 |
+
Location_MO = np.hstack((MO_O, MO_T)).astype(int)
|
| 226 |
+
return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def calc_mean_std_4D(feat, eps=1e-5):
|
| 232 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
| 233 |
+
size = feat.size()
|
| 234 |
+
assert (len(size) == 4)
|
| 235 |
+
N, C = size[:2]
|
| 236 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
| 237 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
| 238 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 239 |
+
return feat_mean, feat_std
|
| 240 |
+
|
| 241 |
+
def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
|
| 242 |
+
size = content_feat.size()
|
| 243 |
+
style_mean, style_std = calc_mean_std_4D(style_feat)
|
| 244 |
+
content_mean, content_std = calc_mean_std_4D(content_feat)
|
| 245 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 246 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
|
| 250 |
+
return nn.Sequential(
|
| 251 |
+
SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
| 252 |
+
nn.LeakyReLU(0.2),
|
| 253 |
+
SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class MSDilateBlock(nn.Module):
|
| 258 |
+
def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
|
| 259 |
+
super(MSDilateBlock, self).__init__()
|
| 260 |
+
self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
|
| 261 |
+
self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
|
| 262 |
+
self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
|
| 263 |
+
self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
|
| 264 |
+
self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
|
| 265 |
+
def forward(self, x):
|
| 266 |
+
conv1 = self.conv1(x)
|
| 267 |
+
conv2 = self.conv2(x)
|
| 268 |
+
conv3 = self.conv3(x)
|
| 269 |
+
conv4 = self.conv4(x)
|
| 270 |
+
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
|
| 271 |
+
out = self.convi(cat) + x
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class AdaptiveInstanceNorm(nn.Module):
|
| 276 |
+
def __init__(self, in_channel):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.norm = nn.InstanceNorm2d(in_channel)
|
| 279 |
+
|
| 280 |
+
def forward(self, input, style):
|
| 281 |
+
style_mean, style_std = calc_mean_std_4D(style)
|
| 282 |
+
out = self.norm(input)
|
| 283 |
+
size = input.size()
|
| 284 |
+
out = style_std.expand(size) * out + style_mean.expand(size)
|
| 285 |
+
return out
|
| 286 |
+
|
| 287 |
+
class NoiseInjection(nn.Module):
|
| 288 |
+
def __init__(self, channel):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
| 291 |
+
def forward(self, image, noise):
|
| 292 |
+
if noise is None:
|
| 293 |
+
b, c, h, w = image.shape
|
| 294 |
+
noise = image.new_empty(b, 1, h, w).normal_()
|
| 295 |
+
return image + self.weight * noise
|
| 296 |
+
|
| 297 |
+
class StyledUpBlock(nn.Module):
|
| 298 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
|
| 299 |
+
super().__init__()
|
| 300 |
+
|
| 301 |
+
self.noise_inject = noise_inject
|
| 302 |
+
if upsample:
|
| 303 |
+
self.conv1 = nn.Sequential(
|
| 304 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 305 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
| 306 |
+
nn.LeakyReLU(0.2),
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
self.conv1 = nn.Sequential(
|
| 310 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
| 311 |
+
nn.LeakyReLU(0.2),
|
| 312 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
| 313 |
+
)
|
| 314 |
+
self.convup = nn.Sequential(
|
| 315 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 316 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
| 317 |
+
nn.LeakyReLU(0.2),
|
| 318 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
| 319 |
+
)
|
| 320 |
+
if self.noise_inject:
|
| 321 |
+
self.noise1 = NoiseInjection(out_channel)
|
| 322 |
+
|
| 323 |
+
self.lrelu1 = nn.LeakyReLU(0.2)
|
| 324 |
+
|
| 325 |
+
self.ScaleModel1 = nn.Sequential(
|
| 326 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
| 327 |
+
nn.LeakyReLU(0.2),
|
| 328 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
|
| 329 |
+
)
|
| 330 |
+
self.ShiftModel1 = nn.Sequential(
|
| 331 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
| 332 |
+
nn.LeakyReLU(0.2),
|
| 333 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
def forward(self, input, style):
|
| 337 |
+
out = self.conv1(input)
|
| 338 |
+
out = self.lrelu1(out)
|
| 339 |
+
Shift1 = self.ShiftModel1(style)
|
| 340 |
+
Scale1 = self.ScaleModel1(style)
|
| 341 |
+
out = out * Scale1 + Shift1
|
| 342 |
+
if self.noise_inject:
|
| 343 |
+
out = self.noise1(out, noise=None)
|
| 344 |
+
outup = self.convup(out)
|
| 345 |
+
return outup
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
####################################################################
|
| 349 |
+
###############Face Dictionary Generator
|
| 350 |
+
####################################################################
|
| 351 |
+
def AttentionBlock(in_channel):
|
| 352 |
+
return nn.Sequential(
|
| 353 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
| 354 |
+
nn.LeakyReLU(0.2),
|
| 355 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
class DilateResBlock(nn.Module):
|
| 359 |
+
def __init__(self, dim, dilation=[5,3] ):
|
| 360 |
+
super(DilateResBlock, self).__init__()
|
| 361 |
+
self.Res = nn.Sequential(
|
| 362 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
|
| 363 |
+
nn.LeakyReLU(0.2),
|
| 364 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
|
| 365 |
+
)
|
| 366 |
+
def forward(self, x):
|
| 367 |
+
out = x + self.Res(x)
|
| 368 |
+
return out
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class KeyValue(nn.Module):
|
| 372 |
+
def __init__(self, indim, keydim, valdim):
|
| 373 |
+
super(KeyValue, self).__init__()
|
| 374 |
+
self.Key = nn.Sequential(
|
| 375 |
+
SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 376 |
+
nn.LeakyReLU(0.2),
|
| 377 |
+
SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 378 |
+
)
|
| 379 |
+
self.Value = nn.Sequential(
|
| 380 |
+
SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 381 |
+
nn.LeakyReLU(0.2),
|
| 382 |
+
SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 383 |
+
)
|
| 384 |
+
def forward(self, x):
|
| 385 |
+
return self.Key(x), self.Value(x)
|
| 386 |
+
|
| 387 |
+
class MaskAttention(nn.Module):
|
| 388 |
+
def __init__(self, indim):
|
| 389 |
+
super(MaskAttention, self).__init__()
|
| 390 |
+
self.conv1 = nn.Sequential(
|
| 391 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 392 |
+
nn.LeakyReLU(0.2),
|
| 393 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 394 |
+
)
|
| 395 |
+
self.conv2 = nn.Sequential(
|
| 396 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 397 |
+
nn.LeakyReLU(0.2),
|
| 398 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 399 |
+
)
|
| 400 |
+
self.conv3 = nn.Sequential(
|
| 401 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 402 |
+
nn.LeakyReLU(0.2),
|
| 403 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 404 |
+
)
|
| 405 |
+
self.convCat = nn.Sequential(
|
| 406 |
+
SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 407 |
+
nn.LeakyReLU(0.2),
|
| 408 |
+
SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 409 |
+
)
|
| 410 |
+
def forward(self, x, y, z):
|
| 411 |
+
c1 = self.conv1(x)
|
| 412 |
+
c2 = self.conv2(y)
|
| 413 |
+
c3 = self.conv3(z)
|
| 414 |
+
return self.convCat(torch.cat([c1,c2,c3], dim=1))
|
| 415 |
+
|
| 416 |
+
class Query(nn.Module):
|
| 417 |
+
def __init__(self, indim, quedim):
|
| 418 |
+
super(Query, self).__init__()
|
| 419 |
+
self.Query = nn.Sequential(
|
| 420 |
+
SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 421 |
+
nn.LeakyReLU(0.2),
|
| 422 |
+
SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
| 423 |
+
)
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
return self.Query(x)
|
| 426 |
+
|
| 427 |
+
def roi_align_self(input, location, target_size):
|
| 428 |
+
return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],(target_size,target_size),mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
|
| 429 |
+
|
| 430 |
+
class FeatureExtractor(nn.Module):
|
| 431 |
+
def __init__(self, ngf = 64, key_scale = 4):#
|
| 432 |
+
super().__init__()
|
| 433 |
+
|
| 434 |
+
self.key_scale = 4
|
| 435 |
+
self.part_sizes = np.array([80,80,50,110]) #
|
| 436 |
+
self.feature_sizes = np.array([256,128,64]) #
|
| 437 |
+
|
| 438 |
+
self.conv1 = nn.Sequential(
|
| 439 |
+
SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
|
| 440 |
+
nn.LeakyReLU(0.2),
|
| 441 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 442 |
+
)
|
| 443 |
+
self.conv2 = nn.Sequential(
|
| 444 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 445 |
+
nn.LeakyReLU(0.2),
|
| 446 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
|
| 447 |
+
)
|
| 448 |
+
self.res1 = DilateResBlock(ngf, [5,3])
|
| 449 |
+
self.res2 = DilateResBlock(ngf, [5,3])
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
self.conv3 = nn.Sequential(
|
| 453 |
+
SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
|
| 454 |
+
nn.LeakyReLU(0.2),
|
| 455 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
| 456 |
+
)
|
| 457 |
+
self.conv4 = nn.Sequential(
|
| 458 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
| 459 |
+
nn.LeakyReLU(0.2),
|
| 460 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
|
| 461 |
+
)
|
| 462 |
+
self.res3 = DilateResBlock(ngf*2, [3,1])
|
| 463 |
+
self.res4 = DilateResBlock(ngf*2, [3,1])
|
| 464 |
+
|
| 465 |
+
self.conv5 = nn.Sequential(
|
| 466 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
|
| 467 |
+
nn.LeakyReLU(0.2),
|
| 468 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
| 469 |
+
)
|
| 470 |
+
self.conv6 = nn.Sequential(
|
| 471 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
| 472 |
+
nn.LeakyReLU(0.2),
|
| 473 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
|
| 474 |
+
)
|
| 475 |
+
self.res5 = DilateResBlock(ngf*4, [1,1])
|
| 476 |
+
self.res6 = DilateResBlock(ngf*4, [1,1])
|
| 477 |
+
|
| 478 |
+
self.LE_256_Q = Query(ngf, ngf // self.key_scale)
|
| 479 |
+
self.RE_256_Q = Query(ngf, ngf // self.key_scale)
|
| 480 |
+
self.MO_256_Q = Query(ngf, ngf // self.key_scale)
|
| 481 |
+
self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 482 |
+
self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 483 |
+
self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
| 484 |
+
self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 485 |
+
self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 486 |
+
self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def forward(self, img, locs):
|
| 490 |
+
le_location = locs[:,0,:].int().cpu().numpy()
|
| 491 |
+
re_location = locs[:,1,:].int().cpu().numpy()
|
| 492 |
+
no_location = locs[:,2,:].int().cpu().numpy()
|
| 493 |
+
mo_location = locs[:,3,:].int().cpu().numpy()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
f1_0 = self.conv1(img)
|
| 497 |
+
f1_1 = self.res1(f1_0)
|
| 498 |
+
f2_0 = self.conv2(f1_1)
|
| 499 |
+
f2_1 = self.res2(f2_0)
|
| 500 |
+
|
| 501 |
+
f3_0 = self.conv3(f2_1)
|
| 502 |
+
f3_1 = self.res3(f3_0)
|
| 503 |
+
f4_0 = self.conv4(f3_1)
|
| 504 |
+
f4_1 = self.res4(f4_0)
|
| 505 |
+
|
| 506 |
+
f5_0 = self.conv5(f4_1)
|
| 507 |
+
f5_1 = self.res5(f5_0)
|
| 508 |
+
f6_0 = self.conv6(f5_1)
|
| 509 |
+
f6_1 = self.res6(f6_0)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
####ROI Align
|
| 513 |
+
le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
|
| 514 |
+
re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
|
| 515 |
+
mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
|
| 516 |
+
|
| 517 |
+
le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
|
| 518 |
+
re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
|
| 519 |
+
mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
|
| 520 |
+
|
| 521 |
+
le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
|
| 522 |
+
re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
|
| 523 |
+
mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
le_256_q = self.LE_256_Q(le_part_256)
|
| 527 |
+
re_256_q = self.RE_256_Q(re_part_256)
|
| 528 |
+
mo_256_q = self.MO_256_Q(mo_part_256)
|
| 529 |
+
|
| 530 |
+
le_128_q = self.LE_128_Q(le_part_128)
|
| 531 |
+
re_128_q = self.RE_128_Q(re_part_128)
|
| 532 |
+
mo_128_q = self.MO_128_Q(mo_part_128)
|
| 533 |
+
|
| 534 |
+
le_64_q = self.LE_64_Q(le_part_64)
|
| 535 |
+
re_64_q = self.RE_64_Q(re_part_64)
|
| 536 |
+
mo_64_q = self.MO_64_Q(mo_part_64)
|
| 537 |
+
|
| 538 |
+
return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
|
| 539 |
+
'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
|
| 540 |
+
'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
|
| 541 |
+
'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
|
| 542 |
+
'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
|
| 543 |
+
'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
|
| 544 |
+
'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class DMDNet(nn.Module):
|
| 548 |
+
def __init__(self, ngf = 64, banks_num = 128):
|
| 549 |
+
super().__init__()
|
| 550 |
+
self.part_sizes = np.array([80,80,50,110]) # size for 512
|
| 551 |
+
self.feature_sizes = np.array([256,128,64]) # size for 512
|
| 552 |
+
|
| 553 |
+
self.banks_num = banks_num
|
| 554 |
+
self.key_scale = 4
|
| 555 |
+
|
| 556 |
+
self.E_lq = FeatureExtractor(key_scale = self.key_scale)
|
| 557 |
+
self.E_hq = FeatureExtractor(key_scale = self.key_scale)
|
| 558 |
+
|
| 559 |
+
self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 560 |
+
self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 561 |
+
self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
| 562 |
+
|
| 563 |
+
self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
| 564 |
+
self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
| 565 |
+
self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
| 566 |
+
|
| 567 |
+
self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
| 568 |
+
self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
| 569 |
+
self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
self.LE_256_Attention = AttentionBlock(64)
|
| 573 |
+
self.RE_256_Attention = AttentionBlock(64)
|
| 574 |
+
self.MO_256_Attention = AttentionBlock(64)
|
| 575 |
+
|
| 576 |
+
self.LE_128_Attention = AttentionBlock(128)
|
| 577 |
+
self.RE_128_Attention = AttentionBlock(128)
|
| 578 |
+
self.MO_128_Attention = AttentionBlock(128)
|
| 579 |
+
|
| 580 |
+
self.LE_64_Attention = AttentionBlock(256)
|
| 581 |
+
self.RE_64_Attention = AttentionBlock(256)
|
| 582 |
+
self.MO_64_Attention = AttentionBlock(256)
|
| 583 |
+
|
| 584 |
+
self.LE_256_Mask = MaskAttention(64)
|
| 585 |
+
self.RE_256_Mask = MaskAttention(64)
|
| 586 |
+
self.MO_256_Mask = MaskAttention(64)
|
| 587 |
+
|
| 588 |
+
self.LE_128_Mask = MaskAttention(128)
|
| 589 |
+
self.RE_128_Mask = MaskAttention(128)
|
| 590 |
+
self.MO_128_Mask = MaskAttention(128)
|
| 591 |
+
|
| 592 |
+
self.LE_64_Mask = MaskAttention(256)
|
| 593 |
+
self.RE_64_Mask = MaskAttention(256)
|
| 594 |
+
self.MO_64_Mask = MaskAttention(256)
|
| 595 |
+
|
| 596 |
+
self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
|
| 597 |
+
|
| 598 |
+
self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
|
| 599 |
+
self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
|
| 600 |
+
self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
|
| 601 |
+
self.up4 = nn.Sequential(
|
| 602 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
| 603 |
+
nn.LeakyReLU(0.2),
|
| 604 |
+
UpResBlock(ngf),
|
| 605 |
+
UpResBlock(ngf),
|
| 606 |
+
SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
|
| 607 |
+
nn.Tanh()
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# define generic memory, revise register_buffer to register_parameter for backward update
|
| 611 |
+
self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
|
| 612 |
+
self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
|
| 613 |
+
self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
|
| 614 |
+
self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
|
| 615 |
+
self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
|
| 616 |
+
self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
|
| 620 |
+
self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
|
| 621 |
+
self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
|
| 622 |
+
self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
|
| 623 |
+
self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
|
| 624 |
+
self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
|
| 625 |
+
|
| 626 |
+
self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
|
| 627 |
+
self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
|
| 628 |
+
self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
|
| 629 |
+
self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
|
| 630 |
+
self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
|
| 631 |
+
self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def readMem(self, k, v, q):
|
| 635 |
+
sim = F.conv2d(q, k)
|
| 636 |
+
score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
|
| 637 |
+
sb,sn,sw,sh = score.size()
|
| 638 |
+
s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
|
| 639 |
+
vb,vn,vw,vh = v.size()
|
| 640 |
+
v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
|
| 641 |
+
mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
|
| 642 |
+
max_inds = torch.argmax(score, dim=1).squeeze()
|
| 643 |
+
return mem_out, max_inds
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
def memorize(self, img, locs):
|
| 647 |
+
fs = self.E_hq(img, locs)
|
| 648 |
+
LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
|
| 649 |
+
RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
|
| 650 |
+
MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
|
| 651 |
+
|
| 652 |
+
LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
|
| 653 |
+
RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
|
| 654 |
+
MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
|
| 655 |
+
|
| 656 |
+
LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
|
| 657 |
+
RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
|
| 658 |
+
MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
|
| 659 |
+
|
| 660 |
+
Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
|
| 661 |
+
Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
|
| 662 |
+
Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
|
| 663 |
+
|
| 664 |
+
FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
|
| 665 |
+
FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
|
| 666 |
+
FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
|
| 667 |
+
|
| 668 |
+
return Mem256, Mem128, Mem64
|
| 669 |
+
|
| 670 |
+
def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
|
| 671 |
+
le_256_q = fs_in['le_256_q']
|
| 672 |
+
re_256_q = fs_in['re_256_q']
|
| 673 |
+
mo_256_q = fs_in['mo_256_q']
|
| 674 |
+
|
| 675 |
+
le_128_q = fs_in['le_128_q']
|
| 676 |
+
re_128_q = fs_in['re_128_q']
|
| 677 |
+
mo_128_q = fs_in['mo_128_q']
|
| 678 |
+
|
| 679 |
+
le_64_q = fs_in['le_64_q']
|
| 680 |
+
re_64_q = fs_in['re_64_q']
|
| 681 |
+
mo_64_q = fs_in['mo_64_q']
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
####for 256
|
| 685 |
+
le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
|
| 686 |
+
re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
|
| 687 |
+
mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
|
| 688 |
+
|
| 689 |
+
le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
|
| 690 |
+
re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
|
| 691 |
+
mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
|
| 692 |
+
|
| 693 |
+
le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
|
| 694 |
+
re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
|
| 695 |
+
mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
|
| 696 |
+
|
| 697 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
| 698 |
+
le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
|
| 699 |
+
re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
|
| 700 |
+
mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
|
| 701 |
+
le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
|
| 702 |
+
le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
|
| 703 |
+
re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
|
| 704 |
+
re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
|
| 705 |
+
mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
|
| 706 |
+
mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
|
| 707 |
+
|
| 708 |
+
le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
|
| 709 |
+
re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
|
| 710 |
+
mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
|
| 711 |
+
le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
|
| 712 |
+
le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
|
| 713 |
+
re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
|
| 714 |
+
re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
|
| 715 |
+
mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
|
| 716 |
+
mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
|
| 717 |
+
|
| 718 |
+
le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
|
| 719 |
+
re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
|
| 720 |
+
mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
|
| 721 |
+
le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
|
| 722 |
+
le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
|
| 723 |
+
re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
|
| 724 |
+
re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
|
| 725 |
+
mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
|
| 726 |
+
mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
|
| 727 |
+
else:
|
| 728 |
+
le_256_mem = le_256_mem_g
|
| 729 |
+
re_256_mem = re_256_mem_g
|
| 730 |
+
mo_256_mem = mo_256_mem_g
|
| 731 |
+
le_128_mem = le_128_mem_g
|
| 732 |
+
re_128_mem = re_128_mem_g
|
| 733 |
+
mo_128_mem = mo_128_mem_g
|
| 734 |
+
le_64_mem = le_64_mem_g
|
| 735 |
+
re_64_mem = re_64_mem_g
|
| 736 |
+
mo_64_mem = mo_64_mem_g
|
| 737 |
+
|
| 738 |
+
le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
|
| 739 |
+
re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
|
| 740 |
+
mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
|
| 741 |
+
|
| 742 |
+
####for 128
|
| 743 |
+
le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
|
| 744 |
+
re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
|
| 745 |
+
mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
|
| 746 |
+
|
| 747 |
+
####for 64
|
| 748 |
+
le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
|
| 749 |
+
re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
|
| 750 |
+
mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
|
| 754 |
+
EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
|
| 755 |
+
EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
|
| 756 |
+
Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
|
| 757 |
+
Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
|
| 758 |
+
Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
|
| 759 |
+
return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
|
| 760 |
+
|
| 761 |
+
def reconstruct(self, fs_in, locs, memstar):
|
| 762 |
+
le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
|
| 763 |
+
le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
|
| 764 |
+
le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
|
| 765 |
+
|
| 766 |
+
le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
|
| 767 |
+
re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
|
| 768 |
+
mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
|
| 769 |
+
|
| 770 |
+
le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
|
| 771 |
+
re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
|
| 772 |
+
mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
|
| 773 |
+
|
| 774 |
+
le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
|
| 775 |
+
re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
|
| 776 |
+
mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
le_location = locs[:,0,:]
|
| 780 |
+
re_location = locs[:,1,:]
|
| 781 |
+
mo_location = locs[:,3,:]
|
| 782 |
+
le_location = le_location.cpu().int().numpy()
|
| 783 |
+
re_location = re_location.cpu().int().numpy()
|
| 784 |
+
mo_location = mo_location.cpu().int().numpy()
|
| 785 |
+
|
| 786 |
+
up_in_256 = fs_in['f256'].clone()# * 0
|
| 787 |
+
up_in_128 = fs_in['f128'].clone()# * 0
|
| 788 |
+
up_in_64 = fs_in['f64'].clone()# * 0
|
| 789 |
+
|
| 790 |
+
for i in range(fs_in['f256'].size(0)):
|
| 791 |
+
up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
|
| 792 |
+
up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
|
| 793 |
+
up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
|
| 794 |
+
|
| 795 |
+
up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
|
| 796 |
+
up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
|
| 797 |
+
up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
|
| 798 |
+
|
| 799 |
+
up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
|
| 800 |
+
up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
|
| 801 |
+
up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
|
| 802 |
+
|
| 803 |
+
ms_in_64 = self.MSDilate(fs_in['f64'].clone())
|
| 804 |
+
fea_up1 = self.up1(ms_in_64, up_in_64)
|
| 805 |
+
fea_up2 = self.up2(fea_up1, up_in_128) #
|
| 806 |
+
fea_up3 = self.up3(fea_up2, up_in_256) #
|
| 807 |
+
output = self.up4(fea_up3) #
|
| 808 |
+
return output
|
| 809 |
+
|
| 810 |
+
def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
|
| 811 |
+
return self.memorize(sp_imgs, sp_locs)
|
| 812 |
+
|
| 813 |
+
def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
|
| 814 |
+
fs_in = self.E_lq(lq, loc) # low quality images
|
| 815 |
+
GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
|
| 816 |
+
GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
|
| 817 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
| 818 |
+
GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
|
| 819 |
+
GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
|
| 820 |
+
else:
|
| 821 |
+
GSOut = None
|
| 822 |
+
return GeOut, GSOut
|
| 823 |
+
|
| 824 |
+
class UpResBlock(nn.Module):
|
| 825 |
+
def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
|
| 826 |
+
super(UpResBlock, self).__init__()
|
| 827 |
+
self.Model = nn.Sequential(
|
| 828 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
| 829 |
+
nn.LeakyReLU(0.2),
|
| 830 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
| 831 |
+
)
|
| 832 |
+
def forward(self, x):
|
| 833 |
+
out = x + self.Model(x)
|
| 834 |
+
return out
|
| 835 |
+
|
plugins/plugin_faceswap.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
| 2 |
+
from roop.face_helper import get_one_face, get_many_faces, swap_face
|
| 3 |
+
import os
|
| 4 |
+
from roop.utilities import compute_cosine_distance
|
| 5 |
+
|
| 6 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
| 7 |
+
|
| 8 |
+
# start function
|
| 9 |
+
def start(core:ChainImgProcessor):
|
| 10 |
+
manifest = { # plugin settings
|
| 11 |
+
"name": "Faceswap", # name
|
| 12 |
+
"version": "1.0", # version
|
| 13 |
+
|
| 14 |
+
"default_options": {
|
| 15 |
+
"swap_mode": "selected",
|
| 16 |
+
"max_distance": 0.65, # max distance to detect face similarity
|
| 17 |
+
},
|
| 18 |
+
"img_processor": {
|
| 19 |
+
"faceswap": Faceswap
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
return manifest
|
| 23 |
+
|
| 24 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Faceswap(ChainImgPlugin):
|
| 29 |
+
|
| 30 |
+
def init_plugin(self):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def process(self, frame, params:dict):
|
| 35 |
+
if not "input_face_datas" in params or len(params["input_face_datas"]) < 1:
|
| 36 |
+
params["face_detected"] = False
|
| 37 |
+
return frame
|
| 38 |
+
|
| 39 |
+
temp_frame = frame
|
| 40 |
+
params["face_detected"] = True
|
| 41 |
+
params["processed_faces"] = []
|
| 42 |
+
|
| 43 |
+
if params["swap_mode"] == "first":
|
| 44 |
+
face = get_one_face(frame)
|
| 45 |
+
if face is None:
|
| 46 |
+
params["face_detected"] = False
|
| 47 |
+
return frame
|
| 48 |
+
params["processed_faces"].append(face)
|
| 49 |
+
frame = swap_face(params["input_face_datas"][0], face, frame)
|
| 50 |
+
return frame
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
faces = get_many_faces(frame)
|
| 54 |
+
if(len(faces) < 1):
|
| 55 |
+
params["face_detected"] = False
|
| 56 |
+
return frame
|
| 57 |
+
|
| 58 |
+
dist_threshold = params["face_distance_threshold"]
|
| 59 |
+
|
| 60 |
+
if params["swap_mode"] == "all":
|
| 61 |
+
for sf in params["input_face_datas"]:
|
| 62 |
+
for face in faces:
|
| 63 |
+
params["processed_faces"].append(face)
|
| 64 |
+
temp_frame = swap_face(sf, face, temp_frame)
|
| 65 |
+
return temp_frame
|
| 66 |
+
|
| 67 |
+
elif params["swap_mode"] == "selected":
|
| 68 |
+
for i,tf in enumerate(params["target_face_datas"]):
|
| 69 |
+
for face in faces:
|
| 70 |
+
if compute_cosine_distance(tf.embedding, face.embedding) <= dist_threshold:
|
| 71 |
+
temp_frame = swap_face(params["input_face_datas"][i], face, temp_frame)
|
| 72 |
+
params["processed_faces"].append(face)
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
elif params["swap_mode"] == "all_female" or params["swap_mode"] == "all_male":
|
| 76 |
+
gender = 'F' if params["swap_mode"] == "all_female" else 'M'
|
| 77 |
+
face_found = False
|
| 78 |
+
for face in faces:
|
| 79 |
+
if face.sex == gender:
|
| 80 |
+
face_found = True
|
| 81 |
+
if face_found:
|
| 82 |
+
params["processed_faces"].append(face)
|
| 83 |
+
temp_frame = swap_face(params["input_face_datas"][0], face, temp_frame)
|
| 84 |
+
face_found = False
|
| 85 |
+
|
| 86 |
+
return temp_frame
|
plugins/plugin_gfpgan.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
| 2 |
+
import os
|
| 3 |
+
import gfpgan
|
| 4 |
+
import threading
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from numpy import asarray
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
from roop.utilities import resolve_relative_path, conditional_download
|
| 10 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
| 11 |
+
|
| 12 |
+
model_gfpgan = None
|
| 13 |
+
THREAD_LOCK_GFPGAN = threading.Lock()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# start function
|
| 17 |
+
def start(core:ChainImgProcessor):
|
| 18 |
+
manifest = { # plugin settings
|
| 19 |
+
"name": "GFPGAN", # name
|
| 20 |
+
"version": "1.4", # version
|
| 21 |
+
|
| 22 |
+
"default_options": {},
|
| 23 |
+
"img_processor": {
|
| 24 |
+
"gfpgan": GFPGAN
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
return manifest
|
| 28 |
+
|
| 29 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GFPGAN(ChainImgPlugin):
|
| 34 |
+
|
| 35 |
+
def init_plugin(self):
|
| 36 |
+
global model_gfpgan
|
| 37 |
+
|
| 38 |
+
if model_gfpgan is None:
|
| 39 |
+
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
|
| 40 |
+
model_gfpgan = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=self.device) # type: ignore[attr-defined]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def process(self, frame, params:dict):
|
| 45 |
+
import copy
|
| 46 |
+
|
| 47 |
+
global model_gfpgan
|
| 48 |
+
|
| 49 |
+
if model_gfpgan is None:
|
| 50 |
+
return frame
|
| 51 |
+
|
| 52 |
+
if "face_detected" in params:
|
| 53 |
+
if not params["face_detected"]:
|
| 54 |
+
return frame
|
| 55 |
+
# don't touch original
|
| 56 |
+
temp_frame = copy.copy(frame)
|
| 57 |
+
if "processed_faces" in params:
|
| 58 |
+
for face in params["processed_faces"]:
|
| 59 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
| 60 |
+
padding_x = int((end_x - start_x) * 0.5)
|
| 61 |
+
padding_y = int((end_y - start_y) * 0.5)
|
| 62 |
+
start_x = max(0, start_x - padding_x)
|
| 63 |
+
start_y = max(0, start_y - padding_y)
|
| 64 |
+
end_x = max(0, end_x + padding_x)
|
| 65 |
+
end_y = max(0, end_y + padding_y)
|
| 66 |
+
temp_face = temp_frame[start_y:end_y, start_x:end_x]
|
| 67 |
+
if temp_face.size:
|
| 68 |
+
with THREAD_LOCK_GFPGAN:
|
| 69 |
+
_, _, temp_face = model_gfpgan.enhance(
|
| 70 |
+
temp_face,
|
| 71 |
+
paste_back=True
|
| 72 |
+
)
|
| 73 |
+
temp_frame[start_y:end_y, start_x:end_x] = temp_face
|
| 74 |
+
else:
|
| 75 |
+
with THREAD_LOCK_GFPGAN:
|
| 76 |
+
_, _, temp_frame = model_gfpgan.enhance(
|
| 77 |
+
temp_frame,
|
| 78 |
+
paste_back=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if not "blend_ratio" in params:
|
| 82 |
+
return temp_frame
|
| 83 |
+
|
| 84 |
+
temp_frame = Image.blend(Image.fromarray(frame), Image.fromarray(temp_frame), params["blend_ratio"])
|
| 85 |
+
return asarray(temp_frame)
|
plugins/plugin_txt2clip.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import threading
|
| 6 |
+
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from clip.clipseg import CLIPDensePredT
|
| 9 |
+
from numpy import asarray
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
| 13 |
+
|
| 14 |
+
modname = os.path.basename(__file__)[:-3] # calculating modname
|
| 15 |
+
|
| 16 |
+
model_clip = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# start function
|
| 22 |
+
def start(core:ChainImgProcessor):
|
| 23 |
+
manifest = { # plugin settings
|
| 24 |
+
"name": "Text2Clip", # name
|
| 25 |
+
"version": "1.0", # version
|
| 26 |
+
|
| 27 |
+
"default_options": {
|
| 28 |
+
},
|
| 29 |
+
"img_processor": {
|
| 30 |
+
"txt2clip": Text2Clip
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
return manifest
|
| 34 |
+
|
| 35 |
+
def start_with_options(core:ChainImgProcessor, manifest:dict):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Text2Clip(ChainImgPlugin):
|
| 41 |
+
|
| 42 |
+
def load_clip_model(self):
|
| 43 |
+
global model_clip
|
| 44 |
+
|
| 45 |
+
if model_clip is None:
|
| 46 |
+
device = torch.device(super().device)
|
| 47 |
+
model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
| 48 |
+
model_clip.eval();
|
| 49 |
+
model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
|
| 50 |
+
model_clip.to(device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def init_plugin(self):
|
| 54 |
+
self.load_clip_model()
|
| 55 |
+
|
| 56 |
+
def process(self, frame, params:dict):
|
| 57 |
+
if "face_detected" in params:
|
| 58 |
+
if not params["face_detected"]:
|
| 59 |
+
return frame
|
| 60 |
+
|
| 61 |
+
return self.mask_original(params["original_frame"], frame, params["clip_prompt"])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def mask_original(self, img1, img2, keywords):
|
| 65 |
+
global model_clip
|
| 66 |
+
|
| 67 |
+
source_image_small = cv2.resize(img1, (256,256))
|
| 68 |
+
|
| 69 |
+
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
|
| 70 |
+
mask_border = 1
|
| 71 |
+
l = 0
|
| 72 |
+
t = 0
|
| 73 |
+
r = 1
|
| 74 |
+
b = 1
|
| 75 |
+
|
| 76 |
+
mask_blur = 5
|
| 77 |
+
clip_blur = 5
|
| 78 |
+
|
| 79 |
+
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
|
| 80 |
+
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
|
| 81 |
+
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
|
| 82 |
+
img_mask /= 255
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
input_image = source_image_small
|
| 86 |
+
|
| 87 |
+
transform = transforms.Compose([
|
| 88 |
+
transforms.ToTensor(),
|
| 89 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 90 |
+
transforms.Resize((256, 256)),
|
| 91 |
+
])
|
| 92 |
+
img = transform(input_image).unsqueeze(0)
|
| 93 |
+
|
| 94 |
+
thresh = 0.5
|
| 95 |
+
prompts = keywords.split(',')
|
| 96 |
+
with THREAD_LOCK_CLIP:
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
preds = model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
|
| 99 |
+
clip_mask = torch.sigmoid(preds[0][0])
|
| 100 |
+
for i in range(len(prompts)-1):
|
| 101 |
+
clip_mask += torch.sigmoid(preds[i+1][0])
|
| 102 |
+
|
| 103 |
+
clip_mask = clip_mask.data.cpu().numpy()
|
| 104 |
+
np.clip(clip_mask, 0, 1)
|
| 105 |
+
|
| 106 |
+
clip_mask[clip_mask>thresh] = 1.0
|
| 107 |
+
clip_mask[clip_mask<=thresh] = 0.0
|
| 108 |
+
kernel = np.ones((5, 5), np.float32)
|
| 109 |
+
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
|
| 110 |
+
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
|
| 111 |
+
|
| 112 |
+
img_mask *= clip_mask
|
| 113 |
+
img_mask[img_mask<0.0] = 0.0
|
| 114 |
+
|
| 115 |
+
img_mask = cv2.resize(img_mask, (img2.shape[1], img2.shape[0]))
|
| 116 |
+
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
| 117 |
+
|
| 118 |
+
target = img2.astype(np.float32)
|
| 119 |
+
result = (1-img_mask) * target
|
| 120 |
+
result += img_mask * img1.astype(np.float32)
|
| 121 |
+
return np.uint8(result)
|
| 122 |
+
|
roop-unleashed.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1abfc4e80fb1e9e8eb3381f1d46193051b683ea452595a189bb5d647dfe7b6b
|
| 3 |
+
size 5953
|