| | from typing import Optional, List, Dict, Any |
| |
|
| | import torch |
| | from tqdm import tqdm |
| | from transformers import CLIPTokenizer |
| |
|
| | import constants |
| | from models.neti_clip_text_encoder import NeTICLIPTextModel |
| | from utils.types import NeTIBatch |
| |
|
| |
|
| | class PromptManager: |
| | """ Class for computing all time and space embeddings for a given prompt. """ |
| | def __init__(self, tokenizer: CLIPTokenizer, |
| | text_encoder: NeTICLIPTextModel, |
| | timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS, |
| | unet_layers: List[str] = constants.UNET_LAYERS, |
| | placeholder_token_id: Optional[List] = None, |
| | placeholder_token: Optional[List] = None, |
| | torch_dtype: torch.dtype = torch.float32): |
| | self.tokenizer = tokenizer |
| | self.text_encoder = text_encoder |
| | self.timesteps = timesteps |
| | self.unet_layers = unet_layers |
| | self.placeholder_token = placeholder_token |
| | self.placeholder_token_id = placeholder_token_id |
| | self.dtype = torch_dtype |
| |
|
| | def embed_prompt(self, text: str, |
| | truncation_idx: Optional[int] = None, |
| | num_images_per_prompt: int = 1) -> List[Dict[str, Any]]: |
| | """ |
| | Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}` |
| | for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples. |
| | """ |
| | text = text.format(self.placeholder_token) |
| | ids = self.tokenizer( |
| | text, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | return_tensors="pt", |
| | ).input_ids |
| |
|
| | |
| | print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.") |
| | hidden_states_per_timestep = [] |
| | for timestep in tqdm(self.timesteps): |
| | _hs = {"this_idx": 0}.copy() |
| | for layer_idx, unet_layer in enumerate(self.unet_layers): |
| | batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device), |
| | timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device), |
| | unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0), |
| | placeholder_token_id=self.placeholder_token_id, |
| | truncation_idx=truncation_idx) |
| | layer_hs, layer_hs_bypass = self.text_encoder(batch=batch) |
| | layer_hs = layer_hs[0].to(dtype=self.dtype) |
| | _hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1) |
| | if layer_hs_bypass is not None: |
| | layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype) |
| | _hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1) |
| | hidden_states_per_timestep.append(_hs) |
| | print("Done.") |
| | return hidden_states_per_timestep |
| |
|