Aiden McComiskey
Changing the model to de-duped
8d89728
import gradio as gr
import joblib, numpy as np, torch, transformers
# ── load artefacts once ───────────────────────────────────────────────
rf = joblib.load("models/RandomForest_tfidf_pipe.joblib")
thr = np.load("models/RandomForest_tfidf_thresh.npy")
mlb = joblib.load("models/mlb.joblib")
embedder = transformers.AutoModel.from_pretrained(
"BAAI/bge-m3", trust_remote_code=True
).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained("BAAI/bge-m3")
def predict(payload: dict):
"""
payload: a dict of key→value pairs.
Concatenate all values (in arbitrary order) into one text blob.
"""
# 1. concatenate all values into one string
text = " ".join(str(v) for v in payload.values())
# 2. embed + classify
with torch.inference_mode():
inputs = tokenizer(text, return_tensors="pt", truncation=True)
vec = embedder(**inputs).pooler_output.numpy()
proba = rf.predict_proba(vec)[0]
mask = proba >= thr
return {
"labels": mlb.classes_[mask].tolist(),
"scores": proba[mask].round(4).tolist()
}
# ── Gradio needs exactly this global var at import time ────────────────
demo = gr.Interface(
fn=predict,
inputs=gr.JSON(label="Input JSON payload"),
outputs=gr.JSON(label="Predicted labels & scores"),
title="Drill-Category Classifier",
description="Send a JSON object; its values will be concatenated and classified.",
api_name="predict"
)
if __name__ == "__main__":
demo.queue() # optional: enables concurrency & progress events
demo.launch()