import gradio as gr import joblib, numpy as np, torch, transformers # ── load artefacts once ─────────────────────────────────────────────── rf = joblib.load("models/RandomForest_tfidf_pipe.joblib") thr = np.load("models/RandomForest_tfidf_thresh.npy") mlb = joblib.load("models/mlb.joblib") embedder = transformers.AutoModel.from_pretrained( "BAAI/bge-m3", trust_remote_code=True ).eval() tokenizer = transformers.AutoTokenizer.from_pretrained("BAAI/bge-m3") def predict(payload: dict): """ payload: a dict of key→value pairs. Concatenate all values (in arbitrary order) into one text blob. """ # 1. concatenate all values into one string text = " ".join(str(v) for v in payload.values()) # 2. embed + classify with torch.inference_mode(): inputs = tokenizer(text, return_tensors="pt", truncation=True) vec = embedder(**inputs).pooler_output.numpy() proba = rf.predict_proba(vec)[0] mask = proba >= thr return { "labels": mlb.classes_[mask].tolist(), "scores": proba[mask].round(4).tolist() } # ── Gradio needs exactly this global var at import time ──────────────── demo = gr.Interface( fn=predict, inputs=gr.JSON(label="Input JSON payload"), outputs=gr.JSON(label="Predicted labels & scores"), title="Drill-Category Classifier", description="Send a JSON object; its values will be concatenated and classified.", api_name="predict" ) if __name__ == "__main__": demo.queue() # optional: enables concurrency & progress events demo.launch()