from pathlib import Path import json import pandas as pd import numpy as np import gradio as gr from datasets import load_dataset from gradio_leaderboard import Leaderboard from datetime import datetime import os from about import ( PROBLEM_TYPES, TOKEN, CACHE_PATH, API, submissions_repo, results_repo, COLUMN_DISPLAY_NAMES, COUNT_BASED_METRICS, METRIC_GROUPS, METRIC_GROUP_COLORS, COLUMN_TO_GROUP ) def get_leaderboard(): ds = load_dataset(results_repo, split='train', download_mode="force_redownload") full_df = pd.DataFrame(ds) print(full_df.columns) if len(full_df) == 0: return pd.DataFrame({'date':[], 'model':[], 'score':[], 'verified':[]}) return full_df def format_dataframe(df, show_percentage=False, selected_groups=None, compact_view=True): """Format the dataframe with proper column names and optional percentages.""" if len(df) == 0: return df # Build column list based on view mode selected_cols = ['model_name'] if compact_view: # Use predefined compact columns from about import COMPACT_VIEW_COLUMNS selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns] else: # Build from selected groups if 'n_structures' in df.columns: selected_cols.append('n_structures') # If no groups selected, show all if not selected_groups: selected_groups = list(METRIC_GROUPS.keys()) # Add columns from selected groups for group in selected_groups: if group in METRIC_GROUPS: for col in METRIC_GROUPS[group]: if col in df.columns and col not in selected_cols: selected_cols.append(col) # Create a copy with selected columns display_df = df[selected_cols].copy() # Add symbols to model names based on various properties if 'model_name' in display_df.columns: def add_model_symbols(row): name = row['model_name'] symbols = [] # Add relaxed symbol if 'relaxed' in df.columns and row.get('relaxed', False): symbols.append('⚡') # Add reference dataset symbols # ★ for Alexandria and OQMD (in-distribution, part of reference dataset) if name in ['Alexandria', 'OQMD']: symbols.append('★') # ◆ for AFLOW (out-of-distribution relative to reference dataset) elif name == 'AFLOW': symbols.append('◆') return f"{name} {' '.join(symbols)}" if symbols else name display_df['model_name'] = df.apply(add_model_symbols, axis=1) # Convert count-based metrics to percentages if requested if show_percentage and 'n_structures' in df.columns: n_structures = df['n_structures'] for col in COUNT_BASED_METRICS: if col in display_df.columns: # Calculate percentage and format as string with % display_df[col] = (df[col] / n_structures * 100).round(1).astype(str) + '%' # Round numeric columns for cleaner display for col in display_df.columns: if display_df[col].dtype in ['float64', 'float32']: display_df[col] = display_df[col].round(4) # Rename columns for display display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES) # Apply color coding based on metric groups styled_df = apply_color_styling(display_df, selected_cols) return styled_df def apply_color_styling(display_df, original_cols): """Apply background colors to dataframe based on metric groups using pandas Styler.""" def style_by_group(x): # Create a DataFrame with the same shape filled with empty strings styles = pd.DataFrame('', index=x.index, columns=x.columns) # Map display column names back to original column names for i, display_col in enumerate(x.columns): if i < len(original_cols): original_col = original_cols[i] # Check if this column belongs to a metric group if original_col in COLUMN_TO_GROUP: group = COLUMN_TO_GROUP[original_col] color = METRIC_GROUP_COLORS.get(group, '') if color: styles[display_col] = f'background-color: {color}' return styles # Apply the styling function return display_df.style.apply(style_by_group, axis=None) def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction): """Update the leaderboard based on user selections. Uses cached dataframe to avoid re-downloading data on every change. """ # Use cached dataframe instead of re-downloading df_to_format = cached_df.copy() # Convert display name back to raw column name for sorting if sort_by and sort_by != "None": # Create reverse mapping from display names to raw column names display_to_raw = {v: k for k, v in COLUMN_DISPLAY_NAMES.items()} raw_column_name = display_to_raw.get(sort_by, sort_by) if raw_column_name in df_to_format.columns: ascending = (sort_direction == "Ascending") df_to_format = df_to_format.sort_values(by=raw_column_name, ascending=ascending) formatted_df = format_dataframe(df_to_format, show_percentage, selected_groups, compact_view) return formatted_df def show_output_box(message): return gr.update(value=message, visible=True) def submit_cif_files(model_name, problem_type, cif_files, relaxed, profile: gr.OAuthProfile | None): """Submit structures to the leaderboard.""" from huggingface_hub import upload_file # Validate inputs if not model_name or not model_name.strip(): return "Error: Please provide a model name.", None if not problem_type: return "Error: Please select a problem type.", None if not cif_files: return "Error: Please upload a file.", None if not profile: return "Error: Please log in to submit.", None try: username = profile.username timestamp = datetime.now().isoformat() # Create submission metadata submission_data = { "username": username, "model_name": model_name.strip(), "problem_type": problem_type, "relaxed": relaxed, "timestamp": timestamp, "file_name": Path(cif_files).name } # Create a unique submission ID submission_id = f"{username}_{model_name.strip().replace(' ', '_')}_{timestamp.replace(':', '-')}" # Upload the submission file file_path = Path(cif_files) uploaded_file_path = f"submissions/{submission_id}/{file_path.name}" upload_file( path_or_fileobj=str(file_path), path_in_repo=uploaded_file_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset" ) # Upload metadata as JSON metadata_path = f"submissions/{submission_id}/metadata.json" import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(submission_data, f, indent=2) temp_metadata_path = f.name upload_file( path_or_fileobj=temp_metadata_path, path_in_repo=metadata_path, repo_id=submissions_repo, token=TOKEN, repo_type="dataset" ) # Clean up temp file os.unlink(temp_metadata_path) return f"Success! Submitted {model_name} for {problem_type} evaluation. Submission ID: {submission_id}", submission_id except Exception as e: return f"Error during submission: {str(e)}", None def generate_metric_legend_html(): """Generate HTML table with color-coded metric group legend.""" metric_details = { 'Validity ↑': ('Valid, Charge Neutral, Distance Valid, Plausibility Valid', '↑ Higher is better'), 'Uniqueness & Novelty ↑': ('Unique, Novel', '↑ Higher is better'), 'Energy Metrics ↓': ('E Above Hull, Formation Energy, Relaxation RMSD (with std)', '↓ Lower is better'), 'Stability ↑': ('Stable, Unique in Stable, SUN', '↑ Higher is better'), 'Metastability ↑': ('Metastable, Unique in Metastable, MSUN', '↑ Higher is better'), 'Distribution ↓': ('JS Distance, MMD, FID', '↓ Lower is better'), 'Diversity ↑': ('Element, Space Group, Atomic Site, Crystal Size', '↑ Higher is better'), 'HHI ↓': ('HHI Production, HHI Reserve', '↓ Lower is better'), } html = '' html += '' html += '' html += '' html += '' html += '' html += '' for group, color in METRIC_GROUP_COLORS.items(): metrics, direction = metric_details.get(group, ('', '')) group_name = group.replace('↑', '').replace('↓', '').strip() html += '' html += f'' html += f'' html += f'' html += f'' html += '' html += '
ColorGroupMetricsDirection
{group_name}{metrics}{direction}
' return html def gradio_interface() -> gr.Blocks: with gr.Blocks() as demo: gr.Markdown(""" # 🔬 LeMat-GenBench: A Unified Benchmark for Generative Models of Crystalline Materials Generative machine learning models hold great promise for accelerating materials discovery, particularly through the inverse design of inorganic crystals, enabling an unprecedented exploration of chemical space. Yet, the lack of standardized evaluation frameworks makes it difficult to evaluate, compare and further develop these ML models meaningfully. **LeMat-GenBench** introduces a unified benchmark for generative models of crystalline materials, with standardized evaluation metrics** for meaningful model comparison, diverse tasks, and this leaderboard to encourage and track community progress. 📄 **Paper**: [arXiv preprint](https://arxiv.org/abs/XXXX.XXXXX) | 💻 **Code**: [GitHub](https://github.com/LeMaterial/lemat-genbench) | 📧 **Contact**: siddharth.betala-ext [at] entalpic.ai, alexandre.duval [at] entalpic.ai """) with gr.Tabs(elem_classes="tab-buttons"): with gr.TabItem("🚀 Leaderboard", elem_id="boundary-benchmark-tab-table"): gr.Markdown("# LeMat-GenBench") # Display options with gr.Row(): with gr.Column(scale=1): compact_view = gr.Checkbox( value=True, label="Compact View", info="Show only key metrics" ) show_percentage = gr.Checkbox( value=True, label="Show as Percentages", info="Display count-based metrics as percentages of total structures" ) with gr.Column(scale=1): # Create choices with display names, but values are the raw column names sort_choices = ["None"] + [COLUMN_DISPLAY_NAMES.get(col, col) for col in COLUMN_DISPLAY_NAMES.keys()] sort_by = gr.Dropdown( choices=sort_choices, value="None", label="Sort By", info="Select column to sort by" ) sort_direction = gr.Radio( choices=["Ascending", "Descending"], value="Descending", label="Sort Direction" ) with gr.Column(scale=2): selected_groups = gr.CheckboxGroup( choices=list(METRIC_GROUPS.keys()), value=list(METRIC_GROUPS.keys()), label="Metric Families (only active when Compact View is off)", info="Select which metric groups to display" ) # Metric legend with color coding with gr.Accordion("Metric Groups Legend", open=False): gr.HTML(generate_metric_legend_html()) try: # Initial dataframe - load once and cache initial_df = get_leaderboard() cached_df_state = gr.State(initial_df) formatted_df = format_dataframe(initial_df, show_percentage=True, selected_groups=list(METRIC_GROUPS.keys()), compact_view=True) leaderboard_table = gr.Dataframe( label="GenBench Leaderboard", value=formatted_df, interactive=False, wrap=True, column_widths=["180px"] + ["160px"] * (len(formatted_df.columns) - 1) if len(formatted_df.columns) > 0 else None, show_fullscreen_button=True ) # Update dataframe when options change (using cached data) show_percentage.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) selected_groups.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) compact_view.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) sort_by.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) sort_direction.change( fn=update_leaderboard, inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction], outputs=leaderboard_table ) except Exception as e: gr.Markdown(f"Leaderboard is empty or error loading: {str(e)}") gr.Markdown(""" **Symbol Legend:** - ⚡ Structures were already relaxed - ★ Contributes to LeMat-Bulk reference dataset (in-distribution) - ◆ Out-of-distribution relative to LeMat-Bulk reference dataset Verified submissions mean the results came from a model submission rather than a CIF submission. """) with gr.TabItem("✉️ Submit", elem_id="boundary-benchmark-tab-table"): gr.Markdown( """ # Materials Submission Upload a CSV, pkl, or a ZIP of CIFs with your structures. """ ) filename = gr.State(value=None) gr.LoginButton() with gr.Row(): with gr.Column(): model_name_input = gr.Textbox( label="Model Name", placeholder="Enter your model name", info="Provide a name for your model/method" ) problem_type = gr.Dropdown(PROBLEM_TYPES, label="Problem Type") with gr.Column(): cif_file = gr.File(label="Upload a CSV, a pkl, or a ZIP of CIF files.") relaxed = gr.Checkbox( value=False, label="Structures are already relaxed", info="Check this box if your submitted structures have already been relaxed" ) submit_btn = gr.Button("Submission") message = gr.Textbox(label="Status", lines=1, visible=False) # help message gr.Markdown("If you have issues with submission or using the leaderboard, please start a discussion in the Community tab of this Space.") submit_btn.click( submit_cif_files, inputs=[model_name_input, problem_type, cif_file, relaxed], outputs=[message, filename], ).then( fn=show_output_box, inputs=[message], outputs=[message], ) return demo if __name__ == "__main__": gradio_interface().launch()