|
|
|
|
|
""" |
|
|
Script to pre-train and save K-Means models for the Gradio app. |
|
|
Run this once to generate models/ folder with trained models. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, '../src') |
|
|
|
|
|
from utils.data_loader import DataLoader |
|
|
from utils.clustering_models import ClusteringModels |
|
|
import os |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Train and save models.""" |
|
|
|
|
|
print("=" * 70) |
|
|
print("TRAINING K-MEANS MODELS FOR GRADIO APP") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print("\n[1/4] Loading data...") |
|
|
data_loader = DataLoader("./data/processed") |
|
|
scaled_features = data_loader.scaled_features |
|
|
original_features = data_loader.original_features |
|
|
|
|
|
print(f" Scaled features shape: {scaled_features.shape}") |
|
|
print(f" Original features shape: {original_features.shape}") |
|
|
|
|
|
|
|
|
print("\n[2/4] Initializing clustering models...") |
|
|
models_dir = "./models" |
|
|
os.makedirs(models_dir, exist_ok=True) |
|
|
|
|
|
cm = ClusteringModels(scaled_features, original_features, models_dir) |
|
|
|
|
|
|
|
|
print("\n[3/4] Training K-Means models (k=2 to k=10)...") |
|
|
cm.train_models(k_range=range(2, 11)) |
|
|
|
|
|
|
|
|
print("\n[4/4] Applying PCA for visualization...") |
|
|
cm.apply_pca(n_components=None) |
|
|
|
|
|
|
|
|
print("\n[5/5] Saving models to disk...") |
|
|
cm.save_models() |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TRAINING COMPLETED SUCCESSFULLY!") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print("\nSummary:") |
|
|
print(f" Models saved: {len(cm.kmeans_models)} (k=2 to k={max(cm.kmeans_models.keys())})") |
|
|
print(f" PCA components: {cm.pca_features.shape[1]}") |
|
|
|
|
|
print("\n✓ Checking models...") |
|
|
print("\nSilhouette Scores by K:") |
|
|
for k, score in zip(range(2, 11), cm.silhouette_scores): |
|
|
print(f" k={k}: {score:.4f}") |
|
|
|
|
|
best_k = range(2, 11)[cm.silhouette_scores.index(max(cm.silhouette_scores))] |
|
|
print(f"\nBest K (by Silhouette Score): {best_k}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|