Spaces:
Running
Running
| """ | |
| Hugging Face Spaces App for Kolam AI Generator | |
| Enhanced with StyleConditionedGenerator for more variety | |
| """ | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| import sys | |
| import matplotlib.cm as cm # for color mapping | |
| # Add project paths | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| sys.path.insert(0, str(Path(__file__).parent.parent / 'models')) | |
| from models.gan_generator import StyleConditionedGenerator | |
| from models.gan_discriminator import KolamDiscriminator | |
| from utils.metrics import KolamDesignMetrics | |
| class KolamAIGenerator: | |
| def __init__(self): | |
| """Initialize the Kolam AI Generator.""" | |
| self.generator = None | |
| self.discriminator = None | |
| self.metrics = KolamDesignMetrics() | |
| self.load_models() | |
| def load_models(self): | |
| """Load the pre-trained models.""" | |
| try: | |
| # Use StyleConditionedGenerator | |
| self.generator = StyleConditionedGenerator( | |
| noise_dim=100, | |
| feature_dim=128, | |
| style_dim=32, | |
| output_channels=1, | |
| image_size=64 | |
| ) | |
| self.discriminator = KolamDiscriminator( | |
| input_channels=1, | |
| image_size=64 | |
| ) | |
| # Try to load pretrained weights | |
| weights_path = Path("models/generator.pth") | |
| if weights_path.exists(): | |
| self.generator.load_state_dict( | |
| torch.load(weights_path, map_location="cpu") | |
| ) | |
| print("β Loaded pretrained generator weights!") | |
| else: | |
| print("β οΈ No pretrained weights found, using untrained model.") | |
| self.generator.eval() | |
| self.discriminator.eval() | |
| except Exception as e: | |
| print(f"β Error loading models: {e}") | |
| self.generator = StyleConditionedGenerator() | |
| self.generator.eval() | |
| def generate_kolam(self, complexity, symmetry, seed=None, use_color=True): | |
| """Generate a Kolam design with specified parameters.""" | |
| try: | |
| # Seed control | |
| if seed is not None: | |
| torch.manual_seed(int(seed)) | |
| np.random.seed(int(seed)) | |
| else: | |
| seed = np.random.randint(0, 100000) | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Random noise | |
| noise = torch.randn(1, 100) | |
| # Complexity tuning | |
| if complexity == "Simple": | |
| noise = noise * 0.5 | |
| elif complexity == "Medium": | |
| noise = noise * 1.0 + torch.randn_like(noise) * 0.3 | |
| elif complexity == "Complex": | |
| noise = noise * 1.5 + torch.randn_like(noise) * 0.5 | |
| # Random features & style vector for variety | |
| features = torch.randn(1, 128) | |
| style = torch.randn(1, 32) | |
| # Generate image | |
| with torch.no_grad(): | |
| generated_kolam = self.generator(noise, features, style) | |
| # Normalize to [0,1] | |
| kolam_image = generated_kolam.squeeze().cpu().numpy() | |
| kolam_image = (kolam_image + 1) / 2 | |
| kolam_image = np.clip(kolam_image, 0, 1) | |
| # Apply symmetry | |
| if symmetry == "High": | |
| kolam_image = self.enhance_symmetry(kolam_image) | |
| # Convert to color | |
| if use_color: | |
| kolam_colored = cm.viridis(kolam_image)[:, :, :3] | |
| kolam_pil = Image.fromarray((kolam_colored * 255).astype(np.uint8)) | |
| else: | |
| kolam_pil = Image.fromarray((kolam_image * 255).astype(np.uint8), mode='L') | |
| return kolam_pil | |
| except Exception as e: | |
| print(f"β Error generating Kolam: {e}") | |
| return self.create_fallback_pattern() | |
| def enhance_symmetry(self, image): | |
| """Enhance symmetry with mirroring + rotation.""" | |
| img_sym = (image + np.fliplr(image)) / 2 | |
| img_sym = (img_sym + np.flipud(img_sym)) / 2 | |
| rotated = np.rot90(image) | |
| img_sym = (img_sym + rotated) / 2 | |
| return np.clip(img_sym, 0, 1) | |
| def create_fallback_pattern(self): | |
| """Fallback geometric pattern.""" | |
| size = 64 | |
| pattern = np.zeros((size, size), dtype=np.float32) | |
| center = size // 2 | |
| for radius in range(5, center, 8): | |
| y, x = np.ogrid[:size, :size] | |
| mask = (x - center) ** 2 + (y - center) ** 2 <= radius ** 2 | |
| pattern[mask] = 1.0 | |
| return Image.fromarray((pattern * 255).astype(np.uint8), mode='L') | |
| def analyze_quality(self, image): | |
| """Analyze the quality of the Kolam.""" | |
| try: | |
| if isinstance(image, Image.Image): | |
| image_array = np.array(image) / 255.0 | |
| else: | |
| image_array = image | |
| quality = self.metrics.calculate_overall_quality(image_array) | |
| return { | |
| "Overall Quality": f"{quality['overall_quality']:.3f}", | |
| "Horizontal Symmetry": f"{quality['horizontal']:.3f}", | |
| "Vertical Symmetry": f"{quality['vertical']:.3f}", | |
| "Complexity": f"{quality['complexity']:.3f}", | |
| "Balance": f"{quality['balance']:.3f}", | |
| "Rhythm": f"{quality['rhythm']:.3f}" | |
| } | |
| except Exception as e: | |
| print(f"β Error analyzing quality: {e}") | |
| return {k: "N/A" for k in [ | |
| "Overall Quality", "Horizontal Symmetry", | |
| "Vertical Symmetry", "Complexity", | |
| "Balance", "Rhythm" | |
| ]} | |
| # ------------------------- | |
| # Interface setup | |
| # ------------------------- | |
| kolam_ai = KolamAIGenerator() | |
| def generate_and_analyze(complexity, symmetry, seed): | |
| kolam_image = kolam_ai.generate_kolam(complexity, symmetry, seed, use_color=True) | |
| quality_metrics = kolam_ai.analyze_quality(kolam_image) | |
| return kolam_image, quality_metrics | |
| def create_interface(): | |
| css = """ | |
| .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } | |
| .title { text-align: center; color: #2E86AB; margin-bottom: 20px; } | |
| .description { text-align: center; color: #666; margin-bottom: 30px; } | |
| """ | |
| with gr.Blocks(css=css, title="Kolam AI Generator") as interface: | |
| gr.HTML(""" | |
| <div class="title"> | |
| <h1>π¨ Kolam AI Generator</h1> | |
| <p class="description">Generate beautiful, diverse Kolam designs using AI</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ποΈ Controls") | |
| complexity = gr.Dropdown(["Simple", "Medium", "Complex"], value="Medium", label="Pattern Complexity") | |
| symmetry = gr.Dropdown(["Low", "Medium", "High"], value="Medium", label="Symmetry Level") | |
| seed = gr.Number(value=None, label="Random Seed (Optional)") | |
| generate_btn = gr.Button("π¨ Generate Kolam", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### πΌοΈ Generated Kolam") | |
| output_image = gr.Image(label="Generated Design", type="pil", height=400) | |
| gr.Markdown("### π Quality Analysis") | |
| quality_output = gr.JSON(label="Design Quality Metrics", value={}) | |
| generate_btn.click(fn=generate_and_analyze, inputs=[complexity, symmetry, seed], outputs=[output_image, quality_output]) | |
| interface.load(fn=lambda: generate_and_analyze("Medium", "Medium", None), outputs=[output_image, quality_output]) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True) | |