import pandas as pd import gradio as gr import matplotlib.pyplot as plt import matplotlib import io import base64 from data import ModelBenchmarkData # Configure matplotlib for better performance matplotlib.use('Agg') plt.ioff() DATA = ModelBenchmarkData("data.json") def refresh_plot_data(): data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False) print(data) return pd.DataFrame(data) def load_css(): """Load CSS styling.""" try: with open("styles.css", "r") as f: return f.read() except FileNotFoundError: return "body { background: #000; color: #fff; }" def create_matplotlib_bar_charts(): """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False) # Create figure with dark theme - wider for side-by-side plots plt.style.use('dark_background') fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12)) fig.patch.set_facecolor('#000000') # Prepare data labels = data['label'] ttft_values = data['ttft'] tpot_values = data['tpot'] # Define color mapping based on configuration keywords def get_color_for_config(label): is_eager = 'eager' in label.lower() is_sdpa = 'sdpa' in label.lower() is_compiled = '_compiled' in label.lower() if is_eager: if is_compiled: return '#FF4444' # Red for eager compiled else: return '#FF6B6B' # Light red for eager uncompiled elif is_sdpa: if is_compiled: return '#4A90E2' # Blue for SDPA compiled else: return '#7BB3F0' # Light blue for SDPA uncompiled else: return '#FFD700' # Yellow for others # Get colors for each bar colors = [get_color_for_config(label) for label in labels] # TTFT Plot (left) ax1.set_facecolor('#000000') bars1 = ax1.bar(range(len(labels)), ttft_values, color=colors, width=1.0, edgecolor='white', linewidth=1) ax1.set_xlabel('Model Configuration', color='white', fontsize=14) ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14) ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20) ax1.set_xticks(range(len(labels))) ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels], rotation=45, ha='right', color='white', fontsize=10) ax1.tick_params(colors='white') ax1.grid(True, alpha=0.3, color='white') # TPOT Plot (right) ax2.set_facecolor('#000000') bars2 = ax2.bar(range(len(labels)), tpot_values, color=colors, width=1.0, edgecolor='white', linewidth=1) ax2.set_xlabel('Model Configuration', color='white', fontsize=14) ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14) ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20) ax2.set_xticks(range(len(labels))) ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels], rotation=45, ha='right', color='white', fontsize=10) ax2.tick_params(colors='white') ax2.grid(True, alpha=0.3, color='white') # Tight layout to prevent label cutoff plt.tight_layout() # Save plot to bytes buffer = io.BytesIO() plt.savefig(buffer, format='png', facecolor='#000000', bbox_inches='tight', dpi=100) buffer.seek(0) # Convert to base64 for HTML embedding img_data = base64.b64encode(buffer.getvalue()).decode() plt.close(fig) # Return HTML with embedded image - almost full height html = f"""