arubaDev commited on
Commit
225355f
·
verified ·
1 Parent(s): 4574c3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -3,6 +3,10 @@ import sqlite3
3
  from datetime import datetime
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
 
 
 
 
6
 
7
  # ---------------------------
8
  # Config
@@ -10,7 +14,7 @@ from huggingface_hub import InferenceClient
10
  MODELS = {
11
  "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
12
  "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
13
- # Later you can add your fine-tuned backend-focused model here
14
  # "Backend-Finetuned Model": "your-username/backend-crud-model"
15
  }
16
 
@@ -22,7 +26,9 @@ SYSTEM_DEFAULT = (
22
  "Always prioritize database, API, authentication, routing, migrations, and CRUD logic. "
23
  "Provide full backend code scaffolds with files, paths, and commands. "
24
  "Only include frontend if required for framework integration "
25
- "(e.g., Laravel Blade, Django templates). Ignore other frontend/UI tasks."
 
 
26
  )
27
 
28
  # ---------------------------
@@ -120,6 +126,23 @@ def update_session_title_if_needed(session_id: int, first_user_text: str):
120
  conn.commit()
121
  conn.close()
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # ---------------------------
124
  # Helpers
125
  # ---------------------------
@@ -131,9 +154,11 @@ def label_to_id(label: str | None) -> int | None:
131
  except Exception:
132
  return None
133
 
134
- def build_api_messages(session_id: int, system_message: str):
 
135
  msgs = [{"role": "system", "content": system_message.strip()}]
136
  msgs.extend(get_messages(session_id))
 
137
  return msgs
138
 
139
  def get_client(model_choice: str):
@@ -178,7 +203,7 @@ def send_cb(user_text, selected_label, chatbot_msgs, system_message, max_tokens,
178
  add_message(sid, "user", user_text)
179
  update_session_title_if_needed(sid, user_text)
180
 
181
- api_messages = build_api_messages(sid, system_message)
182
  display_msgs = get_messages(sid)
183
  display_msgs.append({"role": "assistant", "content": ""})
184
 
 
3
  from datetime import datetime
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
+ from datasets import load_dataset
7
+ from sentence_transformers import SentenceTransformer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import numpy as np
10
 
11
  # ---------------------------
12
  # Config
 
14
  MODELS = {
15
  "Meta LLaMA 3.1 (8B Instruct)": "meta-llama/Llama-3.1-8B-Instruct",
16
  "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
17
+ # Add your backend-focused fine-tuned model here if available
18
  # "Backend-Finetuned Model": "your-username/backend-crud-model"
19
  }
20
 
 
26
  "Always prioritize database, API, authentication, routing, migrations, and CRUD logic. "
27
  "Provide full backend code scaffolds with files, paths, and commands. "
28
  "Only include frontend if required for framework integration "
29
+ "(e.g., Laravel Blade, Django templates). Ignore other frontend/UI tasks. "
30
+ "If user asks for excessive frontend work, politely respond: "
31
+ "'I am a backend assistant and focus only on backend tasks.'"
32
  )
33
 
34
  # ---------------------------
 
126
  conn.commit()
127
  conn.close()
128
 
129
+ # ---------------------------
130
+ # Dataset & Embeddings Setup
131
+ # ---------------------------
132
+ print("Loading dataset and embeddings... (this runs only once)")
133
+ dataset = load_dataset("codeparrot/codeparrot-clean-python", split="train[:5%]") # small % for speed
134
+ backend_snippets = [d["content"] for d in dataset if any(k in d["content"].lower() for k in
135
+ ["db", "database", "api", "crud", "auth", "routing", "migration"])]
136
+
137
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
138
+ snippet_embeddings = embed_model.encode(backend_snippets, convert_to_numpy=True)
139
+
140
+ def get_relevant_snippets(user_text, top_k=3):
141
+ user_emb = embed_model.encode([user_text], convert_to_numpy=True)
142
+ sims = cosine_similarity(user_emb, snippet_embeddings)[0]
143
+ top_indices = np.argsort(sims)[-top_k:][::-1]
144
+ return "\n\n".join([backend_snippets[i] for i in top_indices])
145
+
146
  # ---------------------------
147
  # Helpers
148
  # ---------------------------
 
154
  except Exception:
155
  return None
156
 
157
+ def build_api_messages(session_id: int, system_message: str, user_text: str):
158
+ relevant_snippets = get_relevant_snippets(user_text)
159
  msgs = [{"role": "system", "content": system_message.strip()}]
160
  msgs.extend(get_messages(session_id))
161
+ msgs.append({"role": "user", "content": relevant_snippets + "\n\n" + user_text})
162
  return msgs
163
 
164
  def get_client(model_choice: str):
 
203
  add_message(sid, "user", user_text)
204
  update_session_title_if_needed(sid, user_text)
205
 
206
+ api_messages = build_api_messages(sid, system_message, user_text)
207
  display_msgs = get_messages(sid)
208
  display_msgs.append({"role": "assistant", "content": ""})
209