Image Feature Extraction
Transformers
JAX
Safetensors
MLX
PyTorch
aimv2_vision_model
vision
custom_code
Eval Results (legacy)
Instructions to use apple/aimv2-large-patch14-336 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use apple/aimv2-large-patch14-336 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="apple/aimv2-large-patch14-336", trust_remote_code=True)# Load model directly from transformers import AutoImageProcessor, AutoModel processor = AutoImageProcessor.from_pretrained("apple/aimv2-large-patch14-336", trust_remote_code=True) model = AutoModel.from_pretrained("apple/aimv2-large-patch14-336", trust_remote_code=True) - MLX
How to use apple/aimv2-large-patch14-336 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir aimv2-large-patch14-336 apple/aimv2-large-patch14-336
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| from typing import Any, Optional, Tuple, Union | |
| import flax.linen as nn | |
| import jax | |
| import jax.numpy as jnp | |
| from .configuration_aimv2 import AIMv2Config | |
| from flax.core import frozen_dict | |
| from transformers import FlaxPreTrainedModel | |
| from transformers.modeling_flax_outputs import FlaxBaseModelOutput | |
| __all__ = ["FlaxAIMv2Model"] | |
| class FlaxRMSNorm(nn.Module): | |
| eps: float = 1e-6 | |
| def __call__(self, x: jax.Array) -> jax.Array: | |
| dim = x.shape[-1] | |
| scale = self.param("scale", nn.initializers.ones_init(), (dim,)) | |
| output = self._norm(x.astype(jnp.float32)).astype(x.dtype) | |
| output = output * scale.astype(x.dtype) | |
| return output | |
| def _norm(self, x: jax.Array) -> jax.Array: | |
| return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps) | |
| class FlaxAIMv2SwiGLUFFN(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__(self, x: jax.Array) -> jax.Array: | |
| hidden_features = self.config.intermediate_size | |
| in_features = self.config.hidden_size | |
| bias = self.config.use_bias | |
| x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x) | |
| x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x) | |
| x = nn.silu(x1) * x2 | |
| x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x) | |
| return x | |
| class FlaxAIMv2PatchEmbed(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__(self, x: jax.Array) -> jax.Array: | |
| patch_size = (self.config.patch_size, self.config.patch_size) | |
| x = x.transpose(0, 2, 3, 1) # (N C H W) -> (N H W C) | |
| x = nn.Conv( | |
| self.config.hidden_size, | |
| kernel_size=patch_size, | |
| strides=patch_size, | |
| padding=(0, 0), | |
| dtype=self.dtype, | |
| name="proj", | |
| )(x) | |
| x = jax.lax.collapse(x, 1, 3) # (N, H * W, F) | |
| x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x) | |
| return x | |
| class FlaxAIMv2ViTPreprocessor(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__(self, x: jax.Array) -> jax.Array: | |
| tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")( | |
| x | |
| ) | |
| _, N, _ = tokens.shape | |
| pos_embed = self.param( | |
| "pos_embed", | |
| nn.initializers.normal(stddev=0.02), | |
| (1, self.num_patches, self.config.hidden_size), | |
| ) | |
| tokens = tokens + pos_embed[:, :N].astype(tokens.dtype) | |
| return tokens | |
| def num_patches(self) -> int: | |
| return (self.config.image_size // self.config.patch_size) ** 2 | |
| class FlaxAIMv2Attention(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__( | |
| self, | |
| x: jax.Array, | |
| mask: Optional[jax.Array] = None, | |
| deterministic: bool = True, | |
| output_attentions: bool = False, | |
| ) -> Tuple[jax.Array, Optional[jax.Array]]: | |
| B, N, C = x.shape | |
| dim, num_heads = self.config.hidden_size, self.config.num_attention_heads | |
| qkv = nn.Dense( | |
| dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv" | |
| )(x) | |
| qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn_weights = nn.dot_product_attention_weights( | |
| q.swapaxes(-3, -2), # [B, N, H, C] | |
| k.swapaxes(-3, -2), | |
| mask=mask, | |
| deterministic=deterministic, | |
| dtype=self.dtype, | |
| ) | |
| attn_weights = nn.Dropout( | |
| self.config.attention_dropout, deterministic=deterministic, name="attn_drop" | |
| )(attn_weights) | |
| x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C) | |
| x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")( | |
| x | |
| ) | |
| x = nn.Dropout( | |
| self.config.projection_dropout, | |
| deterministic=deterministic, | |
| name="proj_drop", | |
| )(x) | |
| return (x, attn_weights) if output_attentions else (x, None) | |
| class FlaxAIMv2Block(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def setup(self): | |
| self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn") | |
| self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1") | |
| self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp") | |
| self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2") | |
| def __call__( | |
| self, | |
| x: jax.Array, | |
| mask: Optional[jax.Array] = None, | |
| deterministic: bool = True, | |
| output_attentions: bool = False, | |
| ) -> Tuple[jax.Array, Optional[jax.Array]]: | |
| features, attention = self.attn( | |
| self.norm_1(x), | |
| mask, | |
| deterministic=deterministic, | |
| output_attentions=output_attentions, | |
| ) | |
| x = x + features | |
| x = x + self.mlp(self.norm_2(x)) | |
| return x, attention | |
| class FlaxAIMv2Transformer(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__( | |
| self, | |
| tokens: jax.Array, | |
| mask: Optional[jax.Array] = None, | |
| deterministic: bool = True, | |
| output_attentions: bool = False, | |
| output_hidden_states: bool = False, | |
| ) -> Tuple[ | |
| jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]] | |
| ]: | |
| hidden_states = () if output_hidden_states else None | |
| attentions = () if output_attentions else None | |
| for blk_id, block in enumerate(range(self.config.num_hidden_layers)): | |
| tokens, attention = FlaxAIMv2Block( | |
| self.config, dtype=self.dtype, name=f"layers_{blk_id}" | |
| )( | |
| tokens, | |
| mask, | |
| deterministic=deterministic, | |
| output_attentions=output_attentions, | |
| ) | |
| if output_hidden_states: | |
| hidden_states += (tokens,) | |
| if output_attentions: | |
| attentions += (attention,) | |
| tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens) | |
| return tokens, hidden_states, attentions | |
| class FlaxAIMv2Module(nn.Module): | |
| config: AIMv2Config | |
| dtype: jnp.dtype = jnp.float32 | |
| def __call__( | |
| self, | |
| x: jax.Array, | |
| mask: Optional[jax.Array] = None, | |
| deterministic: bool = True, | |
| output_attentions: bool = False, | |
| output_hidden_states: bool = False, | |
| ) -> Tuple[ | |
| jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]] | |
| ]: | |
| x = FlaxAIMv2ViTPreprocessor( | |
| self.config, dtype=self.dtype, name="preprocessor" | |
| )(x) | |
| x, hidden_states, attentions = FlaxAIMv2Transformer( | |
| self.config, dtype=self.dtype, name="trunk" | |
| )( | |
| x, | |
| mask, | |
| deterministic=deterministic, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| return x, hidden_states, attentions | |
| class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel): | |
| config_class = AIMv2Config | |
| base_model_prefix = "aimv2" | |
| main_input_name = "pixel_values" | |
| def __init__( | |
| self, | |
| config: AIMv2Config, | |
| input_shape: Optional[Tuple[int, int, int, int]] = None, # [B, C, H, W] | |
| dtype: jnp.dtype = jnp.float32, | |
| **kwargs: Any, | |
| ): | |
| if input_shape is None: | |
| input_shape = (1, 3, config.image_size, config.image_size) | |
| super().__init__( | |
| config, | |
| module=FlaxAIMv2Module(config, dtype=dtype), | |
| input_shape=input_shape, | |
| dtype=dtype, | |
| **kwargs, | |
| ) | |
| def init_weights( | |
| self, | |
| rng: jax.Array, | |
| input_shape: Tuple[int, ...], | |
| params: Optional[frozen_dict.FrozenDict] = None, | |
| ) -> frozen_dict.FrozenDict: | |
| del params | |
| input_pixels = jnp.empty(input_shape) | |
| params = self.module.init(rng, input_pixels, deterministic=True) | |
| return params["params"] | |
| class FlaxAIMv2Model(FlaxAIMv2PretrainedModel): | |
| def __call__( | |
| self, | |
| pixel_values: jax.Array, | |
| params: Optional[frozen_dict.FrozenDict] = None, | |
| mask: Optional[jax.Array] = None, | |
| dropout_rng: Optional[jax.Array] = None, | |
| deterministic: bool = True, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[ | |
| Tuple[jax.Array], | |
| Tuple[jax.Array, Tuple[jax.Array, ...]], | |
| Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]], | |
| FlaxBaseModelOutput, | |
| ]: | |
| if params is None: | |
| params = self.params | |
| if output_attentions is None: | |
| output_attentions = self.config.output_attentions | |
| if output_hidden_states is None: | |
| output_hidden_states = self.config.output_hidden_states | |
| if return_dict is None: | |
| return_dict = self.config.use_return_dict | |
| rngs = None if deterministic else {"dropout": dropout_rng} | |
| x, hidden_states, attentions = self.module.apply( | |
| {"params": params}, | |
| pixel_values, | |
| mask, | |
| rngs=rngs, | |
| deterministic=deterministic, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| if not return_dict: | |
| res = (x,) | |
| res += (hidden_states,) if output_hidden_states else () | |
| res += (attentions,) if output_attentions else () | |
| return res | |
| return FlaxBaseModelOutput( | |
| last_hidden_state=x, | |
| hidden_states=hidden_states, | |
| attentions=attentions, | |
| ) | |