Spaces:
Running
on
Zero
Running
on
Zero
| import os, json | |
| import math, random | |
| from multiprocessing import Pool | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| from transformers import CLIPTextModel | |
| from transformers import PretrainedConfig | |
| def pad_spec(spec, spec_length, pad_value=0, random_crop=True): # spec: [3, mel_dim, spec_len] | |
| assert spec_length % 8 == 0, "spec_length must be divisible by 8" | |
| if spec.shape[-1] < spec_length: | |
| # pad spec to spec_length | |
| spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value) | |
| else: | |
| # random crop | |
| if random_crop: | |
| start = random.randint(0, spec.shape[-1] - spec_length) | |
| spec = spec[:, :, start:start+spec_length] | |
| else: | |
| spec = spec[:, :, :spec_length] | |
| return spec | |
| def load_spec(spec_path): | |
| if spec_path.endswith(".pt"): | |
| spec = torch.load(spec_path, map_location="cpu") | |
| elif spec_path.endswith(".npy"): | |
| spec = torch.from_numpy(np.load(spec_path)) | |
| else: | |
| raise ValueError(f"Unknown spec file type {spec_path}") | |
| assert len(spec.shape) == 3, f"spec shape must be [3, mel_dim, spec_len], got {spec.shape}" | |
| if spec.size(0) == 1: | |
| spec = spec.repeat(3, 1, 1) | |
| return spec | |
| def random_crop_spec(spec, target_spec_length, pad_value=0, frame_per_sec=100, time_step=5): # spec: [3, mel_dim, spec_len] | |
| assert target_spec_length % 8 == 0, "spec_length must be divisible by 8" | |
| spec_length = spec.shape[-1] | |
| full_s = math.ceil(spec_length / frame_per_sec / time_step) * time_step # get full seconds(ceil) | |
| start_s = random.randint(0, math.floor(spec_length / frame_per_sec / time_step)) * time_step # random get start seconds | |
| end_s = min(start_s + math.ceil(target_spec_length / frame_per_sec), full_s) # get end seconds | |
| spec = spec[:, :, start_s * frame_per_sec : end_s * frame_per_sec] # get spec in seconds(crop more than target_spec_length because ceiling) | |
| if spec.shape[-1] < target_spec_length: | |
| spec = F.pad(spec, (0, target_spec_length - spec.shape[-1]), value=pad_value) # pad to target_spec_length | |
| else: | |
| spec = spec[:, :, :target_spec_length] # crop to target_spec_length | |
| return spec, int(start_s), int(end_s), int(full_s) | |
| def load_condion_embed(text_embed_path): | |
| if text_embed_path.endswith(".pt"): | |
| text_embed_list = torch.load(text_embed_path, map_location="cpu") | |
| elif text_embed_path.endswith(".npy"): | |
| text_embed_list = torch.from_numpy(np.load(text_embed_path)) | |
| else: | |
| raise ValueError(f"Unknown text embedding file type {text_embed_path}") | |
| if type(text_embed_list) == list: | |
| text_embed = random.choice(text_embed_list) | |
| if len(text_embed.shape) == 3: # [1, text_len, text_dim] | |
| text_embed = text_embed.squeeze(0) # random choice and return text_emb: [text_len, text_dim] | |
| return text_embed.detach().cpu() | |
| def process_condition_embed(cond_emb, max_length): # [text_len, text_dim], Padding 0 and random drop by CFG | |
| if cond_emb.shape[0] < max_length: | |
| cond_emb = F.pad(cond_emb, (0, 0, 0, max_length - cond_emb.shape[0]), value=0) | |
| else: | |
| cond_emb = cond_emb[:max_length, :] | |
| return cond_emb | |
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| if "t5" in model_class.lower(): | |
| from transformers import T5EncoderModel | |
| return T5EncoderModel | |
| if "clap" in model_class.lower(): | |
| from transformers import ClapTextModelWithProjection | |
| return ClapTextModelWithProjection | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def str2bool(string): | |
| str2val = {"True": True, "False": False, "true": True, "false": False, "none": False, "None": False} | |
| if string in str2val: | |
| return str2val[string] | |
| else: | |
| raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") | |
| def str2str(string): | |
| if string.lower() == "none" or string.lower() == "null" or string.lower() == "false" or string == "": | |
| return None | |
| else: | |
| return string | |
| def json_dump(data_json, json_save_path): | |
| with open(json_save_path, 'w') as f: | |
| json.dump(data_json, f, indent=4) | |
| f.close() | |
| def json_load(json_path): | |
| with open(json_path, 'r') as f: | |
| data = json.load(f) | |
| f.close() | |
| return data | |
| def load_json_list(path): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| return [json.loads(line) for line in f.readlines()] | |
| def save_json_list(data, path): | |
| with open(path, 'w', encoding='utf-8') as f: | |
| for d in data: | |
| f.write(json.dumps(d) + '\n') | |
| def multiprocess_function(func, func_args, n_jobs=32): | |
| with Pool(processes=n_jobs) as p: | |
| with tqdm(total=len(func_args)) as pbar: | |
| for i, _ in enumerate(p.imap_unordered(func, func_args)): | |
| pbar.update() | |
| def image_add_color(spec_img): | |
| cmap = plt.get_cmap('viridis') | |
| cmap_r = cmap.reversed() | |
| image = cmap(np.array(spec_img)[:,:,0])[:, :, :3] # 省略透明度通道 | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| image = Image.fromarray(np.uint8(image*255)) | |
| return image | |
| def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: | |
| """ | |
| Convert a PyTorch tensor to a NumPy image. | |
| """ | |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
| return images | |
| def numpy_to_pil(images): | |
| """ | |
| Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| if images.shape[-1] == 1: | |
| # special case for grayscale (single channel) images | |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
| else: | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| ### CODE FOR INPAITING ### | |
| def normalize(images): | |
| """ | |
| Normalize an image array to [-1,1]. | |
| """ | |
| if images.min() >= 0: | |
| return 2.0 * images - 1.0 | |
| else: | |
| return images | |
| def denormalize(images): | |
| """ | |
| Denormalize an image array to [0,1]. | |
| """ | |
| if images.min() < 0: | |
| return (images / 2 + 0.5).clamp(0, 1) | |
| else: | |
| return images.clamp(0, 1) | |
| def prepare_mask_and_masked_image(image, mask): | |
| """ | |
| Prepare a binary mask and the masked image. | |
| Parameters: | |
| - image (torch.Tensor): The input image tensor of shape [3, height, width] with values in the range [0, 1]. | |
| - mask (torch.Tensor): The input mask tensor of shape [1, height, width]. | |
| Returns: | |
| - tuple: A tuple containing the binary mask and the masked image. | |
| """ | |
| # Noralize image to [0,1] | |
| if image.max() > 1: | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| # Normalize image from [0,1] to [-1,1] | |
| if image.min() >= 0: | |
| image = normalize(image) | |
| # Apply the mask to the image | |
| masked_image = image * (mask < 0.5) | |
| return mask, masked_image | |
| def torch_to_pil(image): | |
| """ | |
| Convert a torch tensor to a PIL image. | |
| """ | |
| if image.min() < 0: | |
| image = denormalize(image) | |
| return transforms.ToPILImage()(image.cpu().detach().squeeze()) | |
| # class TextEncoderAdapter(nn.Module): | |
| # def __init__(self, hidden_size, cross_attention_dim=768): | |
| # super(TextEncoderAdapter, self).__init__() | |
| # self.hidden_size = hidden_size | |
| # self.cross_attention_dim = cross_attention_dim | |
| # self.proj = nn.Linear(self.hidden_size, self.cross_attention_dim) | |
| # self.norm = torch.nn.LayerNorm(self.cross_attention_dim) | |
| # def forward(self, x): | |
| # x = self.proj(x) | |
| # x = self.norm(x) | |
| # return x | |
| # def save_pretrained(self, save_directory, subfolder=""): | |
| # if subfolder: | |
| # save_directory = os.path.join(save_directory, subfolder) | |
| # os.makedirs(save_directory, exist_ok=True) | |
| # ckpt_path = os.path.join(save_directory, "adapter.pt") | |
| # config_path = os.path.join(save_directory, "config.json") | |
| # config = {"hidden_size": self.hidden_size, "cross_attention_dim": self.cross_attention_dim} | |
| # json_dump(config, config_path) | |
| # torch.save(self.state_dict(), ckpt_path) | |
| # print(f"Saving adapter model to {ckpt_path}") | |
| # @classmethod | |
| # def from_pretrained(cls, load_directory, subfolder=""): | |
| # if subfolder: | |
| # load_directory = os.path.join(load_directory, subfolder) | |
| # ckpt_path = os.path.join(load_directory, "adapter.pt") | |
| # config_path = os.path.join(load_directory, "config.json") | |
| # config = json_load(config_path) | |
| # instance = cls(**config) | |
| # instance.load_state_dict(torch.load(ckpt_path)) | |
| # print(f"Loading adapter model from {ckpt_path}") | |
| # return instance | |
| class ConditionAdapter(nn.Module): | |
| def __init__(self, config): | |
| super(ConditionAdapter, self).__init__() | |
| self.config = config | |
| self.proj = nn.Linear(self.config["condition_dim"], self.config["cross_attention_dim"]) | |
| self.norm = torch.nn.LayerNorm(self.config["cross_attention_dim"]) | |
| print(f"INITIATED: ConditionAdapter: {self.config}") | |
| def forward(self, x): | |
| x = self.proj(x) | |
| x = self.norm(x) | |
| return x | |
| def from_pretrained(cls, pretrained_model_name_or_path): | |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") | |
| ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt") | |
| config = json_load(config_path) | |
| instance = cls(config) | |
| instance.load_state_dict(torch.load(ckpt_path)) | |
| print(f"LOADED: ConditionAdapter from {pretrained_model_name_or_path}") | |
| return instance | |
| def save_pretrained(self, pretrained_model_name_or_path): | |
| os.makedirs(pretrained_model_name_or_path, exist_ok=True) | |
| config_path = os.path.join(pretrained_model_name_or_path, "config.json") | |
| ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt") | |
| json_dump(self.config, config_path) | |
| torch.save(self.state_dict(), ckpt_path) | |
| print(f"SAVED: ConditionAdapter {self.config['condition_adapter_name']} to {pretrained_model_name_or_path}") | |
| # class TextEncoderWrapper(CLIPTextModel): | |
| # def __init__(self, text_encoder, text_encoder_adapter): | |
| # super().__init__(text_encoder.config) | |
| # self.text_encoder = text_encoder | |
| # self.adapter = text_encoder_adapter | |
| # def forward(self, input_ids, **kwargs): | |
| # outputs = self.text_encoder(input_ids, **kwargs) | |
| # adapted_output = self.adapter(outputs[0]) | |
| # return [adapted_output] # to compatible with last_hidden_state | |