xxnithicxx commited on
Commit
1f4d19b
·
1 Parent(s): 92756c6

update tab 2

Browse files
Files changed (2) hide show
  1. app.py +20 -12
  2. utils/visualizations.py +11 -4
app.py CHANGED
@@ -70,7 +70,18 @@ def initialize_app():
70
 
71
  init_clustering_models(scaled_features, original_features, models_dir)
72
 
73
- return data_loader, cm, raw_data, scaled_features, original_features
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  # Global variables (will be initialized at app startup)
@@ -79,6 +90,7 @@ cm = None
79
  raw_data = None
80
  scaled_features = None
81
  original_features = None
 
82
 
83
 
84
  # ============================================================================
@@ -213,13 +225,10 @@ def create_tab2():
213
  )
214
 
215
  def update_pca_plot(k):
216
- """Update PCA plot based on selected K."""
217
- if k not in cm.cluster_labels:
218
- return None
219
-
220
- labels = cm.cluster_labels[k]
221
- pca_fig = plot_clusters_pca_2d(cm.pca_features, labels, k)
222
- return pca_fig
223
 
224
  pca_plot = gr.Plot(
225
  label="Scatter Plot: PC1 vs PC2",
@@ -236,8 +245,7 @@ def create_tab2():
236
  **How to Use:**
237
  - Each **point** represents one customer
238
  - **Color** indicates which cluster the customer belongs to
239
- - Hover over points to see **CustomerID**
240
- - When changing K, clusters will be recalculated
241
  """)
242
 
243
 
@@ -342,10 +350,10 @@ def create_tab3():
342
 
343
  def main():
344
  """Main Gradio app."""
345
- global data_loader, cm, raw_data, scaled_features, original_features
346
 
347
  print("Starting Gradio app initialization...")
348
- data_loader, cm, raw_data, scaled_features, original_features = initialize_app()
349
  print("App initialized successfully!")
350
 
351
  # Create interface
 
70
 
71
  init_clustering_models(scaled_features, original_features, models_dir)
72
 
73
+ # Pre-compute all PCA plots for Tab 2 (K=2 to K=10)
74
+ print("Pre-computing PCA plots for all K values...")
75
+ pca_plots_cache = {}
76
+ for k in range(2, 11):
77
+ if k in cm.cluster_labels:
78
+ labels = cm.cluster_labels[k]
79
+ pca_plots_cache[k] = plot_clusters_pca_2d(cm.pca_features, labels, k)
80
+ print(f" Cached PCA plot for K={k}")
81
+
82
+ print("All PCA plots cached successfully!")
83
+
84
+ return data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache
85
 
86
 
87
  # Global variables (will be initialized at app startup)
 
90
  raw_data = None
91
  scaled_features = None
92
  original_features = None
93
+ pca_plots_cache = None
94
 
95
 
96
  # ============================================================================
 
225
  )
226
 
227
  def update_pca_plot(k):
228
+ """Update PCA plot based on selected K (from cache)."""
229
+ if k in pca_plots_cache:
230
+ return pca_plots_cache[k]
231
+ return None
 
 
 
232
 
233
  pca_plot = gr.Plot(
234
  label="Scatter Plot: PC1 vs PC2",
 
245
  **How to Use:**
246
  - Each **point** represents one customer
247
  - **Color** indicates which cluster the customer belongs to
248
+ - When changing K, clusters will be instantly updated from cache
 
249
  """)
250
 
251
 
 
350
 
351
  def main():
352
  """Main Gradio app."""
353
+ global data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache
354
 
355
  print("Starting Gradio app initialization...")
356
+ data_loader, cm, raw_data, scaled_features, original_features, pca_plots_cache = initialize_app()
357
  print("App initialized successfully!")
358
 
359
  # Create interface
utils/visualizations.py CHANGED
@@ -9,6 +9,8 @@ import numpy as np
9
  import plotly.graph_objects as go
10
  import plotly.express as px
11
  from plotly.subplots import make_subplots
 
 
12
 
13
 
14
  def create_kpi_display(kpi_metrics):
@@ -120,7 +122,7 @@ def plot_hourly_daily_heatmap(df):
120
  ))
121
 
122
  fig.update_layout(
123
- title="Mẫu mua hàng: Giờ trong ngày x Ngày trong tuần",
124
  xaxis_title="Giờ trong ngày",
125
  yaxis_title="Ngày trong tuần",
126
  height=400,
@@ -204,7 +206,7 @@ def plot_elbow_silhouette(inertias, silhouette_scores, k_range=range(2, 11)):
204
 
205
  def plot_clusters_pca_2d(pca_features, cluster_labels, k):
206
  """
207
- Plot clusters in 2D PCA space.
208
 
209
  Args:
210
  pca_features: DataFrame with PCA features
@@ -216,18 +218,23 @@ def plot_clusters_pca_2d(pca_features, cluster_labels, k):
216
  """
217
  df_plot = pca_features.copy()
218
  df_plot['Cluster'] = cluster_labels
219
- df_plot['CustomerID'] = pca_features.index
220
 
 
221
  fig = px.scatter(
222
  df_plot,
223
  x='PC1', y='PC2',
224
  color='Cluster',
225
- hover_data={'CustomerID': True, 'PC1': ':.3f', 'PC2': ':.3f'},
226
  color_continuous_scale='Viridis',
227
  title=f'Phân cụm K-Means (K={k}) - Không gian PCA',
228
  labels={'Cluster': 'Cluster'},
229
  )
230
 
 
 
 
 
 
231
  fig.update_layout(
232
  height=500,
233
  template='plotly_white',
 
9
  import plotly.graph_objects as go
10
  import plotly.express as px
11
  from plotly.subplots import make_subplots
12
+ from functools import lru_cache
13
+ import hashlib
14
 
15
 
16
  def create_kpi_display(kpi_metrics):
 
122
  ))
123
 
124
  fig.update_layout(
125
+ title="Heatmap thời gian mua hàng: Giờ trong ngày x Ngày trong tuần",
126
  xaxis_title="Giờ trong ngày",
127
  yaxis_title="Ngày trong tuần",
128
  height=400,
 
206
 
207
  def plot_clusters_pca_2d(pca_features, cluster_labels, k):
208
  """
209
+ Plot clusters in 2D PCA space with minimal hover data for performance.
210
 
211
  Args:
212
  pca_features: DataFrame with PCA features
 
218
  """
219
  df_plot = pca_features.copy()
220
  df_plot['Cluster'] = cluster_labels
 
221
 
222
+ # Minimal hover data for faster rendering
223
  fig = px.scatter(
224
  df_plot,
225
  x='PC1', y='PC2',
226
  color='Cluster',
227
+ hover_data={'PC1': ':.2f', 'PC2': ':.2f'},
228
  color_continuous_scale='Viridis',
229
  title=f'Phân cụm K-Means (K={k}) - Không gian PCA',
230
  labels={'Cluster': 'Cluster'},
231
  )
232
 
233
+ fig.update_traces(
234
+ marker=dict(size=4, opacity=0.7),
235
+ hovertemplate='<b>Cluster %{customdata[0]}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<extra></extra>'
236
+ )
237
+
238
  fig.update_layout(
239
  height=500,
240
  template='plotly_white',