Spaces:
Runtime error
Runtime error
| import asyncio | |
| import logging | |
| import gradio as gr | |
| import pandas as pd | |
| from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \ | |
| get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \ | |
| get_async_connection | |
| logger = logging.getLogger(__name__) | |
| ALL_QUESTIONS_STR = "All questions" | |
| # Initialize data at the module level | |
| questions = [] | |
| source_finders = [] | |
| questions_dict = {} | |
| source_finders_dict = {} | |
| question_options = [] | |
| baseline_rankers_dict = {} | |
| baseline_ranker_options = [] | |
| run_ids = [] | |
| available_run_id_dict = {} | |
| finder_options = [] | |
| previous_run_id = "initial_run" | |
| run_id_options = [] | |
| run_id_dropdown = None | |
| # Last source runs for retrieving full baseline_reason on selection | |
| last_source_runs = [] | |
| # Maximum length for baseline_reason display | |
| TRUNCATE_REASON_LEN = 50 | |
| # Get all questions | |
| # Initialize data in a single async function | |
| async def initialize_data(): | |
| global source_finders, source_finders_dict, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options | |
| async with get_async_connection() as conn: | |
| source_finders = await get_source_finders(conn) | |
| baseline_rankers = await get_baseline_rankers(conn) | |
| # Convert to dictionaries for easier lookup | |
| baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers} | |
| source_finders_dict = {f["name"]: f["id"] for f in source_finders} | |
| # Create formatted options for dropdowns | |
| finder_options = [s["name"] for s in source_finders] | |
| baseline_ranker_options = [b["name"] for b in baseline_rankers] | |
| def update_run_ids(question_option, source_finder_name, baseline_ranker_name): | |
| return asyncio.run(update_run_ids_async(question_option, source_finder_name, baseline_ranker_name)) | |
| async def update_run_ids_async(question_option, source_finder_name, baseline_ranker_name): | |
| global question_options, questions_dict, previous_run_id, available_run_id_dict, run_id_options | |
| async with get_async_connection() as conn: | |
| finder_id_int = source_finders_dict.get(source_finder_name) | |
| available_run_id_dict = await get_run_ids(conn, finder_id_int) | |
| run_id_options = list(available_run_id_dict.keys()) | |
| return gr.Dropdown(choices=[]), None, None, gr.Dropdown(choices=run_id_options, | |
| value=None), "Select Question to see results.csv", "", "" | |
| def update_questions_list(source_finder_name, run_id, baseline_ranker_name): | |
| return asyncio.run(update_questions_list_async(source_finder_name, run_id, baseline_ranker_name)) | |
| async def update_questions_list_async(source_finder_name, run_id, baseline_ranker_name): | |
| global available_run_id_dict | |
| if source_finder_name and run_id and baseline_ranker_name: | |
| async with get_async_connection() as conn: | |
| run_id_int = available_run_id_dict.get(run_id) | |
| baseline_ranker_id = baseline_rankers_dict.get(baseline_ranker_name) | |
| questions = await get_updated_question_list(conn, baseline_ranker_id, run_id_int) | |
| return gr.Dropdown(choices=questions, value=None), None, None, None, None, "" | |
| else: | |
| return None, None, None, None, None, "" | |
| async def get_updated_question_list(conn, baseline_ranker_id, finder_id_int): | |
| global questions_dict, questions | |
| questions = await get_questions(conn, finder_id_int, baseline_ranker_id) | |
| if questions: | |
| questions_dict = {q["text"]: q["id"] for q in questions} | |
| question_options = [ALL_QUESTIONS_STR] + [q['text'] for q in questions] | |
| else: | |
| question_options = [] | |
| return question_options | |
| def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str, | |
| evt: gr.EventData = None): | |
| global previous_run_id | |
| if evt: | |
| logger.info(f"event: {evt.target.elem_id}") | |
| if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id): | |
| return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip() | |
| if type(run_id) == str: | |
| previous_run_id = run_id | |
| return asyncio.run(update_sources_list_async(question_option, source_finder_id, run_id, baseline_ranker_id)) | |
| # Main function to handle UI interactions | |
| async def update_sources_list_async(question_option, source_finder_name, run_id, baseline_ranker_name: str): | |
| global available_run_id_dict, previous_run_id, questions_dict | |
| if not question_option: | |
| return gr.skip(), gr.skip(), "No question selected", "", "" | |
| if not source_finder_name or not run_id or not baseline_ranker_name: | |
| return gr.skip(), gr.skip(), "Need to select source finder and baseline", "", "" | |
| logger.info("processing update") | |
| async with get_async_connection() as conn: | |
| if type(baseline_ranker_name) == list: | |
| baseline_ranker_name = baseline_ranker_name[0] | |
| baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get( | |
| baseline_ranker_name) | |
| if len(source_finder_name): | |
| finder_id_int = source_finders_dict.get(source_finder_name) | |
| else: | |
| finder_id_int = None | |
| if question_option == ALL_QUESTIONS_STR: | |
| if finder_id_int: | |
| run_id_int = available_run_id_dict.get(run_id) | |
| all_stats = await calculate_cumulative_statistics_for_all_questions(conn, list(questions_dict.values()), | |
| run_id_int, | |
| baseline_ranker_id_int) | |
| else: | |
| all_stats = None | |
| return None, all_stats, "Select Run Id and source finder to see results.csv", "", "" | |
| # Extract question ID from selection | |
| question_id = questions_dict.get(question_option) | |
| available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id) | |
| previous_run_id = run_id | |
| run_id_int = available_run_id_dict.get(run_id) | |
| source_runs = None | |
| stats = None | |
| # Get source runs data | |
| if finder_id_int: | |
| source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int) | |
| global last_source_runs | |
| last_source_runs = source_runs | |
| df = pd.DataFrame(source_runs) | |
| if not source_runs: | |
| return None, None, "No results.csv found for the selected filters", "", "" | |
| # Format table columns | |
| columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', | |
| 'tractate', | |
| 'folio', 'reason'] | |
| df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df | |
| # CSV for download | |
| # csv_data = df.to_csv(index=False) | |
| metadata = await get_metadata(conn, question_id, run_id_int) | |
| result_message = f"Found {len(source_runs)} results.csv" | |
| return df_display, stats, result_message, metadata, "" | |
| # Add a new function to handle row selection | |
| async def handle_row_selection_async(evt: gr.SelectData): | |
| if evt is None or evt.value is None: | |
| return "No source selected" | |
| try: | |
| # Get the ID from the selected row | |
| tractate_chunk_id = evt.row_value[0] | |
| # Get the source text | |
| async with get_async_connection() as conn: | |
| text = await get_source_text(conn, tractate_chunk_id) | |
| return text | |
| except Exception as e: | |
| return f"Error retrieving source text: {str(e)}" | |
| def handle_row_selection(evt: gr.SelectData): | |
| return asyncio.run(handle_row_selection_async(evt)) | |
| # Create Gradio app | |
| # Ensure we clean up when done | |
| async def main(): | |
| global run_id_dropdown | |
| await initialize_data() | |
| with gr.Blocks(title="Source Runs Explorer", theme=gr.themes.Citrus()) as app: | |
| gr.Markdown("# Source Runs Explorer") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| source_finder_dropdown = gr.Dropdown( | |
| choices=finder_options, | |
| value=None, | |
| label="Source Finder", | |
| interactive=True, | |
| elem_id="source_finder_dropdown" | |
| ) | |
| with gr.Column(scale=1): | |
| run_id_dropdown = gr.Dropdown( | |
| choices=run_id_options, | |
| value=None, | |
| allow_custom_value=True, | |
| label="source finder Run ID", | |
| interactive=True, | |
| elem_id="run_id_dropdown" | |
| ) | |
| with gr.Column(scale=1): | |
| baseline_rankers_dropdown = gr.Dropdown( | |
| choices=baseline_ranker_options, | |
| value=None, | |
| label="Select Baseline Ranker", | |
| interactive=True, | |
| elem_id="baseline_rankers_dropdown" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Main content area | |
| question_dropdown = gr.Dropdown( | |
| choices=[ALL_QUESTIONS_STR] + question_options, | |
| label="Select Question (if list is empty this means there is no overlap between source run and baseline)", | |
| value=None, | |
| interactive=True, | |
| elem_id="question_dropdown" | |
| ) | |
| with gr.Column(scale=1): | |
| # Sidebar area | |
| gr.Markdown("""To Get started select the following: | |
| * Source Finder | |
| * Source Finder Run ID (corresponds to a run of the source finder for a group of questions) | |
| * Baseline Ranker (corresponds to a run of the baseline ranker for a group of questions) | |
| **Note: if there is no overlap between the baseline questions and the source finder questions, the question list will be empty.** | |
| """) | |
| with gr.Row(): | |
| result_text = gr.Markdown("Select a question to view source runs") | |
| with gr.Row(): | |
| gr.Markdown("# Source Run Statistics") | |
| with gr.Row(): | |
| statistics_table = gr.DataFrame( | |
| headers=["num_high_ranked_baseline_sources", | |
| "num_high_ranked_found_sources", | |
| "overlap_count", | |
| "overlap_percentage", | |
| "high_ranked_overlap_count", | |
| "high_ranked_overlap_percentage" | |
| ], | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| metadata_text = gr.TextArea( | |
| label="Metadata of Source Finder for Selected Question", | |
| elem_id="metadata", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("# Sources Found") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| results_table = gr.DataFrame( | |
| headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run', | |
| 'source_run_rank', 'source_reason', 'baseline_reason'], | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| source_text = gr.TextArea( | |
| value="Text of the source will appear here", | |
| lines=15, | |
| label="Source Text", | |
| interactive=False, | |
| elem_id="source_text" | |
| ) | |
| # download_button = gr.DownloadButton( | |
| # label="Download Results as CSV", | |
| # interactive=True, | |
| # visible=True | |
| # ) | |
| # Set up event handlers | |
| results_table.select( | |
| handle_row_selection, | |
| inputs=None, | |
| outputs=source_text | |
| ) | |
| baseline_rankers_dropdown.change( | |
| update_questions_list, | |
| inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
| outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text] | |
| ) | |
| run_id_dropdown.change( | |
| update_questions_list, | |
| inputs=[source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
| outputs=[question_dropdown, result_text, metadata_text, results_table, statistics_table, source_text] | |
| ) | |
| question_dropdown.change( | |
| update_sources_list, | |
| inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown], | |
| outputs=[results_table, statistics_table, result_text, metadata_text, source_text] | |
| ) | |
| source_finder_dropdown.change( | |
| update_run_ids, | |
| inputs=[question_dropdown, source_finder_dropdown, baseline_rankers_dropdown], | |
| # outputs=[run_id_dropdown, results_table, result_text, download_button] | |
| outputs=[question_dropdown, results_table, statistics_table, run_id_dropdown, result_text, metadata_text, source_text] | |
| ) | |
| app.queue() | |
| app.launch() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| asyncio.run(main()) | |