366degrees commited on
Commit
b8181bc
·
verified ·
1 Parent(s): 08d197a

Update inference_handler.py

Browse files
Files changed (1) hide show
  1. inference_handler.py +36 -42
inference_handler.py CHANGED
@@ -1,42 +1,36 @@
1
- from typing import Dict, Any
2
- import torch
3
- from transformers import AutoConfig, AutoTokenizer
4
- from snp_universal_embedding import CustomSNPModel, CustomSNPConfig
5
-
6
- class EndpointHandler:
7
- def __init__(self, model_dir: str):
8
- print(f"Loading model from {model_dir}")
9
-
10
- # --- Tokenizer ---
11
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
12
- print("✅ Tokenizer loaded.")
13
-
14
- # --- Config & Model ---
15
- config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
16
- self.model = CustomSNPModel(config)
17
- state = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
18
- self.model.load_state_dict(state, strict=False)
19
- self.model.eval()
20
- print("✅ Custom SNP model loaded and ready.")
21
-
22
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
- """Called for each inference request"""
24
- inputs = data.get("inputs") or data
25
- if isinstance(inputs, dict) and "text" in inputs:
26
- text = inputs["text"]
27
- else:
28
- text = str(inputs)
29
-
30
- encoded = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
31
-
32
- with torch.no_grad():
33
- outputs = self.model(**encoded)
34
- # Get mean pooled embedding
35
- if hasattr(outputs, "last_hidden_state"):
36
- emb = outputs.last_hidden_state.mean(dim=1).tolist()
37
- elif isinstance(outputs, tuple):
38
- emb = outputs[0].mean(dim=1).tolist()
39
- else:
40
- emb = outputs.tolist()
41
-
42
- return {"embeddings": emb}
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from transformers import AutoConfig, AutoTokenizer
4
+ from snp_universal_embedding import CustomSNPModel, CustomSNPConfig
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, model_dir):
8
+ print(f"Loading model from {model_dir}")
9
+ config = CustomSNPConfig.from_pretrained(model_dir)
10
+ self.model = CustomSNPModel(config)
11
+ state = torch.load(f"{model_dir}/pytorch_model.bin", map_location="cpu")
12
+ self.model.load_state_dict(state, strict=False)
13
+ self.model.eval()
14
+ print("✅ Custom SNP model loaded successfully!")
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ """Called for each inference request"""
18
+ inputs = data.get("inputs") or data
19
+ if isinstance(inputs, dict) and "text" in inputs:
20
+ text = inputs["text"]
21
+ else:
22
+ text = str(inputs)
23
+
24
+ encoded = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
25
+
26
+ with torch.no_grad():
27
+ outputs = self.model(**encoded)
28
+ # Get mean pooled embedding
29
+ if hasattr(outputs, "last_hidden_state"):
30
+ emb = outputs.last_hidden_state.mean(dim=1).tolist()
31
+ elif isinstance(outputs, tuple):
32
+ emb = outputs[0].mean(dim=1).tolist()
33
+ else:
34
+ emb = outputs.tolist()
35
+
36
+ return {"embeddings": emb}