366degrees commited on
Commit
4a40b5a
·
verified ·
1 Parent(s): b8181bc

Update inference_handler.py

Browse files
Files changed (1) hide show
  1. inference_handler.py +40 -10
inference_handler.py CHANGED
@@ -1,20 +1,51 @@
1
  from typing import Dict, Any
2
  import torch
3
- from transformers import AutoConfig, AutoTokenizer
4
- from snp_universal_embedding import CustomSNPModel, CustomSNPConfig
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class EndpointHandler:
7
- def __init__(self, model_dir):
8
  print(f"Loading model from {model_dir}")
9
- config = CustomSNPConfig.from_pretrained(model_dir)
10
- self.model = CustomSNPModel(config)
11
- state = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
12
- self.model.load_state_dict(state, strict=False)
13
  self.model.eval()
14
- print("✅ Custom SNP model loaded successfully!")
15
 
16
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
- """Called for each inference request"""
18
  inputs = data.get("inputs") or data
19
  if isinstance(inputs, dict) and "text" in inputs:
20
  text = inputs["text"]
@@ -25,7 +56,6 @@ class EndpointHandler:
25
 
26
  with torch.no_grad():
27
  outputs = self.model(**encoded)
28
- # Get mean pooled embedding
29
  if hasattr(outputs, "last_hidden_state"):
30
  emb = outputs.last_hidden_state.mean(dim=1).tolist()
31
  elif isinstance(outputs, tuple):
 
1
  from typing import Dict, Any
2
  import torch
3
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel, PretrainedConfig
4
+ import torch.nn as nn
5
 
6
+ # ============================================================
7
+ # Register Custom SNP Architecture
8
+ # ============================================================
9
+ class CustomSNPConfig(PretrainedConfig):
10
+ model_type = "custom_snp"
11
+
12
+
13
+ class CustomSNPModel(PreTrainedModel):
14
+ config_class = CustomSNPConfig
15
+
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+ hidden_size = getattr(config, "hidden_size", 768)
19
+ self.encoder = nn.Linear(hidden_size, hidden_size)
20
+ self.mirror_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
21
+ self.prism_head = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh())
22
+ self.projection = nn.Linear(hidden_size, 6)
23
+
24
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
25
+ x = self.encoder(input_ids.float()) if input_ids is not None else None
26
+ x = self.mirror_head(x)
27
+ x = self.prism_head(x)
28
+ return self.projection(x)
29
+
30
+ # Register classes so Transformers recognizes "custom_snp"
31
+ AutoConfig.register("custom_snp", CustomSNPConfig)
32
+ AutoModel.register(CustomSNPConfig, CustomSNPModel)
33
+
34
+
35
+ # ============================================================
36
+ # Endpoint Handler
37
+ # ============================================================
38
  class EndpointHandler:
39
+ def __init__(self, model_dir: str):
40
  print(f"Loading model from {model_dir}")
41
+
42
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
43
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
44
+ self.model = AutoModel.from_pretrained(model_dir, config=config, trust_remote_code=True)
45
  self.model.eval()
46
+ print("✅ Custom SNP model loaded successfully.")
47
 
48
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
49
  inputs = data.get("inputs") or data
50
  if isinstance(inputs, dict) and "text" in inputs:
51
  text = inputs["text"]
 
56
 
57
  with torch.no_grad():
58
  outputs = self.model(**encoded)
 
59
  if hasattr(outputs, "last_hidden_state"):
60
  emb = outputs.last_hidden_state.mean(dim=1).tolist()
61
  elif isinstance(outputs, tuple):