Spaces:
Runtime error
Runtime error
| import asyncio | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| import asyncpg | |
| import psycopg2 | |
| from cachetools import TTLCache, cached | |
| from dotenv import load_dotenv | |
| import pandas as pd | |
| # Global connection pool | |
| load_dotenv() | |
| async def get_async_connection(schema="talmudexplore", auto_commit=True): | |
| """ | |
| Get a connection for the current request. | |
| Args: | |
| schema: Database schema to use | |
| auto_commit: If True (default), each statement auto-commits. | |
| If False, requires explicit commit. | |
| """ | |
| conn = None | |
| tx = None | |
| try: | |
| # Create a single connection without relying on a shared pool | |
| conn = await asyncpg.connect( | |
| database=os.getenv("pg_dbname"), | |
| user=os.getenv("pg_user"), | |
| password=os.getenv("pg_password"), | |
| host=os.getenv("pg_host"), | |
| port=os.getenv("pg_port") | |
| ) | |
| await conn.execute(f'SET search_path TO {schema}') | |
| if not auto_commit: | |
| # Start a transaction that requires explicit commit | |
| tx = conn.transaction() | |
| await tx.start() | |
| yield conn | |
| if not auto_commit and tx: | |
| await tx.commit() | |
| finally: | |
| if conn: | |
| await conn.close() | |
| async def get_questions(conn: asyncpg.Connection, source_finder_run_id: int, baseline_source_finder_run_id: int): | |
| questions = await conn.fetch(""" | |
| select distinct q.id, question_text from talmudexplore.questions q | |
| join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $1) sfrqm1 | |
| on sfrqm1.question_id = q.id | |
| join (select question_id from talmudexplore.source_finder_run_question_metadata where source_finder_run_id = $2) sfrqm2 | |
| on sfrqm2.question_id = q.id; | |
| """, source_finder_run_id, baseline_source_finder_run_id) | |
| return [{"id": q["id"], "text": q["question_text"]} for q in questions] | |
| async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int): | |
| metadata = await conn.fetchrow(''' | |
| SELECT metadata | |
| FROM source_finder_run_question_metadata sfrqm | |
| WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2; | |
| ''', question_id, source_finder_id_run_id) | |
| if metadata is None: | |
| return "" | |
| return metadata.get('metadata') | |
| # Get distinct source finders | |
| async def get_source_finders(conn: asyncpg.Connection): | |
| finders = await conn.fetch(""" | |
| SELECT distinct sf.id, sf.source_finder_type as name from talmudexplore.source_finder_runs sfr | |
| join talmudexplore.source_finders sf on sf.id = sfr.source_finder_id | |
| WHERE EXISTS ( | |
| SELECT 1 | |
| FROM talmudexplore.source_run_results srr | |
| WHERE srr.source_finder_run_id = sfr.id | |
| ) | |
| ORDER BY sf.id | |
| """ | |
| ) | |
| return [{"id": f["id"], "name": f["name"]} for f in finders] | |
| # Get distinct run IDs for a question | |
| async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_id: int = None): | |
| query = """ | |
| select distinct sfr.description, srs.source_finder_run_id as run_id | |
| from source_run_results srs | |
| join source_finder_runs sfr on srs.source_finder_run_id = sfr.id | |
| join source_finders sf on sfr.source_finder_id = sf.id | |
| where sfr.source_finder_id = $1 | |
| """ | |
| if question_id is not None: | |
| query += " and srs.question_id = $2" | |
| params = (source_finder_id, question_id) | |
| else: | |
| params = (source_finder_id,) | |
| query += " order by run_id DESC;" | |
| run_ids = await conn.fetch(query, *params) | |
| return {r["description"]:r["run_id"] for r in run_ids} | |
| async def get_baseline_rankers(conn: asyncpg.Connection): | |
| query = """ | |
| SELECT sfr.id, sf.source_finder_type, sfr.description from source_finder_runs sfr | |
| join source_finders sf on sf.id = sfr.source_finder_id | |
| WHERE EXISTS ( | |
| SELECT 1 | |
| FROM source_run_results srr | |
| WHERE srr.source_finder_run_id = sfr.id | |
| ) | |
| ORDER BY sf.id DESC | |
| """ | |
| rankers = await conn.fetch(query) | |
| return [{"id": r["id"], "name": f"{r['source_finder_type']} : {r['description']}"} for r in rankers] | |
| async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources): | |
| # for a given question_id and source_finder_id and run_id calculate the baseline vs source stats | |
| # e.g. overlap, high ranked overlap, etc. | |
| actual_sources_set = {s["id"] for s in source_runs_sources} | |
| baseline_sources_set = {s["id"] for s in baseline_sources} | |
| # Calculate overlap | |
| overlap = actual_sources_set.intersection(baseline_sources_set) | |
| # only_in_1 = actual_sources_set - baseline_sources_set | |
| # only_in_2 = baseline_sources_set - actual_sources_set | |
| # Calculate high-ranked overlap (rank >= 4) | |
| actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4} | |
| baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4} | |
| high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked) | |
| results = { | |
| "total_baseline_sources": len(baseline_sources), | |
| "total_found_sources": len(source_runs_sources), | |
| "overlap_count": len(overlap), | |
| "overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)), | |
| 2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0, | |
| "num_high_ranked_baseline_sources": len(baseline_high_ranked), | |
| "num_high_ranked_found_sources": len(actual_high_ranked), | |
| "high_ranked_overlap_count": len(high_ranked_overlap), | |
| "high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0 | |
| } | |
| #convert results.csv to dataframe | |
| results_df = pd.DataFrame([results]) | |
| return results_df | |
| async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, question_ids, source_finder_run_id: int, ranker_id: int): | |
| """ | |
| Calculate cumulative statistics across all questions for a specific source finder, run, and ranker. | |
| Args: | |
| conn (asyncpg.Connection): Database connection | |
| question_ids (list): List of question IDs to analyze | |
| source_finder_run_id (int): ID of the source finder and run as appears in source runs | |
| ranker_id (int): ID of the baseline ranker | |
| Returns: | |
| pd.DataFrame: DataFrame containing aggregated statistics | |
| """ | |
| # Initialize aggregates | |
| total_baseline_sources = 0 | |
| total_found_sources = 0 | |
| total_overlap = 0 | |
| total_high_ranked_baseline = 0 | |
| total_high_ranked_found = 0 | |
| total_high_ranked_overlap = 0 | |
| # Process each question | |
| valid_questions = 0 | |
| for question_id in question_ids: | |
| try: | |
| # Get unified sources for this question | |
| sources, stats = await get_unified_sources(conn, question_id, source_finder_run_id, ranker_id) | |
| if sources and len(sources) > 0: | |
| valid_questions += 1 | |
| stats_dict = stats.iloc[0].to_dict() | |
| # Add to running totals | |
| total_baseline_sources += stats_dict.get('total_baseline_sources', 0) | |
| total_found_sources += stats_dict.get('total_found_sources', 0) | |
| total_overlap += stats_dict.get('overlap_count', 0) | |
| total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0) | |
| total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0) | |
| total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0) | |
| except Exception as e: | |
| # Skip questions with errors | |
| continue | |
| # Calculate overall percentages | |
| overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \ | |
| if max(total_baseline_sources, total_found_sources) > 0 else 0 | |
| high_ranked_overlap_percentage = round( | |
| total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \ | |
| if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0 | |
| # Compile results.csv | |
| cumulative_stats = { | |
| "total_questions_analyzed": valid_questions, | |
| "total_baseline_sources": total_baseline_sources, | |
| "total_found_sources": total_found_sources, | |
| "total_overlap_count": total_overlap, | |
| "overall_overlap_percentage": overlap_percentage, | |
| "total_high_ranked_baseline_sources": total_high_ranked_baseline, | |
| "total_high_ranked_found_sources": total_high_ranked_found, | |
| "total_high_ranked_overlap_count": total_high_ranked_overlap, | |
| "overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage, | |
| "avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions, | |
| 2) if valid_questions > 0 else 0, | |
| "avg_found_sources_per_question": round(total_found_sources / valid_questions, | |
| 2) if valid_questions > 0 else 0 | |
| } | |
| return pd.DataFrame([cumulative_stats]) | |
| async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int): | |
| """ | |
| Create unified view of sources from both baseline_sources and source_runs | |
| with indicators of where each source appears and their respective ranks. | |
| """ | |
| query_runs = """ | |
| SELECT tb.tractate_chunk_id as id, | |
| sr.rank as source_rank, | |
| sr.tractate, | |
| sr.folio, | |
| sr.reason as source_reason | |
| FROM source_run_results sr | |
| join talmud_bavli tb on sr.sugya_id = tb.xml_id | |
| WHERE sr.question_id = $1 | |
| AND sr.source_finder_run_id = $2 | |
| """ | |
| source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id) | |
| # Get sources from baseline_sources | |
| baseline_query = query_runs.replace("source_rank", "baseline_rank").replace("source_reason", "baseline_reason") | |
| baseline_sources = await conn.fetch(baseline_query, question_id, ranker_id) | |
| stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs) | |
| # Convert to dictionaries for easier lookup | |
| source_runs_dict = {s["id"]: dict(s) for s in source_runs} | |
| baseline_dict = {s["id"]: dict(s) for s in baseline_sources} | |
| # Get all unique sugya_ids | |
| all_sugya_ids = set(source_runs_dict.keys()) | set(baseline_dict.keys()) | |
| # Build unified results.csv | |
| unified_results = [] | |
| for sugya_id in all_sugya_ids: | |
| in_source_run = sugya_id in source_runs_dict | |
| in_baseline = sugya_id in baseline_dict | |
| if in_baseline: | |
| info = baseline_dict[sugya_id] | |
| else: | |
| info = source_runs_dict[sugya_id] | |
| result = { | |
| "id": sugya_id, | |
| "tractate": info.get("tractate"), | |
| "folio": info.get("folio"), | |
| "in_baseline": "Yes" if in_baseline else "No", | |
| "baseline_rank": baseline_dict.get(sugya_id, {}).get("baseline_rank", "N/A"), | |
| "in_source_run": "Yes" if in_source_run else "No", | |
| "source_run_rank": source_runs_dict.get(sugya_id, {}).get("source_rank", "N/A"), | |
| "source_reason": source_runs_dict.get(sugya_id, {}).get("source_reason", "N/A"), | |
| "baseline_reason": baseline_dict.get(sugya_id, {}).get("baseline_reason", "N/A"), | |
| } | |
| unified_results.append(result) | |
| return unified_results, stats_df | |
| async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int): | |
| """ | |
| Retrieves the text content for a given tractate chunk ID. | |
| """ | |
| query = """ | |
| SELECT tb.text as text | |
| FROM talmud_bavli tb | |
| WHERE tb.tractate_chunk_id = $1 | |
| """ | |
| result = await conn.fetchrow(query, tractate_chunk_id) | |
| return result["text"] if result else "Source text not found" | |
| def get_pg_sync_connection(schema="talmudexplore"): | |
| conn = psycopg2.connect(dbname=os.getenv("pg_dbname"), | |
| user=os.getenv("pg_user"), | |
| password=os.getenv("pg_password"), | |
| host=os.getenv("pg_host"), | |
| port=os.getenv("pg_port"), | |
| options=f"-c search_path={schema}") | |
| return conn | |