""" Hugging Face Spaces App for Kolam AI Generator Enhanced with StyleConditionedGenerator for more variety """ import gradio as gr import torch import numpy as np from PIL import Image from pathlib import Path import sys import matplotlib.cm as cm # for color mapping # Add project paths sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent / 'models')) from models.gan_generator import StyleConditionedGenerator from models.gan_discriminator import KolamDiscriminator from utils.metrics import KolamDesignMetrics class KolamAIGenerator: def __init__(self): """Initialize the Kolam AI Generator.""" self.generator = None self.discriminator = None self.metrics = KolamDesignMetrics() self.load_models() def load_models(self): """Load the pre-trained models.""" try: # Use StyleConditionedGenerator self.generator = StyleConditionedGenerator( noise_dim=100, feature_dim=128, style_dim=32, output_channels=1, image_size=64 ) self.discriminator = KolamDiscriminator( input_channels=1, image_size=64 ) # Try to load pretrained weights weights_path = Path("models/generator.pth") if weights_path.exists(): self.generator.load_state_dict( torch.load(weights_path, map_location="cpu") ) print("✅ Loaded pretrained generator weights!") else: print("⚠️ No pretrained weights found, using untrained model.") self.generator.eval() self.discriminator.eval() except Exception as e: print(f"❌ Error loading models: {e}") self.generator = StyleConditionedGenerator() self.generator.eval() def generate_kolam(self, complexity, symmetry, seed=None, use_color=True): """Generate a Kolam design with specified parameters.""" try: # Seed control if seed is not None: torch.manual_seed(int(seed)) np.random.seed(int(seed)) else: seed = np.random.randint(0, 100000) torch.manual_seed(seed) np.random.seed(seed) # Random noise noise = torch.randn(1, 100) # Complexity tuning if complexity == "Simple": noise = noise * 0.5 elif complexity == "Medium": noise = noise * 1.0 + torch.randn_like(noise) * 0.3 elif complexity == "Complex": noise = noise * 1.5 + torch.randn_like(noise) * 0.5 # Random features & style vector for variety features = torch.randn(1, 128) style = torch.randn(1, 32) # Generate image with torch.no_grad(): generated_kolam = self.generator(noise, features, style) # Normalize to [0,1] kolam_image = generated_kolam.squeeze().cpu().numpy() kolam_image = (kolam_image + 1) / 2 kolam_image = np.clip(kolam_image, 0, 1) # Apply symmetry if symmetry == "High": kolam_image = self.enhance_symmetry(kolam_image) # Convert to color if use_color: kolam_colored = cm.viridis(kolam_image)[:, :, :3] kolam_pil = Image.fromarray((kolam_colored * 255).astype(np.uint8)) else: kolam_pil = Image.fromarray((kolam_image * 255).astype(np.uint8), mode='L') return kolam_pil except Exception as e: print(f"❌ Error generating Kolam: {e}") return self.create_fallback_pattern() def enhance_symmetry(self, image): """Enhance symmetry with mirroring + rotation.""" img_sym = (image + np.fliplr(image)) / 2 img_sym = (img_sym + np.flipud(img_sym)) / 2 rotated = np.rot90(image) img_sym = (img_sym + rotated) / 2 return np.clip(img_sym, 0, 1) def create_fallback_pattern(self): """Fallback geometric pattern.""" size = 64 pattern = np.zeros((size, size), dtype=np.float32) center = size // 2 for radius in range(5, center, 8): y, x = np.ogrid[:size, :size] mask = (x - center) ** 2 + (y - center) ** 2 <= radius ** 2 pattern[mask] = 1.0 return Image.fromarray((pattern * 255).astype(np.uint8), mode='L') def analyze_quality(self, image): """Analyze the quality of the Kolam.""" try: if isinstance(image, Image.Image): image_array = np.array(image) / 255.0 else: image_array = image quality = self.metrics.calculate_overall_quality(image_array) return { "Overall Quality": f"{quality['overall_quality']:.3f}", "Horizontal Symmetry": f"{quality['horizontal']:.3f}", "Vertical Symmetry": f"{quality['vertical']:.3f}", "Complexity": f"{quality['complexity']:.3f}", "Balance": f"{quality['balance']:.3f}", "Rhythm": f"{quality['rhythm']:.3f}" } except Exception as e: print(f"❌ Error analyzing quality: {e}") return {k: "N/A" for k in [ "Overall Quality", "Horizontal Symmetry", "Vertical Symmetry", "Complexity", "Balance", "Rhythm" ]} # ------------------------- # Interface setup # ------------------------- kolam_ai = KolamAIGenerator() def generate_and_analyze(complexity, symmetry, seed): kolam_image = kolam_ai.generate_kolam(complexity, symmetry, seed, use_color=True) quality_metrics = kolam_ai.analyze_quality(kolam_image) return kolam_image, quality_metrics def create_interface(): css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .title { text-align: center; color: #2E86AB; margin-bottom: 20px; } .description { text-align: center; color: #666; margin-bottom: 30px; } """ with gr.Blocks(css=css, title="Kolam AI Generator") as interface: gr.HTML("""
Generate beautiful, diverse Kolam designs using AI