yuhueng commited on
Commit
196bcc7
·
verified ·
1 Parent(s): 47dcd62

test: Testing system prompt

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -15,26 +15,20 @@ model = AutoModelForCausalLM.from_pretrained(
15
  torch_dtype=torch.float16,
16
  )
17
 
18
- # --- 1. Configuration ---
19
  REPO_ID = "govtech/lionguard-v1"
20
  EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
21
  FILENAME = "models/lionguard-binary.onnx"
22
 
23
- # --- 2. Load Models ---
24
  embedder = SentenceTransformer(EMBEDDING_MODEL)
25
 
26
  model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
27
  session = ort.InferenceSession(model_path)
28
 
29
- # --- 3. The Inference Logic ---
30
  def check_safety(text):
31
- # Generate embedding (Normalize is important for BGE models)
32
  embedding = embedder.encode([text], normalize_embeddings=True)
33
 
34
- # Prepare input for ONNX
35
  input_name = session.get_inputs()[0].name
36
-
37
- # Run prediction
38
  pred = session.run(None, {input_name: embedding.astype(np.float32)})[0]
39
 
40
  return "Unsafe" if pred[0] == 1 else "Safe"
@@ -44,9 +38,15 @@ def check_safety(text):
44
  def inference(prompt: str, max_tokens: int = 256) -> str:
45
  model.to("cuda") # Move to GPU inside decorated function
46
 
47
- messages = [
48
- {"role" : "user", "content" : prompt}
49
- ]
 
 
 
 
 
 
50
  text = tokenizer.apply_chat_template(
51
  messages,
52
  tokenize = False,
 
15
  torch_dtype=torch.float16,
16
  )
17
 
 
18
  REPO_ID = "govtech/lionguard-v1"
19
  EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
20
  FILENAME = "models/lionguard-binary.onnx"
21
 
 
22
  embedder = SentenceTransformer(EMBEDDING_MODEL)
23
 
24
  model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
25
  session = ort.InferenceSession(model_path)
26
 
 
27
  def check_safety(text):
 
28
  embedding = embedder.encode([text], normalize_embeddings=True)
29
 
 
30
  input_name = session.get_inputs()[0].name
31
+
 
32
  pred = session.run(None, {input_name: embedding.astype(np.float32)})[0]
33
 
34
  return "Unsafe" if pred[0] == 1 else "Safe"
 
38
  def inference(prompt: str, max_tokens: int = 256) -> str:
39
  model.to("cuda") # Move to GPU inside decorated function
40
 
41
+ SYSTEM_PROMPT = """You are having a casual conversation with a user in Singapore.
42
+ Keep responses helpful and friendly. Avoid sensitive topics like politics, religion, or race.
43
+ If asked about harmful activities, politely decline."""
44
+
45
+ messages = [
46
+ {"role": "system", "content": SYSTEM_PROMPT},
47
+ {"role": "user", "content": prompt}
48
+ ]
49
+
50
  text = tokenizer.apply_chat_template(
51
  messages,
52
  tokenize = False,