Rishab7310 commited on
Commit
ad62dd2
·
verified ·
1 Parent(s): f0d80b6

Create utils/visualization.py

Browse files
Files changed (1) hide show
  1. utils/visualization.py +356 -0
utils/visualization.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for Kolam images and training progress.
3
+ """
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ import seaborn as sns
9
+ from pathlib import Path
10
+ from typing import List, Optional, Tuple
11
+ import cv2
12
+
13
+
14
+ def plot_kolam_grid(images: List[np.ndarray], grid_size: Tuple[int, int] = (4, 4),
15
+ save_path: Optional[str] = None, title: str = "Kolam Designs",
16
+ figsize: Tuple[int, int] = (12, 12)) -> None:
17
+ """Plot a grid of Kolam images."""
18
+ rows, cols = grid_size
19
+ fig, axes = plt.subplots(rows, cols, figsize=figsize)
20
+
21
+ # Flatten axes for easier indexing
22
+ if rows == 1:
23
+ axes = [axes]
24
+ elif cols == 1:
25
+ axes = [[ax] for ax in axes]
26
+ else:
27
+ axes = axes.flatten()
28
+
29
+ for i, ax in enumerate(axes):
30
+ if i < len(images):
31
+ # Ensure image is in correct format
32
+ img = images[i]
33
+ if isinstance(img, torch.Tensor):
34
+ img = img.detach().cpu().numpy()
35
+
36
+ # Handle different image formats
37
+ if len(img.shape) == 3 and img.shape[0] == 1:
38
+ img = img.squeeze(0)
39
+ elif len(img.shape) == 3 and img.shape[2] == 1:
40
+ img = img.squeeze(2)
41
+
42
+ # Normalize to [0, 1] if needed
43
+ if img.max() > 1.0:
44
+ img = img / 255.0
45
+
46
+ ax.imshow(img, cmap='gray', vmin=0, vmax=1)
47
+ else:
48
+ ax.axis('off')
49
+
50
+ ax.set_xticks([])
51
+ ax.set_yticks([])
52
+ ax.set_aspect('equal')
53
+
54
+ plt.suptitle(title, fontsize=16, fontweight='bold')
55
+ plt.tight_layout()
56
+
57
+ if save_path:
58
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
59
+ print(f"Grid saved to {save_path}")
60
+
61
+ plt.show()
62
+
63
+
64
+ def plot_training_curves(train_losses: List[float], val_losses: List[float] = None,
65
+ train_accs: List[float] = None, val_accs: List[float] = None,
66
+ save_path: Optional[str] = None, title: str = "Training Progress") -> None:
67
+ """Plot training curves for loss and accuracy."""
68
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
69
+
70
+ # Loss curves
71
+ axes[0].plot(train_losses, label='Train Loss', color='blue', alpha=0.7)
72
+ if val_losses:
73
+ axes[0].plot(val_losses, label='Validation Loss', color='red', alpha=0.7)
74
+ axes[0].set_xlabel('Epoch')
75
+ axes[0].set_ylabel('Loss')
76
+ axes[0].set_title('Training and Validation Loss')
77
+ axes[0].legend()
78
+ axes[0].grid(True, alpha=0.3)
79
+
80
+ # Accuracy curves
81
+ if train_accs is not None:
82
+ axes[1].plot(train_accs, label='Train Accuracy', color='blue', alpha=0.7)
83
+ if val_accs:
84
+ axes[1].plot(val_accs, label='Validation Accuracy', color='red', alpha=0.7)
85
+ axes[1].set_xlabel('Epoch')
86
+ axes[1].set_ylabel('Accuracy (%)')
87
+ axes[1].set_title('Training and Validation Accuracy')
88
+ axes[1].legend()
89
+ axes[1].grid(True, alpha=0.3)
90
+ else:
91
+ axes[1].axis('off')
92
+
93
+ plt.suptitle(title, fontsize=16, fontweight='bold')
94
+ plt.tight_layout()
95
+
96
+ if save_path:
97
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
98
+ print(f"Training curves saved to {save_path}")
99
+
100
+ plt.show()
101
+
102
+
103
+ def plot_gan_training_curves(g_losses: List[float], d_losses: List[float],
104
+ real_scores: List[float] = None, fake_scores: List[float] = None,
105
+ save_path: Optional[str] = None) -> None:
106
+ """Plot GAN training curves."""
107
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
108
+
109
+ # Generator and Discriminator losses
110
+ axes[0, 0].plot(g_losses, label='Generator Loss', color='blue', alpha=0.7)
111
+ axes[0, 0].plot(d_losses, label='Discriminator Loss', color='red', alpha=0.7)
112
+ axes[0, 0].set_xlabel('Epoch')
113
+ axes[0, 0].set_ylabel('Loss')
114
+ axes[0, 0].set_title('Generator and Discriminator Loss')
115
+ axes[0, 0].legend()
116
+ axes[0, 0].grid(True, alpha=0.3)
117
+
118
+ # Discriminator scores
119
+ if real_scores and fake_scores:
120
+ axes[0, 1].plot(real_scores, label='Real Score', color='green', alpha=0.7)
121
+ axes[0, 1].plot(fake_scores, label='Fake Score', color='orange', alpha=0.7)
122
+ axes[0, 1].set_xlabel('Epoch')
123
+ axes[0, 1].set_ylabel('Score')
124
+ axes[0, 1].set_title('Discriminator Scores')
125
+ axes[0, 1].legend()
126
+ axes[0, 1].grid(True, alpha=0.3)
127
+ else:
128
+ axes[0, 1].axis('off')
129
+
130
+ # Generator loss only
131
+ axes[1, 0].plot(g_losses, color='blue', alpha=0.7)
132
+ axes[1, 0].set_xlabel('Epoch')
133
+ axes[1, 0].set_ylabel('Generator Loss')
134
+ axes[1, 0].set_title('Generator Loss')
135
+ axes[1, 0].grid(True, alpha=0.3)
136
+
137
+ # Discriminator loss only
138
+ axes[1, 1].plot(d_losses, color='red', alpha=0.7)
139
+ axes[1, 1].set_xlabel('Epoch')
140
+ axes[1, 1].set_ylabel('Discriminator Loss')
141
+ axes[1, 1].set_title('Discriminator Loss')
142
+ axes[1, 1].grid(True, alpha=0.3)
143
+
144
+ plt.suptitle('GAN Training Progress', fontsize=16, fontweight='bold')
145
+ plt.tight_layout()
146
+
147
+ if save_path:
148
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
149
+ print(f"GAN training curves saved to {save_path}")
150
+
151
+ plt.show()
152
+
153
+
154
+ def plot_image_comparison(original: np.ndarray, generated: np.ndarray,
155
+ save_path: Optional[str] = None, title: str = "Original vs Generated") -> None:
156
+ """Plot comparison between original and generated images."""
157
+ fig, axes = plt.subplots(1, 2, figsize=(10, 5))
158
+
159
+ # Original image
160
+ axes[0].imshow(original, cmap='gray', vmin=0, vmax=1)
161
+ axes[0].set_title('Original')
162
+ axes[0].axis('off')
163
+
164
+ # Generated image
165
+ axes[1].imshow(generated, cmap='gray', vmin=0, vmax=1)
166
+ axes[1].set_title('Generated')
167
+ axes[1].axis('off')
168
+
169
+ plt.suptitle(title, fontsize=16, fontweight='bold')
170
+ plt.tight_layout()
171
+
172
+ if save_path:
173
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
174
+ print(f"Comparison saved to {save_path}")
175
+
176
+ plt.show()
177
+
178
+
179
+ def plot_feature_visualization(features: np.ndarray, save_path: Optional[str] = None) -> None:
180
+ """Visualize extracted features."""
181
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
182
+
183
+ # Feature histogram
184
+ axes[0, 0].hist(features.flatten(), bins=50, alpha=0.7, color='blue')
185
+ axes[0, 0].set_title('Feature Distribution')
186
+ axes[0, 0].set_xlabel('Feature Value')
187
+ axes[0, 0].set_ylabel('Frequency')
188
+ axes[0, 0].grid(True, alpha=0.3)
189
+
190
+ # Feature heatmap
191
+ if len(features.shape) == 2:
192
+ im = axes[0, 1].imshow(features, cmap='viridis', aspect='auto')
193
+ axes[0, 1].set_title('Feature Heatmap')
194
+ plt.colorbar(im, ax=axes[0, 1])
195
+
196
+ # Feature statistics
197
+ stats_text = f"""
198
+ Mean: {features.mean():.4f}
199
+ Std: {features.std():.4f}
200
+ Min: {features.min():.4f}
201
+ Max: {features.max():.4f}
202
+ """
203
+ axes[1, 0].text(0.1, 0.5, stats_text, fontsize=12, verticalalignment='center')
204
+ axes[1, 0].set_title('Feature Statistics')
205
+ axes[1, 0].axis('off')
206
+
207
+ # Feature correlation (if 2D)
208
+ if len(features.shape) == 2 and features.shape[1] > 1:
209
+ corr_matrix = np.corrcoef(features.T)
210
+ im = axes[1, 1].imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
211
+ axes[1, 1].set_title('Feature Correlation')
212
+ plt.colorbar(im, ax=axes[1, 1])
213
+ else:
214
+ axes[1, 1].axis('off')
215
+
216
+ plt.suptitle('Feature Visualization', fontsize=16, fontweight='bold')
217
+ plt.tight_layout()
218
+
219
+ if save_path:
220
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
221
+ print(f"Feature visualization saved to {save_path}")
222
+
223
+ plt.show()
224
+
225
+
226
+ def plot_symmetry_analysis(image: np.ndarray, save_path: Optional[str] = None) -> None:
227
+ """Analyze and visualize symmetry properties of an image."""
228
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
229
+
230
+ # Original image
231
+ axes[0, 0].imshow(image, cmap='gray', vmin=0, vmax=1)
232
+ axes[0, 0].set_title('Original')
233
+ axes[0, 0].axis('off')
234
+
235
+ # Horizontal flip
236
+ h_flipped = np.fliplr(image)
237
+ axes[0, 1].imshow(h_flipped, cmap='gray', vmin=0, vmax=1)
238
+ axes[0, 1].set_title('Horizontal Flip')
239
+ axes[0, 1].axis('off')
240
+
241
+ # Vertical flip
242
+ v_flipped = np.flipud(image)
243
+ axes[0, 2].imshow(v_flipped, cmap='gray', vmin=0, vmax=1)
244
+ axes[0, 2].set_title('Vertical Flip')
245
+ axes[0, 2].axis('off')
246
+
247
+ # Difference with horizontal flip
248
+ h_diff = np.abs(image - h_flipped)
249
+ axes[1, 0].imshow(h_diff, cmap='hot', vmin=0, vmax=1)
250
+ axes[1, 0].set_title(f'Horizontal Symmetry\n(Error: {h_diff.mean():.4f})')
251
+ axes[1, 0].axis('off')
252
+
253
+ # Difference with vertical flip
254
+ v_diff = np.abs(image - v_flipped)
255
+ axes[1, 1].imshow(v_diff, cmap='hot', vmin=0, vmax=1)
256
+ axes[1, 1].set_title(f'Vertical Symmetry\n(Error: {v_diff.mean():.4f})')
257
+ axes[1, 1].axis('off')
258
+
259
+ # Rotational symmetry
260
+ rotated = np.rot90(image, k=2)
261
+ r_diff = np.abs(image - rotated)
262
+ axes[1, 2].imshow(r_diff, cmap='hot', vmin=0, vmax=1)
263
+ axes[1, 2].set_title(f'Rotational Symmetry\n(Error: {r_diff.mean():.4f})')
264
+ axes[1, 2].axis('off')
265
+
266
+ plt.suptitle('Symmetry Analysis', fontsize=16, fontweight='bold')
267
+ plt.tight_layout()
268
+
269
+ if save_path:
270
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
271
+ print(f"Symmetry analysis saved to {save_path}")
272
+
273
+ plt.show()
274
+
275
+
276
+ def plot_metrics_comparison(metrics_dict: dict, save_path: Optional[str] = None) -> None:
277
+ """Plot comparison of different metrics."""
278
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
279
+
280
+ # Extract metric names and values
281
+ metric_names = list(metrics_dict.keys())
282
+ metric_values = list(metrics_dict.values())
283
+
284
+ # Bar plot of metrics
285
+ axes[0, 0].bar(metric_names, metric_values, alpha=0.7, color='skyblue')
286
+ axes[0, 0].set_title('Metrics Comparison')
287
+ axes[0, 0].set_ylabel('Value')
288
+ axes[0, 0].tick_params(axis='x', rotation=45)
289
+ axes[0, 0].grid(True, alpha=0.3)
290
+
291
+ # Pie chart of relative importance
292
+ axes[0, 1].pie(metric_values, labels=metric_names, autopct='%1.1f%%', startangle=90)
293
+ axes[0, 1].set_title('Relative Metric Importance')
294
+
295
+ # Line plot (if metrics are time series)
296
+ if len(metric_values) > 1:
297
+ axes[1, 0].plot(metric_names, metric_values, marker='o', linewidth=2, markersize=8)
298
+ axes[1, 0].set_title('Metrics Trend')
299
+ axes[1, 0].set_ylabel('Value')
300
+ axes[1, 0].tick_params(axis='x', rotation=45)
301
+ axes[1, 0].grid(True, alpha=0.3)
302
+
303
+ # Summary statistics
304
+ stats_text = f"""
305
+ Best Metric: {metric_names[np.argmax(metric_values)]}
306
+ Worst Metric: {metric_names[np.argmin(metric_values)]}
307
+ Average: {np.mean(metric_values):.4f}
308
+ Std Dev: {np.std(metric_values):.4f}
309
+ """
310
+ axes[1, 1].text(0.1, 0.5, stats_text, fontsize=12, verticalalignment='center')
311
+ axes[1, 1].set_title('Summary Statistics')
312
+ axes[1, 1].axis('off')
313
+
314
+ plt.suptitle('Metrics Analysis', fontsize=16, fontweight='bold')
315
+ plt.tight_layout()
316
+
317
+ if save_path:
318
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
319
+ print(f"Metrics comparison saved to {save_path}")
320
+
321
+ plt.show()
322
+
323
+
324
+ def create_animation_frames(images: List[np.ndarray], save_dir: str) -> None:
325
+ """Create frames for animation from a list of images."""
326
+ save_dir = Path(save_dir)
327
+ save_dir.mkdir(parents=True, exist_ok=True)
328
+
329
+ for i, img in enumerate(images):
330
+ plt.figure(figsize=(8, 8))
331
+ plt.imshow(img, cmap='gray', vmin=0, vmax=1)
332
+ plt.title(f'Frame {i+1}')
333
+ plt.axis('off')
334
+ plt.tight_layout()
335
+ plt.savefig(save_dir / f'frame_{i:04d}.png', dpi=150, bbox_inches='tight')
336
+ plt.close()
337
+
338
+ print(f"Animation frames saved to {save_dir}")
339
+
340
+
341
+ if __name__ == "__main__":
342
+ # Test visualization functions
343
+ print("Testing visualization functions...")
344
+
345
+ # Create sample images
346
+ sample_images = [np.random.rand(64, 64) for _ in range(16)]
347
+
348
+ # Test grid plotting
349
+ plot_kolam_grid(sample_images, title="Test Grid")
350
+
351
+ # Test training curves
352
+ train_losses = [1.0 - i * 0.01 for i in range(100)]
353
+ val_losses = [1.1 - i * 0.009 for i in range(100)]
354
+ plot_training_curves(train_losses, val_losses, title="Test Training Curves")
355
+
356
+ print("Visualization tests completed!")