Instructions to use Susav/PolarSparsity with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Susav/PolarSparsity with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="Susav/PolarSparsity")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Susav/PolarSparsity", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use Susav/PolarSparsity with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "Susav/PolarSparsity" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Susav/PolarSparsity", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/Susav/PolarSparsity
- SGLang
How to use Susav/PolarSparsity with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "Susav/PolarSparsity" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Susav/PolarSparsity", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "Susav/PolarSparsity" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "Susav/PolarSparsity", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use Susav/PolarSparsity with Docker Model Runner:
docker model run hf.co/Susav/PolarSparsity
| import math | |
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| from einops import rearrange | |
| from transformers import GPT2Config | |
| from collections import namedtuple | |
| from HybridTensor.modules.SelectiveMHA import SMHA, SelectMHA, ParallelSelectMHA, MHARouter, ParallelMHARouter | |
| from HybridTensor.modules.SelectiveMLP import SelectiveMLP, ParallelSelectiveMLP, MLPRouter, ParallelMLPRouter | |
| from HybridTensor.modules.SelectiveBlock import SelectBlock | |
| # from HybridTensor.modules.SelectiveBlock_v1 import SelectBlock | |
| import torch.nn.functional as F | |
| from flash_attn.utils.distributed import ( | |
| all_gather, | |
| all_gather_raw, | |
| get_dim_for_local_rank, | |
| sync_shared_params, | |
| ) | |
| from collections.abc import Sequence | |
| from flash_attn.modules.mha import MHA, ParallelMHA | |
| from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP, GatedMlp, ParallelGatedMlp, Mlp, ParallelMLP | |
| from flash_attn.ops.activations import sqrelu_fwd | |
| from flash_attn.modules.block import Block | |
| try: | |
| from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm | |
| except ImportError: | |
| layer_norm_fn, RMSNorm = None, None | |
| from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings | |
| from flash_attn.utils.distributed import sync_shared_params, all_gather_raw | |
| from flash_attn.utils.pretrained import state_dict_from_pretrained | |
| from flash_attn.utils.generation import GenerationMixin | |
| from flash_attn.models.opt import remap_state_dict_hf_opt | |
| try: | |
| from flash_attn.ops.fused_dense import ColumnParallelLinear | |
| except ImportError: | |
| ColumnParallelLinear = None | |
| try: | |
| from flash_attn.ops.triton.mlp import FusedDenseSqreluDense | |
| except ImportError: | |
| FusedDenseSqreluDense = None | |
| try: | |
| from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm | |
| except ImportError: | |
| layer_norm_fn, RMSNorm = None, None | |
| from HybridTensor.models.helper import remap_state_dict_gpt2, shard_state_dict_tp | |
| def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 | |
| softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) | |
| softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) | |
| if config.scale_attn_by_inverse_layer_idx: | |
| assert layer_idx is not None | |
| softmax_scale /= float(layer_idx + 1) | |
| dwconv = getattr(config, "attn_dwconv", False) | |
| if dwconv: | |
| assert process_group is None, "TensorParallel MHA does not support dwconv yet" | |
| qkv_proj_bias = getattr(config, "qkv_proj_bias", True) | |
| out_proj_bias = getattr(config, "out_proj_bias", True) | |
| rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) | |
| rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) | |
| rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) | |
| rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) | |
| use_alibi = getattr(config, "use_alibi", False) | |
| use_triton = getattr(config, "use_triton", True) # toggle cuda or triton decode kernels | |
| window_size = getattr(config, "window_size", (-1, -1)) | |
| use_flash_attn = getattr(config, "use_flash_attn", False) | |
| fused_bias_fc = getattr(config, "fused_bias_fc", False) | |
| if not fused_bias_fc: | |
| assert process_group is None, "TensorParallel MHA requires fused_bias_fc" | |
| mlp_sparse = getattr(config, "mlp_sparse", False) | |
| att_sparse = getattr(config, "att_sparse", False) | |
| num_heads = getattr(config, "num_attention_heads", None) | |
| n_head_kv = getattr(config, "n_head_kv", num_heads) | |
| if num_heads != n_head_kv: | |
| att_sparse = False | |
| if process_group is None: | |
| mha_cls = SMHA # SelectMHA if att_sparse else MHA | |
| else: | |
| mha_cls = ParallelSelectMHA if att_sparse else ParallelMHA | |
| # mha_cls = SelectMHA if process_group is None else ParallelSelectMHA | |
| serial_kwargs = ( | |
| {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} | |
| ) | |
| parallel_kwargs = ( | |
| { | |
| "process_group": process_group, | |
| "sequence_parallel": getattr(config, "sequence_parallel", False), | |
| } | |
| if process_group is not None | |
| else {} | |
| ) | |
| num_heads_kv = getattr(config, "n_head_kv", None) | |
| mixer_cls = partial( | |
| mha_cls, | |
| num_heads=config.num_attention_heads, | |
| num_heads_kv=num_heads_kv, | |
| qkv_proj_bias=qkv_proj_bias, | |
| out_proj_bias=out_proj_bias, | |
| dropout=config.attn_pdrop, | |
| softmax_scale=softmax_scale, | |
| causal=True, | |
| layer_idx=layer_idx, | |
| rotary_emb_dim=rotary_emb_dim, | |
| rotary_emb_base=rotary_emb_base, | |
| rotary_emb_scale_base=rotary_emb_scale_base, | |
| rotary_emb_interleaved=rotary_emb_interleaved, | |
| use_alibi=use_alibi, | |
| window_size=window_size, | |
| use_flash_attn=use_flash_attn, | |
| **serial_kwargs, | |
| **parallel_kwargs, | |
| **factory_kwargs, | |
| ) | |
| return mixer_cls | |
| def create_mlp_cls_old(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size | |
| fused_mlp = getattr(config, "fused_mlp", False) | |
| if fused_mlp: | |
| assert config.activation_function in [ | |
| "gelu_new", | |
| "gelu_fast", | |
| "gelu_approx", | |
| "gelu_pytorch_tanh", | |
| "relu", | |
| "sqrelu", | |
| ] | |
| assert fused_mlp == True, "Not supported not fused mlp for now" | |
| mlp_sparse = getattr(config, "mlp_sparse", False) | |
| use_heuristic = getattr(config, "use_heuristic", True) | |
| mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) | |
| # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer | |
| if isinstance(mlp_checkpoint_lvl, Sequence): | |
| assert layer_idx is not None | |
| mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] | |
| if fused_mlp: | |
| if FusedMLP is None: | |
| raise ImportError("fused_dense is not installed") | |
| # activation = ( | |
| # "gelu_approx" | |
| # if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] | |
| # else "relu" | |
| # ) | |
| if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]: | |
| activation = "gelu_approx" | |
| else: | |
| activation = "relu" # config.activation_function | |
| if process_group is None: | |
| mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP | |
| else: | |
| mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP | |
| parallel_kwargs = ( | |
| { | |
| "process_group": process_group, | |
| "sequence_parallel": getattr(config, "sequence_parallel", True), | |
| } | |
| if process_group is not None | |
| else {} | |
| ) | |
| sparsity_kwargs = ( | |
| { | |
| "use_heuristic": use_heuristic, | |
| } | |
| if mlp_sparse | |
| else {} | |
| ) | |
| mlp_cls = partial( | |
| mlp_cls, | |
| hidden_features=inner_dim, | |
| activation=activation, | |
| checkpoint_lvl=mlp_checkpoint_lvl, | |
| # layer_idx=layer_idx, | |
| **parallel_kwargs, | |
| **factory_kwargs, | |
| **sparsity_kwargs, | |
| ) | |
| else: | |
| raise RuntimeError("MLP type not supported") | |
| return mlp_cls | |
| def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): | |
| """ | |
| Create an MLP class that supports both sparse MLPs (via fused mlp) and GatedMLPs. | |
| If the activation function is one of "glu", "swiglu", or "geglu", then GatedMlp is used | |
| (and mlp_sparse is ignored). Otherwise, fused_mlp is used to decide between sparse and | |
| dense implementations. | |
| """ | |
| from functools import partial | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) | |
| mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) | |
| # Check for gated activations | |
| if config.activation_function in ["glu", "swiglu", "geglu"]: | |
| # For gated activations we do not support sparsity yet. | |
| activation = ( | |
| F.sigmoid if config.activation_function == "glu" | |
| else (F.silu if config.activation_function == "swiglu" else F.gelu) | |
| ) | |
| mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp | |
| parallel_kwargs = ( | |
| {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} | |
| if process_group is not None else {} | |
| ) | |
| mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) | |
| mlp_cls = partial( | |
| mlp_cls, | |
| hidden_features=config.n_inner, | |
| activation=activation, | |
| bias1=mlp_fc1_bias, | |
| bias2=mlp_fc2_bias, | |
| multiple_of=mlp_multiple_of, | |
| **parallel_kwargs, | |
| **factory_kwargs, | |
| ) | |
| return mlp_cls | |
| # For non-gated activations: | |
| fused_mlp = getattr(config, "fused_mlp", False) | |
| fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) | |
| if fused_dense_sqrelu_dense: | |
| assert config.activation_function == "sqrelu", ( | |
| "fused_dense_sqrelu_dense only supports approximate activation_function sqrelu" | |
| ) | |
| assert not (fused_dense_sqrelu_dense and fused_mlp) | |
| if fused_mlp: | |
| # Ensure valid activation function. | |
| assert config.activation_function in [ | |
| "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu" | |
| ] | |
| # Support checkpoint level (possibly a list) | |
| mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) | |
| if isinstance(mlp_checkpoint_lvl, (list, tuple)): | |
| assert layer_idx is not None | |
| mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] | |
| # Choose activation string. | |
| if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]: | |
| activation = "gelu_approx" | |
| else: | |
| activation = "relu" | |
| # Determine inner dim. | |
| inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size | |
| mlp_sparse = getattr(config, "mlp_sparse", False) | |
| use_heuristic = getattr(config, "use_heuristic", True) | |
| if process_group is None: | |
| mlp_cls = SelectiveMLP if mlp_sparse else FusedMLP | |
| else: | |
| mlp_cls = ParallelSelectiveMLP if mlp_sparse else ParallelFusedMLP | |
| parallel_kwargs = ( | |
| {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} | |
| if process_group is not None else {} | |
| ) | |
| sparsity_kwargs = {"use_heuristic": use_heuristic} if mlp_sparse else {} | |
| mlp_cls = partial( | |
| mlp_cls, | |
| hidden_features=inner_dim, | |
| activation=activation, | |
| checkpoint_lvl=mlp_checkpoint_lvl, | |
| bias1=mlp_fc1_bias, | |
| bias2=mlp_fc2_bias, | |
| **parallel_kwargs, | |
| **factory_kwargs, | |
| **sparsity_kwargs, | |
| ) | |
| return mlp_cls | |
| elif fused_dense_sqrelu_dense: | |
| if process_group is not None: | |
| assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" | |
| assert FusedDenseSqreluDense is not None | |
| mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) | |
| if isinstance(mlp_checkpoint_lvl, (list, tuple)): | |
| assert layer_idx is not None | |
| mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] | |
| mlp_cls = partial( | |
| FusedDenseSqreluDense, | |
| hidden_features=config.n_inner, | |
| checkpoint_lvl=mlp_checkpoint_lvl, | |
| **factory_kwargs, | |
| ) | |
| return mlp_cls | |
| else: | |
| # Non-fused, non-sparse branch. | |
| assert config.activation_function in [ | |
| "gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu" | |
| ] | |
| if config.activation_function == "relu": | |
| activation = partial(F.relu, inplace=True) | |
| elif config.activation_function == "sqrelu": | |
| activation = sqrelu_fwd | |
| else: | |
| approximate = "tanh" if config.activation_function in [ | |
| "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh" | |
| ] else "none" | |
| activation = partial(F.gelu, approximate=approximate) | |
| mlp_sparse = getattr(config, "mlp_sparse", False) | |
| mlp_cls = Mlp if process_group is None else ParallelMLP | |
| parallel_kwargs = ( | |
| {"process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True)} | |
| if process_group is not None else {} | |
| ) | |
| mlp_cls = partial( | |
| mlp_cls, | |
| hidden_features=config.n_inner, | |
| activation=activation, | |
| bias1=mlp_fc1_bias, | |
| bias2=mlp_fc2_bias, | |
| **parallel_kwargs, | |
| **factory_kwargs, | |
| ) | |
| return mlp_cls | |
| def create_mlp_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| num_neurons = config.n_inner if config.n_inner is not None else 4 * config.hidden_size | |
| # this can be made different per layer by adding mlp_low_rank_dim_{layer_idx} in the sp_config | |
| low_rank_dim = getattr(sp_config, "mlp_low_rank_dim", 1024) | |
| # per layer activation threshold | |
| act_th = getattr(config, "mlp_act_th", 0.5) | |
| if process_group is None: | |
| mlp_router_cls = MLPRouter | |
| else: | |
| mlp_router_cls = ParallelMLPRouter | |
| parallel_kwargs = ( | |
| { | |
| "process_group": process_group, | |
| "sequence_parallel": getattr(config, "sequence_parallel", True), | |
| } | |
| if process_group is not None | |
| else {} | |
| ) | |
| mlp_router_cls = partial(mlp_router_cls, | |
| low_rank_dim = low_rank_dim, | |
| out_dim = num_neurons, | |
| act_th = act_th, | |
| **parallel_kwargs, | |
| **factory_kwargs) | |
| return mlp_router_cls | |
| def create_mha_router_cls(config, sp_config = None, layer_idx=None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| num_heads = config.num_attention_heads | |
| n_head_kv = getattr(config, "n_head_kv", num_heads) | |
| if num_heads != n_head_kv: | |
| out_dim = n_head_kv | |
| else: | |
| out_dim = num_heads | |
| low_rank_dim = getattr(sp_config, "attn_low_rank_dim", 128) # optional, default to 128 | |
| # per layer activation topk, to make this different per layer, add a different attn_topk_{layer_idx} in the sp_config | |
| attn_topk = getattr(sp_config, "attn_topk", 0.5) | |
| if process_group is None: | |
| mha_router_cls = MHARouter | |
| else: | |
| mha_router_cls = ParallelMHARouter | |
| parallel_kwargs = ( | |
| { | |
| "process_group": process_group, | |
| "sequence_parallel": getattr(config, "sequence_parallel", True), | |
| } | |
| if process_group is not None | |
| else {} | |
| ) | |
| mha_router_cls = partial(mha_router_cls, | |
| low_rank_dim = low_rank_dim, | |
| out_dim = out_dim, | |
| top_k = attn_topk, | |
| **parallel_kwargs, | |
| **factory_kwargs) | |
| return mha_router_cls | |
| def create_block(config, sp_config, layer_idx=None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| sequence_parallel = getattr(config, "sequence_parallel", True) | |
| mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) | |
| mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) | |
| use_rms_norm = getattr(config, "rms_norm", False) | |
| norm_cls = partial( | |
| nn.LayerNorm if not use_rms_norm else RMSNorm, | |
| eps=config.layer_norm_epsilon, | |
| **factory_kwargs, | |
| ) | |
| # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable | |
| residual_in_fp32 = getattr(config, "residual_in_fp32", False) | |
| resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop | |
| prenorm = getattr(config, "prenorm", True) | |
| parallel_block = getattr(config, "parallel_block", False) | |
| mlp_sparse = getattr(config, "mlp_sparse", False) | |
| att_sparse = getattr(config, "att_sparse", False) | |
| block_sparse = mlp_sparse or att_sparse | |
| if not parallel_block: | |
| if block_sparse: | |
| mha_router_cls = create_mha_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if att_sparse else None | |
| mlp_router_cls = create_mlp_router_cls(config, sp_config, layer_idx, process_group=process_group, **factory_kwargs) if mlp_sparse else None | |
| block = SelectBlock( | |
| config.hidden_size, | |
| mixer_cls, | |
| mlp_cls, | |
| mlp_router = mlp_router_cls, | |
| mha_router = mha_router_cls, | |
| norm_cls=norm_cls, | |
| prenorm=prenorm, | |
| resid_dropout1=resid_dropout1, | |
| resid_dropout2=config.resid_pdrop, | |
| fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), | |
| residual_in_fp32=residual_in_fp32, | |
| sequence_parallel=sequence_parallel and process_group is not None, | |
| mark_shared_params=process_group is not None, | |
| ) | |
| else: | |
| block = Block( | |
| config.hidden_size, | |
| mixer_cls, | |
| mlp_cls, | |
| norm_cls=norm_cls, | |
| prenorm=prenorm, | |
| resid_dropout1=resid_dropout1, | |
| resid_dropout2=config.resid_pdrop, | |
| fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), | |
| residual_in_fp32=residual_in_fp32, | |
| sequence_parallel=sequence_parallel and process_group is not None, | |
| mark_shared_params=process_group is not None, | |
| ) | |
| else: | |
| # not implemented | |
| raise RuntimeError("ParallelBlock not implemented") | |
| block.layer_idx = layer_idx | |
| return block | |
| class GPTPreTrainedModel(nn.Module): | |
| """An abstract class to handle weights initialization and | |
| a simple interface for dowloading and loading pretrained models. | |
| """ | |
| def __init__(self, config, *inputs, **kwargs): | |
| super().__init__() | |
| if not isinstance(config, GPT2Config): | |
| raise ValueError( | |
| "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " | |
| "To create a model from a Google pretrained model use " | |
| "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
| self.__class__.__name__, self.__class__.__name__ | |
| ) | |
| ) | |
| self.config = config | |
| def from_pretrained( | |
| cls, | |
| model_name, | |
| config, | |
| sp_config, | |
| *args, | |
| strict=True, | |
| device=None, | |
| dtype=None, | |
| world_size=1, | |
| rank=0, | |
| **kwargs, | |
| ): | |
| """ | |
| Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. | |
| Download and cache the pre-trained model file if needed. | |
| """ | |
| # Instantiate model. | |
| model = cls(config, sp_config, *args, device=device, dtype=dtype, **kwargs) | |
| # Load state_dict in cpu because we already initialized the model in GPU, and we don't | |
| # want extra stuff taking up more GPU memory | |
| state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) | |
| if model_name.startswith("gpt2"): | |
| state_dict = remap_state_dict_gpt2(state_dict, config) | |
| elif model_name.startswith("facebook/opt"): | |
| state_dict = remap_state_dict_hf_opt(state_dict, config) | |
| else: | |
| raise NotImplementedError(f"Model {model_name} not supported") | |
| if world_size > 1: | |
| state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) | |
| load_return = model.load_state_dict(state_dict, strict=strict) | |
| # logger.info(load_return) | |
| return model | |
| # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 | |
| def _init_weights( | |
| module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True | |
| ): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=initializer_range) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, std=initializer_range) | |
| if rescale_prenorm_residual: | |
| # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: | |
| # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale | |
| # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. | |
| # > -- GPT-2 :: https://openai.com/blog/better-language-models/ | |
| # | |
| # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py | |
| for name, p in module.named_parameters(): | |
| if name in ["out_proj.weight", "fc2.weight"]: | |
| # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block | |
| nn.init.normal_( | |
| p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) | |
| ) | |
| class GPTModel(GPTPreTrainedModel): | |
| def __init__(self, config: GPT2Config, sp_config=None, process_group=None, device=None, dtype=None): | |
| super().__init__(config) | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| self.process_group = process_group | |
| self.sequence_parallel = getattr(config, "sequence_parallel", True) | |
| assert config.activation_function in [ | |
| "gelu", | |
| "gelu_new", | |
| "gelu_fast", | |
| "gelu_approx", | |
| "relu", | |
| "sqrelu", | |
| "glu", | |
| "swiglu", | |
| "geglu", | |
| ] | |
| pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
| vocab_size = ( | |
| math.ceil(config.vocab_size / pad_vocab_size_multiple) | |
| * pad_vocab_size_multiple | |
| ) | |
| # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable | |
| self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) | |
| # These 2 options are for OPT-350m | |
| self.prenorm = getattr(config, "prenorm", True) | |
| use_rms_norm = getattr(config, "rms_norm", False) | |
| word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) | |
| if process_group is None: | |
| self.embeddings = GPT2Embeddings( | |
| config.hidden_size, | |
| vocab_size, | |
| config.max_position_embeddings, | |
| word_embed_proj_dim=word_embed_proj_dim, | |
| **factory_kwargs, | |
| ) | |
| else: | |
| self.embeddings = ParallelGPT2Embeddings( | |
| config.hidden_size, | |
| vocab_size, | |
| config.max_position_embeddings, | |
| process_group=process_group, | |
| sequence_parallel=self.sequence_parallel, | |
| **factory_kwargs, | |
| ) | |
| # We change the order of dropout, residual and layer norm: | |
| # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: | |
| # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and | |
| # the main branch (output of MLP). The model definition is unchanged, but the mapping of the | |
| # nn.Dropout probabilities are changed. | |
| # This is for performance reason: we can fuse dropout + add + layer_norm. | |
| self.layers = nn.ModuleList( | |
| [ | |
| create_block( | |
| config, sp_config, layer_idx=i, process_group=process_group, **factory_kwargs | |
| ) | |
| for i in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) | |
| if self.fused_dropout_add_ln: | |
| if layer_norm_fn is None: | |
| raise ImportError("Triton is not installed") | |
| if self.prenorm: | |
| self.drop_f = nn.Dropout(config.resid_pdrop) | |
| norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm | |
| # self.ln_f = nn.LayerNorm( | |
| # config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs | |
| # ) | |
| self.ln_f = norm_cls( | |
| config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs | |
| ) | |
| if process_group is not None: | |
| for p in self.ln_f.parameters(): | |
| # Mark the norm parameters as "shared_params" so that we sync their values at init. | |
| p._shared_params = True | |
| # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. | |
| if self.sequence_parallel: | |
| p._sequence_parallel = True | |
| self.apply( | |
| partial( | |
| _init_weights, | |
| n_layer=config.num_hidden_layers, | |
| initializer_range=config.initializer_range, | |
| ) | |
| ) | |
| self.tie_weights() | |
| self.sparse = False | |
| if config.mlp_sparse or config.att_sparse: | |
| self.sparse = True | |
| def tie_weights(self): | |
| if self.process_group is not None: | |
| sync_shared_params(self, self.process_group) | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| return { | |
| i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) | |
| for i, layer in enumerate(self.layers) | |
| } | |
| def forward(self, input_ids, position_ids=None, inference_params=None): | |
| # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen | |
| # dimensions so that we can split on it easily, in case of small batch size. | |
| # Only the attention layers need to know the seqlen. | |
| embedding_kwargs = ( | |
| {"combine_batch_seqlen_dim": True} | |
| if self.process_group is not None and self.sequence_parallel | |
| else {} | |
| ) | |
| hidden_states = self.embeddings( | |
| input_ids, position_ids=position_ids, **embedding_kwargs | |
| ) | |
| residual = None | |
| mixer_kwargs = ( | |
| {"seqlen": input_ids.shape[1]} | |
| if self.process_group is not None and self.sequence_parallel | |
| else {} | |
| ) | |
| if inference_params is not None: | |
| mixer_kwargs["inference_params"] = inference_params | |
| else: | |
| mixer_kwargs["inference_params"] = None | |
| # else: | |
| for layer in self.layers: | |
| if self.prenorm: | |
| hidden_states, residual = layer( | |
| hidden_states, | |
| residual, | |
| mixer_kwargs=mixer_kwargs, | |
| ) | |
| else: | |
| hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) | |
| if self.prenorm: | |
| if not self.fused_dropout_add_ln: | |
| dropped = self.drop_f(hidden_states) | |
| residual = (dropped + residual) if residual is not None else dropped | |
| hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) | |
| else: | |
| # Set prenorm=False here since we don't need the residual | |
| if hidden_states.shape != residual.shape: | |
| hidden_states = hidden_states.view(residual.shape) | |
| hidden_states = layer_norm_fn( | |
| hidden_states, | |
| self.ln_f.weight, | |
| self.ln_f.bias, | |
| residual=residual, | |
| x1=None, | |
| eps=self.ln_f.eps, | |
| dropout_p=self.drop_f.p if self.training else 0.0, | |
| prenorm=False, | |
| is_rms_norm=isinstance(self.ln_f, RMSNorm) | |
| ) | |
| return hidden_states | |
| class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): | |
| def __init__(self, config: GPT2Config, sp_config = None, process_group=None, device=None, dtype=None): | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__(config) | |
| self.process_group = process_group | |
| self.transformer = GPTModel( | |
| config, sp_config, process_group=process_group, **factory_kwargs | |
| ) | |
| self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) | |
| pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) | |
| vocab_size = ( | |
| math.ceil(config.vocab_size / pad_vocab_size_multiple) | |
| * pad_vocab_size_multiple | |
| ) | |
| # This option is for OPT-350m | |
| word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) | |
| embed_dim = ( | |
| config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim | |
| ) | |
| if word_embed_proj_dim is not None: | |
| self.project_out = nn.Linear( | |
| config.n_embd, embed_dim, bias=False, **factory_kwargs | |
| ) | |
| else: | |
| self.project_out = None | |
| mup_width_scale = getattr(config, "mup_width_scale", 1.0) | |
| mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) | |
| self.output_scale = mup_output_multiplier * mup_width_scale | |
| if process_group is None: | |
| self.lm_head = nn.Linear( | |
| embed_dim, vocab_size, bias=False, **factory_kwargs | |
| ) | |
| else: | |
| if ColumnParallelLinear is None: | |
| raise ImportError("fused_dense_lib is not installed") | |
| self.lm_head = ColumnParallelLinear( | |
| embed_dim, | |
| vocab_size, | |
| process_group, | |
| bias=False, | |
| sequence_parallel=getattr(config, "sequence_parallel", True), | |
| **factory_kwargs, | |
| ) | |
| self.norm_head = getattr(config, "norm_head", False) | |
| # Initialize weights and apply final processing | |
| self.apply( | |
| partial( | |
| _init_weights, | |
| n_layer=config.num_hidden_layers, | |
| initializer_range=config.initializer_range, | |
| ) | |
| ) | |
| self.tie_weights() | |
| def tie_weights(self): | |
| if self.tie_word_embeddings: | |
| self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight # llama does not use tied weights | |
| if self.process_group is not None: | |
| sync_shared_params(self, self.process_group) | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): | |
| return self.transformer.allocate_inference_cache( | |
| batch_size, max_seqlen, dtype=dtype, **kwargs | |
| ) | |
| def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): | |
| """ | |
| input_ids: (batch, seqlen) int tensor | |
| inference_params: for generation. Adapted from Megatron-LM (and Apex) | |
| https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 | |
| num_last_tokens: if > 0, only return the logits for the last n tokens | |
| """ | |
| assert ( | |
| input_ids.ndim == 2 | |
| ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" | |
| b, slen = input_ids.shape | |
| hidden_states = self.transformer( | |
| input_ids, position_ids=position_ids, inference_params=inference_params | |
| ) | |
| if inference_params is not None: | |
| assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" | |
| if num_last_tokens > 0: | |
| hidden_states = hidden_states[:, -num_last_tokens:] | |
| if self.project_out is not None: | |
| hidden_states = self.project_out(hidden_states) | |
| if self.output_scale != 1.0: | |
| hidden_states = hidden_states * self.output_scale | |
| if not self.norm_head: | |
| lm_logits = self.lm_head(hidden_states) | |
| else: | |
| lm_head_weight = F.normalize(self.lm_head.weight) | |
| if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: | |
| hidden_states = all_gather(hidden_states, self.lm_head.process_group) | |
| lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) | |
| # During inference, we want the full logit for sampling | |
| if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: | |
| lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) | |
| lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) | |
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) | |
| return CausalLMOutput(logits=lm_logits) | |
| def load_state_dict(self, state_dict, strict=True): | |
| # Remapping from our checkpoints that used a different ordering of layers in the block | |
| # Previous: Attn / MLP -> Dropout -> Add -> LN | |
| # Current: Dropout -> Add -> LN -> Attn / MLP | |
| if "transformer.ln_0.weight" in state_dict: | |
| n_layers = len(self.transformer.layers) | |
| ln_weight = state_dict.pop( | |
| f"transformer.layers.{n_layers - 1}.norm2.weight" | |
| ) | |
| ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") | |
| state_dict["transformer.ln_f.weight"] = ln_weight | |
| state_dict["transformer.ln_f.bias"] = ln_bias | |
| for l in reversed(range(n_layers)): | |
| ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") | |
| ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") | |
| state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight | |
| state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias | |
| if l > 0: | |
| ln_weight = state_dict.pop( | |
| f"transformer.layers.{l - 1}.norm2.weight" | |
| ) | |
| ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") | |
| state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight | |
| state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias | |
| ln_weight = state_dict.pop("transformer.ln_0.weight") | |
| ln_bias = state_dict.pop("transformer.ln_0.bias") | |
| state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight | |
| state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias | |
| return super().load_state_dict(state_dict, strict=strict) |