# -*- coding: utf-8 -*- """ Gradio Web Demo for Customer Segmentation Project A comprehensive interactive interface showcasing: - Tab 1: Dashboard with KPIs and EDA visualizations - Tab 2: Clustering Playground with interactive K selection - Tab 3: Customer DNA analysis with Radar charts - Tab 4: Segment Prediction for new customers """ import sys import os import numpy as np import pandas as pd import gradio as gr from datetime import datetime, timedelta from pathlib import Path # Add src path for clustering_library sys.path.insert(0, '../src') # Import utilities from utils.data_loader import DataLoader, get_data_loader from utils.clustering_models import ClusteringModels, init_clustering_models from utils.visualizations import ( create_kpi_display, plot_revenue_over_time, plot_hourly_daily_heatmap, plot_elbow_silhouette, plot_clusters_pca_2d, plot_radar_chart, create_cluster_stats_table, ) from sklearn.preprocessing import StandardScaler # ============================================================================ # INITIALIZATION # ============================================================================ def initialize_app(): """Initialize the Gradio app with data and models.""" print("Initializing app...") # Load data data_loader = DataLoader("./data/processed") scaled_features = data_loader.scaled_features original_features = data_loader.original_features raw_data = data_loader.raw_data # Initialize clustering models models_dir = "./models" cm = ClusteringModels(scaled_features, original_features, models_dir) # Try to load existing models, otherwise train them if Path(models_dir).exists() and any(Path(models_dir).glob("kmeans_k*.pkl")): print("Loading pre-trained models...") cm.load_models(k_range=range(2, 11)) else: print("Models not found. Training models...") cm.train_models(k_range=range(2, 11)) cm.apply_pca(n_components=None) cm.save_models() # If PCA wasn't loaded, apply it if cm.pca_features is None: print("Applying PCA...") cm.apply_pca(n_components=None) init_clustering_models(scaled_features, original_features, models_dir) # Pre-compute all PCA plots for Tab 2 (K=2 to K=10) print("Pre-computing PCA plots for all K values...") pca_plots_cache = {} for k in range(2, 11): if k in cm.cluster_labels: labels = cm.cluster_labels[k] pca_plots_cache[k] = plot_clusters_pca_2d(cm.pca_features, labels, k) print(f" Cached PCA plot for K={k}") print("All PCA plots cached successfully!") return data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache # Global variables (will be initialized at app startup) data_loader = None cm = None raw_data = None scaled_features = None original_features = None pca_plots_cache = None # ============================================================================ # TAB 1: DASHBOARD OVERVIEW # ============================================================================ def get_kpi_data(): """Get KPI metrics.""" return data_loader.get_kpi_metrics() def get_dashboard_plots(): """Get dashboard plots (cached).""" kpi_metrics = get_kpi_data() kpi_html = create_kpi_display(kpi_metrics) # Revenue plot revenue_fig = plot_revenue_over_time(raw_data) # Heatmap heatmap_fig = plot_hourly_daily_heatmap(raw_data) return kpi_html, revenue_fig, heatmap_fig def create_tab1(): """Create Tab 1: Dashboard Overview.""" with gr.TabItem("Dashboard - Overview"): gr.Markdown("# Data Overview Analysis") # KPI Metrics (HTML display) kpi_html, revenue_fig, heatmap_fig = get_dashboard_plots() gr.HTML(kpi_html) gr.Markdown("## Revenue Over Time") # Date range picker for revenue with gr.Row(): date_start = gr.DateTime( label="From Date", value=raw_data["InvoiceDate"].min() ) date_end = gr.DateTime( label="To Date", value=raw_data["InvoiceDate"].max() ) revenue_plot = gr.Plot( label="Revenue Chart", value=revenue_fig ) # Update revenue plot when dates change def update_revenue_plot(start, end): if start is None or end is None: return revenue_fig return plot_revenue_over_time(raw_data, start, end) date_start.change( fn=update_revenue_plot, inputs=[date_start, date_end], outputs=revenue_plot ) date_end.change( fn=update_revenue_plot, inputs=[date_start, date_end], outputs=revenue_plot ) gr.Markdown("## Shopping Behavior by Hour and Day") gr.Plot( label="Shopping Activity Heatmap", value=heatmap_fig ) gr.Markdown(""" ### Insights: - **Heatmap** shows shopping patterns by hour (0-23) and day of week - **Revenue Over Time** shows overall sales trend (12 months) - Filter by date range to zoom into peak months (Christmas, etc.) """) # ============================================================================ # TAB 2: CLUSTERING PLAYGROUND # ============================================================================ def get_optimal_clusters_data(): """Get Elbow and Silhouette data (cached).""" k_list = list(range(2, 11)) inertias = cm.inertias silhouette_scores = cm.silhouette_scores return inertias, silhouette_scores, k_list def create_tab2(): """Create Tab 2: Clustering Playground.""" with gr.TabItem("Clustering - Playground"): gr.Markdown("# Explore K-Means Clustering Algorithm") gr.Markdown(""" Adjust the slider to select different numbers of clusters (K) and see how the algorithm divides customers into different groups. """) # Get optimal data inertias, silhouette_scores, k_list = get_optimal_clusters_data() # Elbow + Silhouette plot (static, cached) gr.Markdown("## Determine Optimal Number of Clusters") optimal_fig = plot_elbow_silhouette(inertias, silhouette_scores, range(2, 11)) gr.Plot(value=optimal_fig) gr.Markdown(""" **Explanation:** - **Elbow Method**: Find the "elbow" point where increasing K doesn't significantly reduce inertia - **Silhouette Score**: Higher is better. Clusters are more distinct when score is high - **Recommendation**: K=3 or K=4 are both good choices """) # K slider and PCA visualization gr.Markdown("## Visualize Clusters in PCA Space") k_slider = gr.Slider( minimum=2, maximum=10, value=4, step=1, label="Select number of clusters (K)", interactive=True ) def update_pca_plot(k): """Update PCA plot based on selected K (from cache).""" if k in pca_plots_cache: return pca_plots_cache[k] return None pca_plot = gr.Plot( label="Scatter Plot: PC1 vs PC2", value=update_pca_plot(4) # Default k=4 ) k_slider.change( fn=update_pca_plot, inputs=k_slider, outputs=pca_plot ) gr.Markdown(""" **How to Use:** - Each **point** represents one customer - **Color** indicates which cluster the customer belongs to - When changing K, clusters will be instantly updated from cache """) # ============================================================================ # TAB 3: CUSTOMER DNA # ============================================================================ def create_tab3(): """Create Tab 3: Customer DNA.""" with gr.TabItem("Analysis - Customer DNA"): gr.Markdown("# Deep Analysis: Characteristics of Each Cluster") gr.Markdown(""" Select a cluster to see detailed characteristics of customers in it. The Radar chart shows how this cluster differs from the overall average. """) # Get available clusters (K=3 and K=4) k_choices = [3, 4] with gr.Row(): k_select = gr.Radio( choices=k_choices, value=4, label="Select Model (K clusters)" ) cluster_select = gr.Dropdown( choices=[0, 1, 2, 3], value=0, label="Select Cluster", interactive=True ) def update_cluster_choices(k): """Update cluster choices based on selected K.""" return gr.Dropdown( choices=list(range(k)), value=0, interactive=True ) k_select.change( fn=update_cluster_choices, inputs=k_select, outputs=cluster_select ) # Radar chart gr.Markdown("### Radar Chart - Comparison with Overall Average") def update_radar_and_stats(k, cluster_idx): """Update radar chart and statistics.""" cluster_info = cm.get_cluster_info(k) cluster_means = cluster_info["means"] # Create radar chart for selected cluster radar_fig = plot_radar_chart(cluster_means, k, cluster_idx=cluster_idx) # Create stats table stats_df = create_cluster_stats_table(cluster_means, k) return radar_fig, stats_df radar_plot = gr.Plot(label="Radar Chart") stats_table = gr.Dataframe(label="Detailed Statistics") # Update when K or cluster changes k_select.change( fn=update_radar_and_stats, inputs=[k_select, cluster_select], outputs=[radar_plot, stats_table] ) cluster_select.change( fn=update_radar_and_stats, inputs=[k_select, cluster_select], outputs=[radar_plot, stats_table] ) # Initial load initial_k = 4 cluster_info = cm.get_cluster_info(initial_k) initial_radar = plot_radar_chart(cluster_info["means"], initial_k, cluster_idx=0) initial_stats = create_cluster_stats_table(cluster_info["means"], initial_k) radar_plot.value = initial_radar stats_table.value = initial_stats gr.Markdown(""" ### How to Read Radar Chart: - **Each axis = 1 customer characteristic** (normalized 0-1 scale) - **Further from center = higher value** for that characteristic - **Shape of polygon** represents the cluster's profile - **Compare clusters** by looking at shape and size """) # ============================================================================ # MAIN APP # ============================================================================ def main(): """Main Gradio app.""" global data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache print("Starting Gradio app initialization...") data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache = initialize_app() print("App initialized successfully!") # Create interface with gr.Blocks( title="Customer Segmentation Demo" ) as demo: # Header gr.Markdown(""" # Customer Segmentation - Advanced Analysis Interactive demo showcasing customer clustering analysis with K-Means. Explore data stories, clustering patterns, and predict segments for new customers. """) # Tabs create_tab1() create_tab2() create_tab3() # Footer gr.Markdown(""" --- **Project:** Advanced Customer Segmentation **Data:** Online Retail (2010-2011) - Customers: 3,920+ - Transactions: 354,000+ **Built from:** Project by Dr.Nguyen Thai Ha """) return demo if __name__ == "__main__": demo = main() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )