gabrielchua commited on
Commit
cc30f3f
·
1 Parent(s): 945bc2f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +1 -106
utils.py CHANGED
@@ -4,14 +4,11 @@ utils.py
4
 
5
  # Standard imports
6
  import os
7
- from typing import List, Tuple
8
 
9
  # Third party imports
10
  import numpy as np
11
- from google import genai
12
  from openai import OpenAI
13
- from sentence_transformers import SentenceTransformer
14
- from transformers import AutoModel
15
 
16
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
 
@@ -24,14 +21,11 @@ def get_embeddings(
24
  ) -> List[List[float]]:
25
  """
26
  Generate embeddings for a list of texts using OpenAI API synchronously.
27
-
28
  Args:
29
  texts: List of strings to embed.
30
  model: OpenAI embedding model to use (default: text-embedding-3-large).
31
-
32
  Returns:
33
  A list of embeddings (each embedding is a list of floats).
34
-
35
  Raises:
36
  Exception: If the OpenAI API call fails.
37
  """
@@ -45,102 +39,3 @@ def get_embeddings(
45
  # Extract embeddings from response
46
  embeddings = np.array([data.embedding for data in response.data])
47
  return embeddings
48
-
49
-
50
- MODEL_CONFIGS = {
51
- "lionguard-2": {
52
- "label": "LionGuard 2",
53
- "repo_id": "govtech/lionguard-2",
54
- "embedding_strategy": "openai",
55
- "embedding_model": "text-embedding-3-large",
56
- },
57
- "lionguard-2-lite": {
58
- "label": "LionGuard 2 Lite",
59
- "repo_id": "govtech/lionguard-2-lite",
60
- "embedding_strategy": "sentence_transformer",
61
- "embedding_model": "google/embeddinggemma-300m",
62
- },
63
- "lionguard-2.1": {
64
- "label": "LionGuard 2.1",
65
- "repo_id": "govtech/lionguard-2.1",
66
- "embedding_strategy": "gemini",
67
- "embedding_model": "gemini-embedding-001",
68
- },
69
- }
70
-
71
- DEFAULT_MODEL_KEY = "lionguard-2.1"
72
- MODEL_CACHE = {}
73
- EMBEDDING_MODEL_CACHE = {}
74
- current_model_choice = DEFAULT_MODEL_KEY
75
- GEMINI_CLIENT = None
76
-
77
-
78
- def resolve_model_key(model_key: str = None) -> str:
79
- key = model_key or current_model_choice
80
- if key not in MODEL_CONFIGS:
81
- raise ValueError(f"Unknown model selection: {key}")
82
- return key
83
-
84
-
85
- def load_model_instance(model_key: str):
86
- key = resolve_model_key(model_key)
87
- if key not in MODEL_CACHE:
88
- repo_id = MODEL_CONFIGS[key]["repo_id"]
89
- MODEL_CACHE[key] = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
90
- return MODEL_CACHE[key]
91
-
92
-
93
- def get_sentence_transformer(model_name: str):
94
- if model_name not in EMBEDDING_MODEL_CACHE:
95
- EMBEDDING_MODEL_CACHE[model_name] = SentenceTransformer(model_name)
96
- return EMBEDDING_MODEL_CACHE[model_name]
97
-
98
-
99
- def get_gemini_client():
100
- global GEMINI_CLIENT
101
- if GEMINI_CLIENT is None:
102
- api_key = os.getenv("GEMINI_API_KEY")
103
- if not api_key:
104
- raise EnvironmentError(
105
- "GEMINI_API_KEY environment variable is required for LionGuard 2.1."
106
- )
107
- GEMINI_CLIENT = genai.Client(api_key=api_key)
108
- return GEMINI_CLIENT
109
-
110
-
111
- def get_model_embeddings(model_key: str, texts: List[str]) -> np.ndarray:
112
- key = resolve_model_key(model_key)
113
- config = MODEL_CONFIGS[key]
114
- strategy = config["embedding_strategy"]
115
- model_name = config.get("embedding_model")
116
-
117
- if strategy == "openai":
118
- return get_embeddings(texts, model=model_name)
119
- if strategy == "sentence_transformer":
120
- embedder = get_sentence_transformer(model_name)
121
- formatted_texts = [f"task: classification | query: {text}" for text in texts]
122
- embeddings = embedder.encode(formatted_texts)
123
- return np.array(embeddings)
124
- if strategy == "gemini":
125
- client = get_gemini_client()
126
- result = client.models.embed_content(model=model_name, contents=texts)
127
- return np.array([embedding.values for embedding in result.embeddings])
128
-
129
- raise ValueError(f"Unsupported embedding strategy: {strategy}")
130
-
131
-
132
- def predict_with_model(texts: List[str], model_key: str = None) -> Tuple[dict, str]:
133
- key = resolve_model_key(model_key)
134
- embeddings = get_model_embeddings(key, texts)
135
- model = load_model_instance(key)
136
- return model.predict(embeddings), key
137
-
138
-
139
- def set_active_model(model_key: str) -> str:
140
- if model_key not in MODEL_CONFIGS:
141
- return f"⚠️ Unknown model {model_key}"
142
- global current_model_choice
143
- current_model_choice = model_key
144
- load_model_instance(model_key)
145
- label = MODEL_CONFIGS[model_key]["label"]
146
- return f"🦁 Using {label} ({model_key})"
 
4
 
5
  # Standard imports
6
  import os
7
+ from typing import List
8
 
9
  # Third party imports
10
  import numpy as np
 
11
  from openai import OpenAI
 
 
12
 
13
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
 
 
21
  ) -> List[List[float]]:
22
  """
23
  Generate embeddings for a list of texts using OpenAI API synchronously.
 
24
  Args:
25
  texts: List of strings to embed.
26
  model: OpenAI embedding model to use (default: text-embedding-3-large).
 
27
  Returns:
28
  A list of embeddings (each embedding is a list of floats).
 
29
  Raises:
30
  Exception: If the OpenAI API call fails.
31
  """
 
39
  # Extract embeddings from response
40
  embeddings = np.array([data.embedding for data in response.data])
41
  return embeddings