import torch from ._ops import ops def gemm_int4_forward( input: torch.Tensor, weight: torch.Tensor, zeros: torch.Tensor, absmax: torch.Tensor, blocksize: int, ) -> torch.Tensor: original_dtype = input.dtype if original_dtype != torch.bfloat16: input = input.to(torch.bfloat16) output = ops.gemm_int4_forward(input, weight, zeros, absmax, blocksize) if original_dtype != torch.bfloat16: output = output.to(original_dtype) return output