366degrees commited on
Commit
08d197a
·
verified ·
1 Parent(s): c155ae4

Upload inference_handler.py

Browse files
Files changed (1) hide show
  1. inference_handler.py +42 -0
inference_handler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import AutoConfig, AutoTokenizer
4
+ from snp_universal_embedding import CustomSNPModel, CustomSNPConfig
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, model_dir: str):
8
+ print(f"Loading model from {model_dir}")
9
+
10
+ # --- Tokenizer ---
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
12
+ print("✅ Tokenizer loaded.")
13
+
14
+ # --- Config & Model ---
15
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
16
+ self.model = CustomSNPModel(config)
17
+ state = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
18
+ self.model.load_state_dict(state, strict=False)
19
+ self.model.eval()
20
+ print("✅ Custom SNP model loaded and ready.")
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ """Called for each inference request"""
24
+ inputs = data.get("inputs") or data
25
+ if isinstance(inputs, dict) and "text" in inputs:
26
+ text = inputs["text"]
27
+ else:
28
+ text = str(inputs)
29
+
30
+ encoded = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
31
+
32
+ with torch.no_grad():
33
+ outputs = self.model(**encoded)
34
+ # Get mean pooled embedding
35
+ if hasattr(outputs, "last_hidden_state"):
36
+ emb = outputs.last_hidden_state.mean(dim=1).tolist()
37
+ elif isinstance(outputs, tuple):
38
+ emb = outputs[0].mean(dim=1).tolist()
39
+ else:
40
+ emb = outputs.tolist()
41
+
42
+ return {"embeddings": emb}