Rishab7310 commited on
Commit
5cfe98f
·
verified ·
1 Parent(s): 08c01a4

Update models/gan_generator.py

Browse files
Files changed (1) hide show
  1. models/gan_generator.py +37 -110
models/gan_generator.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- GAN Generator model for creating new Kolam designs.
3
- Uses a deep convolutional architecture to generate high-quality Kolam patterns.
4
  """
5
 
6
  import torch
@@ -9,10 +9,7 @@ import torch.nn.functional as F
9
 
10
 
11
  class KolamGenerator(nn.Module):
12
- """
13
- Generator network for creating Kolam designs.
14
- Takes random noise and optional style features as input.
15
- """
16
 
17
  def __init__(self, noise_dim=100, feature_dim=128, output_channels=1, image_size=64):
18
  super(KolamGenerator, self).__init__()
@@ -20,174 +17,104 @@ class KolamGenerator(nn.Module):
20
  self.noise_dim = noise_dim
21
  self.feature_dim = feature_dim
22
  self.image_size = image_size
23
-
24
- # Calculate the starting size after upsampling
25
- # Assuming we start from 4x4 and upsample to 64x64
26
  self.start_size = 4
27
- self.num_upsamples = int(torch.log2(torch.tensor(image_size / self.start_size)).item())
28
 
29
- # Input projection layer
30
  self.input_projection = nn.Linear(noise_dim + feature_dim, 256 * self.start_size * self.start_size)
31
 
32
- # Upsampling layers
33
- self.upsample_layers = nn.ModuleList()
34
- self.conv_layers = nn.ModuleList()
35
- self.bn_layers = nn.ModuleList()
36
-
37
- # Build upsampling blocks
38
  in_channels = 256
39
  for i in range(self.num_upsamples):
40
  out_channels = in_channels // 2 if i < self.num_upsamples - 1 else 64
41
-
42
- self.upsample_layers.append(nn.ConvTranspose2d(in_channels, out_channels,
43
- kernel_size=4, stride=2, padding=1))
44
- self.conv_layers.append(nn.Conv2d(out_channels, out_channels,
45
- kernel_size=3, padding=1))
46
  self.bn_layers.append(nn.BatchNorm2d(out_channels))
47
-
48
  in_channels = out_channels
49
 
50
- # Final output layer
51
- self.final_conv = nn.Conv2d(64, output_channels, kernel_size=3, padding=1)
52
-
53
  def forward(self, noise, features=None):
54
- """
55
- Generate Kolam images from noise and optional features.
56
-
57
- Args:
58
- noise: Random noise tensor of shape (batch_size, noise_dim)
59
- features: Optional feature tensor of shape (batch_size, feature_dim)
60
-
61
- Returns:
62
- Generated images of shape (batch_size, 1, image_size, image_size)
63
- """
64
  batch_size = noise.size(0)
65
-
66
- # Combine noise and features
67
  if features is not None:
68
  x = torch.cat([noise, features], dim=1)
69
  else:
70
- # If no features provided, use zero features
71
  zero_features = torch.zeros(batch_size, self.feature_dim, device=noise.device)
72
  x = torch.cat([noise, zero_features], dim=1)
73
 
74
- # Project to initial feature map
75
  x = self.input_projection(x)
76
  x = x.view(batch_size, 256, self.start_size, self.start_size)
77
 
78
- # Upsample and refine
79
  for i in range(self.num_upsamples):
80
  x = self.upsample_layers[i](x)
81
  x = self.bn_layers[i](x)
82
  x = F.relu(x)
83
-
84
  x = self.conv_layers[i](x)
85
  x = self.bn_layers[i](x)
86
  x = F.relu(x)
87
 
88
- # Final output
89
- x = self.final_conv(x)
90
- x = torch.tanh(x) # Output in range [-1, 1]
91
-
92
- return x
93
-
94
- def generate(self, num_samples=1, features=None, device='cpu'):
95
- """
96
- Generate samples without gradients (for inference).
97
-
98
- Args:
99
- num_samples: Number of samples to generate
100
- features: Optional feature tensor
101
- device: Device to generate on
102
-
103
- Returns:
104
- Generated images
105
- """
106
- self.eval()
107
- with torch.no_grad():
108
- noise = torch.randn(num_samples, self.noise_dim, device=device)
109
- return self.forward(noise, features)
110
 
111
 
112
  class StyleConditionedGenerator(KolamGenerator):
113
- """
114
- Style-conditioned generator that can generate Kolam designs
115
- in specific styles based on input features.
116
- """
117
 
118
- def __init__(self, noise_dim=100, feature_dim=128, style_dim=32,
119
- output_channels=1, image_size=64):
120
  super().__init__(noise_dim, feature_dim, output_channels, image_size)
121
-
122
- # Style embedding layer
123
  self.style_embedding = nn.Sequential(
124
  nn.Linear(style_dim, 64),
125
  nn.ReLU(),
126
  nn.Linear(64, 128)
127
  )
128
-
129
- # Update input projection to include style
130
- self.input_projection = nn.Linear(noise_dim + feature_dim + 128,
131
- 256 * self.start_size * self.start_size)
132
 
133
  def forward(self, noise, features=None, style=None):
134
- """
135
- Generate with style conditioning.
136
-
137
- Args:
138
- noise: Random noise
139
- features: Design features
140
- style: Style vector
141
- """
142
  batch_size = noise.size(0)
143
-
144
- # Process style
145
  if style is not None:
146
  style_embed = self.style_embedding(style)
147
  else:
148
  style_embed = torch.zeros(batch_size, 128, device=noise.device)
149
 
150
- # Combine all inputs
151
  if features is not None:
152
  x = torch.cat([noise, features, style_embed], dim=1)
153
  else:
154
  zero_features = torch.zeros(batch_size, self.feature_dim, device=noise.device)
155
  x = torch.cat([noise, zero_features, style_embed], dim=1)
156
 
157
- # Continue with parent forward pass
158
  x = self.input_projection(x)
159
  x = x.view(batch_size, 256, self.start_size, self.start_size)
160
 
161
- # Upsample and refine
162
  for i in range(self.num_upsamples):
163
  x = self.upsample_layers[i](x)
164
  x = self.bn_layers[i](x)
165
  x = F.relu(x)
166
-
167
  x = self.conv_layers[i](x)
168
  x = self.bn_layers[i](x)
169
  x = F.relu(x)
170
 
171
- # Final output
172
- x = self.final_conv(x)
173
- x = torch.tanh(x)
174
-
175
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  if __name__ == "__main__":
179
- # Test the generator
180
- generator = KolamGenerator()
181
- noise = torch.randn(4, 100) # Batch of 4, 100-dim noise
182
- features = torch.randn(4, 128) # Batch of 4, 128-dim features
183
-
184
- generated = generator(noise, features)
185
- print(f"Noise shape: {noise.shape}")
186
- print(f"Features shape: {features.shape}")
187
- print(f"Generated shape: {generated.shape}")
188
-
189
- # Test style-conditioned generator
190
- style_gen = StyleConditionedGenerator()
191
- style = torch.randn(4, 32)
192
- style_generated = style_gen(noise, features, style)
193
- print(f"Style-generated shape: {style_generated.shape}")
 
1
  """
2
+ Enhanced GAN Generator for Kolam designs.
3
+ Adds style-conditioning and more diverse outputs.
4
  """
5
 
6
  import torch
 
9
 
10
 
11
  class KolamGenerator(nn.Module):
12
+ """Base generator network for Kolam designs."""
 
 
 
13
 
14
  def __init__(self, noise_dim=100, feature_dim=128, output_channels=1, image_size=64):
15
  super(KolamGenerator, self).__init__()
 
17
  self.noise_dim = noise_dim
18
  self.feature_dim = feature_dim
19
  self.image_size = image_size
 
 
 
20
  self.start_size = 4
21
+ self.num_upsamples = int(torch.log2(torch.tensor(image_size // self.start_size)).item())
22
 
23
+ # Input projection
24
  self.input_projection = nn.Linear(noise_dim + feature_dim, 256 * self.start_size * self.start_size)
25
 
26
+ # Upsampling blocks
27
+ self.upsample_layers, self.conv_layers, self.bn_layers = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
 
 
 
 
28
  in_channels = 256
29
  for i in range(self.num_upsamples):
30
  out_channels = in_channels // 2 if i < self.num_upsamples - 1 else 64
31
+ self.upsample_layers.append(nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1))
32
+ self.conv_layers.append(nn.Conv2d(out_channels, out_channels, 3, padding=1))
 
 
 
33
  self.bn_layers.append(nn.BatchNorm2d(out_channels))
 
34
  in_channels = out_channels
35
 
36
+ # Final output
37
+ self.final_conv = nn.Conv2d(64, output_channels, 3, padding=1)
38
+
39
  def forward(self, noise, features=None):
 
 
 
 
 
 
 
 
 
 
40
  batch_size = noise.size(0)
 
 
41
  if features is not None:
42
  x = torch.cat([noise, features], dim=1)
43
  else:
 
44
  zero_features = torch.zeros(batch_size, self.feature_dim, device=noise.device)
45
  x = torch.cat([noise, zero_features], dim=1)
46
 
 
47
  x = self.input_projection(x)
48
  x = x.view(batch_size, 256, self.start_size, self.start_size)
49
 
 
50
  for i in range(self.num_upsamples):
51
  x = self.upsample_layers[i](x)
52
  x = self.bn_layers[i](x)
53
  x = F.relu(x)
 
54
  x = self.conv_layers[i](x)
55
  x = self.bn_layers[i](x)
56
  x = F.relu(x)
57
 
58
+ return torch.tanh(self.final_conv(x)) # [-1, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  class StyleConditionedGenerator(KolamGenerator):
62
+ """Generator with style-conditioning for more variety."""
 
 
 
63
 
64
+ def __init__(self, noise_dim=100, feature_dim=128, style_dim=32, output_channels=1, image_size=64):
 
65
  super().__init__(noise_dim, feature_dim, output_channels, image_size)
 
 
66
  self.style_embedding = nn.Sequential(
67
  nn.Linear(style_dim, 64),
68
  nn.ReLU(),
69
  nn.Linear(64, 128)
70
  )
71
+ self.input_projection = nn.Linear(noise_dim + feature_dim + 128, 256 * self.start_size * self.start_size)
 
 
 
72
 
73
  def forward(self, noise, features=None, style=None):
 
 
 
 
 
 
 
 
74
  batch_size = noise.size(0)
 
 
75
  if style is not None:
76
  style_embed = self.style_embedding(style)
77
  else:
78
  style_embed = torch.zeros(batch_size, 128, device=noise.device)
79
 
 
80
  if features is not None:
81
  x = torch.cat([noise, features, style_embed], dim=1)
82
  else:
83
  zero_features = torch.zeros(batch_size, self.feature_dim, device=noise.device)
84
  x = torch.cat([noise, zero_features, style_embed], dim=1)
85
 
 
86
  x = self.input_projection(x)
87
  x = x.view(batch_size, 256, self.start_size, self.start_size)
88
 
 
89
  for i in range(self.num_upsamples):
90
  x = self.upsample_layers[i](x)
91
  x = self.bn_layers[i](x)
92
  x = F.relu(x)
 
93
  x = self.conv_layers[i](x)
94
  x = self.bn_layers[i](x)
95
  x = F.relu(x)
96
 
97
+ return torch.tanh(self.final_conv(x))
98
+
99
+
100
+ # -------------------------------
101
+ # Utility: easy generation method
102
+ # -------------------------------
103
+ def generate_kolam_samples(generator, num_samples=4, device="cpu"):
104
+ """Generate sample Kolams with random noise + styles."""
105
+ generator.eval()
106
+ with torch.no_grad():
107
+ noise = torch.randn(num_samples, generator.noise_dim, device=device)
108
+ features = torch.randn(num_samples, generator.feature_dim, device=device)
109
+
110
+ if isinstance(generator, StyleConditionedGenerator):
111
+ style = torch.randn(num_samples, 32, device=device)
112
+ return generator(noise, features, style)
113
+ else:
114
+ return generator(noise, features)
115
 
116
 
117
  if __name__ == "__main__":
118
+ gen = StyleConditionedGenerator()
119
+ samples = generate_kolam_samples(gen, num_samples=2)
120
+ print("Generated:", samples.shape)