Spaces:
Runtime error
Runtime error
fix connection reuse
Browse files- app.py +49 -47
- data_access.py +141 -153
- eval_tables.py +5 -0
- tests/test_db_layer.py +21 -17
app.py
CHANGED
|
@@ -5,7 +5,8 @@ import pandas as pd
|
|
| 5 |
import logging
|
| 6 |
|
| 7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
| 8 |
-
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata
|
|
|
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
@@ -20,7 +21,7 @@ baseline_ranker_options = []
|
|
| 20 |
run_ids = []
|
| 21 |
available_run_id_dict = {}
|
| 22 |
finder_options = []
|
| 23 |
-
previous_run_id =
|
| 24 |
|
| 25 |
run_id_dropdown = None
|
| 26 |
|
|
@@ -29,13 +30,13 @@ run_id_dropdown = None
|
|
| 29 |
# Initialize data in a single async function
|
| 30 |
async def initialize_data():
|
| 31 |
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
questions = await get_questions()
|
| 34 |
-
source_finders = await get_source_finders()
|
| 35 |
-
|
| 36 |
-
baseline_rankers = await get_baseline_rankers()
|
| 37 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
| 38 |
-
|
| 39 |
# Convert to dictionaries for easier lookup
|
| 40 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
| 41 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
@@ -52,7 +53,7 @@ def update_sources_list(question_option, source_finder_id, run_id: str, baseline
|
|
| 52 |
if evt:
|
| 53 |
logger.info(f"event: {evt.target.elem_id}")
|
| 54 |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
|
| 55 |
-
return gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 56 |
|
| 57 |
if type(run_id) == str:
|
| 58 |
previous_run_id = run_id
|
|
@@ -65,55 +66,56 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
| 65 |
if not question_option:
|
| 66 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
|
| 67 |
logger.info("processing update")
|
| 68 |
-
|
| 69 |
-
baseline_ranker_name
|
| 70 |
-
|
| 71 |
-
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
finder_id_int = source_finders_dict.get(source_finder_name)
|
| 75 |
-
else:
|
| 76 |
-
finder_id_int = None
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
run_id_int = available_run_id_dict.get(run_id)
|
| 81 |
-
all_stats = await calculate_cumulative_statistics_for_all_questions(run_id_int, baseline_ranker_id_int)
|
| 82 |
else:
|
| 83 |
-
|
| 84 |
-
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
if run_id not in run_id_options:
|
| 92 |
-
run_id = run_id_options[0]
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
| 95 |
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
-
source_runs = None
|
| 99 |
-
stats = None
|
| 100 |
-
# Get source runs data
|
| 101 |
-
if finder_id_int:
|
| 102 |
-
source_runs, stats = await get_unified_sources(question_id, run_id_int, baseline_ranker_id_int)
|
| 103 |
-
# Create DataFrame for display
|
| 104 |
-
df = pd.DataFrame(source_runs)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
'folio', 'reason']
|
| 112 |
-
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
result_message = f"Found {len(source_runs)} results"
|
| 119 |
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
|
|
@@ -128,7 +130,8 @@ async def handle_row_selection_async(evt: gr.SelectData):
|
|
| 128 |
# Get the ID from the selected row
|
| 129 |
tractate_chunk_id = evt.row_value[0]
|
| 130 |
# Get the source text
|
| 131 |
-
|
|
|
|
| 132 |
return text
|
| 133 |
except Exception as e:
|
| 134 |
return f"Error retrieving source text: {str(e)}"
|
|
@@ -248,7 +251,6 @@ async def main():
|
|
| 248 |
question_dropdown.change(
|
| 249 |
update_sources_list,
|
| 250 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
| 251 |
-
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
| 252 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 253 |
)
|
| 254 |
|
|
|
|
| 5 |
import logging
|
| 6 |
|
| 7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
| 8 |
+
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
|
| 9 |
+
get_async_connection
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
| 21 |
run_ids = []
|
| 22 |
available_run_id_dict = {}
|
| 23 |
finder_options = []
|
| 24 |
+
previous_run_id = "initial_run"
|
| 25 |
|
| 26 |
run_id_dropdown = None
|
| 27 |
|
|
|
|
| 30 |
# Initialize data in a single async function
|
| 31 |
async def initialize_data():
|
| 32 |
global questions, source_finders, questions_dict, source_finders_dict, question_options, finder_options, baseline_rankers_dict, source_finders_dict, baseline_ranker_options
|
| 33 |
+
async with get_async_connection() as conn:
|
| 34 |
+
# Get questions and source finders
|
| 35 |
+
questions = await get_questions(conn)
|
| 36 |
+
source_finders = await get_source_finders(conn)
|
| 37 |
+
baseline_rankers = await get_baseline_rankers(conn)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
|
|
| 40 |
# Convert to dictionaries for easier lookup
|
| 41 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
| 42 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
|
|
| 53 |
if evt:
|
| 54 |
logger.info(f"event: {evt.target.elem_id}")
|
| 55 |
if evt.target.elem_id == "run_id_dropdown" and (type(run_id) == list or run_id == previous_run_id):
|
| 56 |
+
return gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 57 |
|
| 58 |
if type(run_id) == str:
|
| 59 |
previous_run_id = run_id
|
|
|
|
| 66 |
if not question_option:
|
| 67 |
return gr.skip(), gr.skip(), gr.skip(), "No question selected", ""
|
| 68 |
logger.info("processing update")
|
| 69 |
+
async with get_async_connection() as conn:
|
| 70 |
+
if type(baseline_ranker_name) == list:
|
| 71 |
+
baseline_ranker_name = baseline_ranker_name[0]
|
|
|
|
| 72 |
|
| 73 |
+
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(baseline_ranker_name)
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
if len(source_finder_name):
|
| 76 |
+
finder_id_int = source_finders_dict.get(source_finder_name)
|
|
|
|
|
|
|
| 77 |
else:
|
| 78 |
+
finder_id_int = None
|
|
|
|
| 79 |
|
| 80 |
+
if question_option == "All questions":
|
| 81 |
+
if finder_id_int and type(run_id) == str:
|
| 82 |
+
run_id_int = available_run_id_dict.get(run_id)
|
| 83 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int, baseline_ranker_id_int)
|
| 84 |
+
else:
|
| 85 |
+
all_stats = None
|
| 86 |
+
return None, all_stats, gr.skip(), "Select Run Id and source finder to see results", ""
|
| 87 |
|
| 88 |
+
# Extract question ID from selection
|
| 89 |
+
question_id = questions_dict.get(question_option)
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
available_run_id_dict = await get_run_ids(conn, question_id, finder_id_int)
|
| 92 |
+
run_id_options = list(available_run_id_dict.keys())
|
| 93 |
+
if run_id not in run_id_options:
|
| 94 |
+
run_id = run_id_options[0]
|
| 95 |
|
| 96 |
+
run_id_int = available_run_id_dict.get(run_id)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
source_runs = None
|
| 101 |
+
stats = None
|
| 102 |
+
# Get source runs data
|
| 103 |
+
if finder_id_int:
|
| 104 |
+
source_runs, stats = await get_unified_sources(conn, question_id, run_id_int, baseline_ranker_id_int)
|
| 105 |
+
# Create DataFrame for display
|
| 106 |
+
df = pd.DataFrame(source_runs)
|
| 107 |
|
| 108 |
+
if not source_runs:
|
| 109 |
+
return None, None, run_id_options, "No results found for the selected filters",
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Format table columns
|
| 112 |
+
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank', 'tractate',
|
| 113 |
+
'folio', 'reason']
|
| 114 |
+
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
| 115 |
+
|
| 116 |
+
# CSV for download
|
| 117 |
+
# csv_data = df.to_csv(index=False)
|
| 118 |
+
metadata = await get_metadata(conn, question_id, run_id_int)
|
| 119 |
|
| 120 |
result_message = f"Found {len(source_runs)} results"
|
| 121 |
return df_display, stats, gr.Dropdown(choices=run_id_options, value=run_id), result_message, metadata
|
|
|
|
| 130 |
# Get the ID from the selected row
|
| 131 |
tractate_chunk_id = evt.row_value[0]
|
| 132 |
# Get the source text
|
| 133 |
+
async with get_async_connection() as conn:
|
| 134 |
+
text = await get_source_text(conn, tractate_chunk_id)
|
| 135 |
return text
|
| 136 |
except Exception as e:
|
| 137 |
return f"Error retrieving source text: {str(e)}"
|
|
|
|
| 251 |
question_dropdown.change(
|
| 252 |
update_sources_list,
|
| 253 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
|
|
|
| 254 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 255 |
)
|
| 256 |
|
data_access.py
CHANGED
|
@@ -30,85 +30,80 @@ async def get_async_connection(schema="talmudexplore"):
|
|
| 30 |
await conn.close()
|
| 31 |
|
| 32 |
|
| 33 |
-
async def get_questions():
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
return ""
|
| 47 |
-
return metadata.get('metadata')
|
| 48 |
|
| 49 |
|
| 50 |
# Get distinct source finders
|
| 51 |
-
async def get_source_finders():
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
return [{"id": f["id"], "name": f["name"]} for f in finders]
|
| 55 |
|
| 56 |
|
| 57 |
# Get distinct run IDs for a question
|
| 58 |
-
async def get_run_ids(question_id: int, source_finder_id: int):
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
async def calculate_baseline_vs_source_stats_for_question(baseline_sources , source_runs_sources):
|
| 78 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
| 79 |
# e.g. overlap, high ranked overlap, etc.
|
| 80 |
-
async with get_async_connection() as conn:
|
| 81 |
-
actual_sources_set = {s["id"] for s in source_runs_sources}
|
| 82 |
-
baseline_sources_set = {s["id"] for s in baseline_sources}
|
| 83 |
-
|
| 84 |
-
# Calculate overlap
|
| 85 |
-
overlap = actual_sources_set.intersection(baseline_sources_set)
|
| 86 |
-
# only_in_1 = actual_sources_set - baseline_sources_set
|
| 87 |
-
# only_in_2 = baseline_sources_set - actual_sources_set
|
| 88 |
-
|
| 89 |
-
# Calculate high-ranked overlap (rank >= 4)
|
| 90 |
-
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
|
| 91 |
-
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
|
| 92 |
-
|
| 93 |
-
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
| 94 |
-
|
| 95 |
-
results = {
|
| 96 |
-
"total_baseline_sources": len(baseline_sources),
|
| 97 |
-
"total_found_sources": len(source_runs_sources),
|
| 98 |
-
"overlap_count": len(overlap),
|
| 99 |
-
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
|
| 100 |
-
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
|
| 101 |
-
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
|
| 102 |
-
"num_high_ranked_found_sources": len(actual_high_ranked),
|
| 103 |
-
"high_ranked_overlap_count": len(high_ranked_overlap),
|
| 104 |
-
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
| 105 |
-
}
|
| 106 |
-
#convert results to dataframe
|
| 107 |
-
results_df = pd.DataFrame([results])
|
| 108 |
-
return results_df
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
"""
|
| 113 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
| 114 |
|
|
@@ -119,83 +114,75 @@ async def calculate_cumulative_statistics_for_all_questions(source_finder_run_id
|
|
| 119 |
Returns:
|
| 120 |
pd.DataFrame: DataFrame containing aggregated statistics
|
| 121 |
"""
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
async def get_unified_sources(question_id: int, source_finder_run_id: int, ranker_id: int):
|
| 187 |
"""
|
| 188 |
Create unified view of sources from both baseline_sources and source_runs
|
| 189 |
with indicators of where each source appears and their respective ranks.
|
| 190 |
"""
|
| 191 |
-
async with get_async_connection() as conn:
|
| 192 |
-
stats_df, unified_results = await get_stats(conn, question_id, ranker_id, source_finder_run_id)
|
| 193 |
-
|
| 194 |
-
return unified_results, stats_df
|
| 195 |
|
| 196 |
-
|
| 197 |
-
async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
| 198 |
-
# Get sources from source_runs
|
| 199 |
query_runs = """
|
| 200 |
SELECT tb.tractate_chunk_id as id,
|
| 201 |
sr.rank as source_rank,
|
|
@@ -217,7 +204,7 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
|
| 217 |
AND bs.ranker_id = $2
|
| 218 |
"""
|
| 219 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
| 220 |
-
stats_df = await calculate_baseline_vs_source_stats_for_question(baseline_sources, source_runs)
|
| 221 |
# Convert to dictionaries for easier lookup
|
| 222 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
| 223 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
|
@@ -244,21 +231,22 @@ async def get_stats(conn, question_id, ranker_id, source_finder_run_id):
|
|
| 244 |
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
| 245 |
}
|
| 246 |
unified_results.append(result)
|
| 247 |
-
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
-
async def get_source_text(tractate_chunk_id: int):
|
| 251 |
"""
|
| 252 |
Retrieves the text content for a given tractate chunk ID.
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
|
| 263 |
def get_pg_sync_connection(schema="talmudexplore"):
|
| 264 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
|
|
|
| 30 |
await conn.close()
|
| 31 |
|
| 32 |
|
| 33 |
+
async def get_questions(conn: asyncpg.Connection):
|
| 34 |
+
questions = await conn.fetch("SELECT id, question_text FROM questions ORDER BY id")
|
| 35 |
+
return [{"id": q["id"], "text": q["question_text"]} for q in questions]
|
| 36 |
+
|
| 37 |
+
async def get_metadata(conn: asyncpg.Connection, question_id: int, source_finder_id_run_id: int):
|
| 38 |
+
metadata = await conn.fetchrow('''
|
| 39 |
+
SELECT metadata
|
| 40 |
+
FROM source_finder_run_question_metadata sfrqm
|
| 41 |
+
WHERE sfrqm.question_id = $1 and sfrqm.source_finder_run_id = $2;
|
| 42 |
+
''', question_id, source_finder_id_run_id)
|
| 43 |
+
if metadata is None:
|
| 44 |
+
return ""
|
| 45 |
+
return metadata.get('metadata')
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
# Get distinct source finders
|
| 49 |
+
async def get_source_finders(conn: asyncpg.Connection):
|
| 50 |
+
finders = await conn.fetch("SELECT id, source_finder_type as name FROM source_finders ORDER BY id")
|
| 51 |
+
return [{"id": f["id"], "name": f["name"]} for f in finders]
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
# Get distinct run IDs for a question
|
| 55 |
+
async def get_run_ids(conn: asyncpg.Connection, question_id: int, source_finder_id: int):
|
| 56 |
+
query = """
|
| 57 |
+
select distinct sfr.description, srs.source_finder_run_id as run_id
|
| 58 |
+
from talmudexplore.source_run_results srs
|
| 59 |
+
join talmudexplore.source_finder_runs sfr on srs.source_finder_run_id = sfr.id
|
| 60 |
+
join talmudexplore.source_finders sf on sfr.source_finder_id = sf.id
|
| 61 |
+
where sfr.source_finder_id = $1
|
| 62 |
+
and srs.question_id = $2
|
| 63 |
+
"""
|
| 64 |
+
run_ids = await conn.fetch(query, source_finder_id, question_id)
|
| 65 |
+
return {r["description"]:r["run_id"] for r in run_ids}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def get_baseline_rankers(conn: asyncpg.Connection):
|
| 69 |
+
rankers = await conn.fetch("SELECT id, ranker FROM rankers ORDER BY id")
|
| 70 |
+
return [{"id": f["id"], "name": f["ranker"]} for f in rankers]
|
| 71 |
+
|
| 72 |
+
async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
|
|
|
|
|
|
|
| 73 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
| 74 |
# e.g. overlap, high ranked overlap, etc.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
actual_sources_set = {s["id"] for s in source_runs_sources}
|
| 77 |
+
baseline_sources_set = {s["id"] for s in baseline_sources}
|
| 78 |
+
|
| 79 |
+
# Calculate overlap
|
| 80 |
+
overlap = actual_sources_set.intersection(baseline_sources_set)
|
| 81 |
+
# only_in_1 = actual_sources_set - baseline_sources_set
|
| 82 |
+
# only_in_2 = baseline_sources_set - actual_sources_set
|
| 83 |
+
|
| 84 |
+
# Calculate high-ranked overlap (rank >= 4)
|
| 85 |
+
actual_high_ranked = {s["id"] for s in source_runs_sources if int(s["source_rank"]) >= 4}
|
| 86 |
+
baseline_high_ranked = {s["id"] for s in baseline_sources if int(s["baseline_rank"]) >= 4}
|
| 87 |
+
|
| 88 |
+
high_ranked_overlap = actual_high_ranked.intersection(baseline_high_ranked)
|
| 89 |
+
|
| 90 |
+
results = {
|
| 91 |
+
"total_baseline_sources": len(baseline_sources),
|
| 92 |
+
"total_found_sources": len(source_runs_sources),
|
| 93 |
+
"overlap_count": len(overlap),
|
| 94 |
+
"overlap_percentage": round(len(overlap) * 100 / max(len(actual_sources_set), len(baseline_sources_set)),
|
| 95 |
+
2) if max(len(actual_sources_set), len(baseline_sources_set)) > 0 else 0,
|
| 96 |
+
"num_high_ranked_baseline_sources": len(baseline_high_ranked),
|
| 97 |
+
"num_high_ranked_found_sources": len(actual_high_ranked),
|
| 98 |
+
"high_ranked_overlap_count": len(high_ranked_overlap),
|
| 99 |
+
"high_ranked_overlap_percentage": round(len(high_ranked_overlap) * 100 / max(len(actual_high_ranked), len(baseline_high_ranked)), 2) if max(len(actual_high_ranked), len(baseline_high_ranked)) > 0 else 0
|
| 100 |
+
}
|
| 101 |
+
#convert results to dataframe
|
| 102 |
+
results_df = pd.DataFrame([results])
|
| 103 |
+
return results_df
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
async def calculate_cumulative_statistics_for_all_questions(conn: asyncpg.Connection, source_finder_run_id: int, ranker_id: int):
|
| 107 |
"""
|
| 108 |
Calculate cumulative statistics across all questions for a specific source finder, run, and ranker.
|
| 109 |
|
|
|
|
| 114 |
Returns:
|
| 115 |
pd.DataFrame: DataFrame containing aggregated statistics
|
| 116 |
"""
|
| 117 |
+
# Get all questions
|
| 118 |
+
query = "SELECT id FROM questions ORDER BY id"
|
| 119 |
+
questions = await conn.fetch(query)
|
| 120 |
+
question_ids = [q["id"] for q in questions]
|
| 121 |
+
|
| 122 |
+
# Initialize aggregates
|
| 123 |
+
total_baseline_sources = 0
|
| 124 |
+
total_found_sources = 0
|
| 125 |
+
total_overlap = 0
|
| 126 |
+
total_high_ranked_baseline = 0
|
| 127 |
+
total_high_ranked_found = 0
|
| 128 |
+
total_high_ranked_overlap = 0
|
| 129 |
+
|
| 130 |
+
# Process each question
|
| 131 |
+
valid_questions = 0
|
| 132 |
+
for question_id in question_ids:
|
| 133 |
+
try:
|
| 134 |
+
# Get unified sources for this question
|
| 135 |
+
sources, stats = await get_unified_sources(conn, question_id, ranker_id, source_finder_run_id)
|
| 136 |
+
|
| 137 |
+
if sources and len(sources) > 0:
|
| 138 |
+
valid_questions += 1
|
| 139 |
+
stats_dict = stats.iloc[0].to_dict()
|
| 140 |
+
|
| 141 |
+
# Add to running totals
|
| 142 |
+
total_baseline_sources += stats_dict.get('total_baseline_sources', 0)
|
| 143 |
+
total_found_sources += stats_dict.get('total_found_sources', 0)
|
| 144 |
+
total_overlap += stats_dict.get('overlap_count', 0)
|
| 145 |
+
total_high_ranked_baseline += stats_dict.get('num_high_ranked_baseline_sources', 0)
|
| 146 |
+
total_high_ranked_found += stats_dict.get('num_high_ranked_found_sources', 0)
|
| 147 |
+
total_high_ranked_overlap += stats_dict.get('high_ranked_overlap_count', 0)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
# Skip questions with errors
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
# Calculate overall percentages
|
| 153 |
+
overlap_percentage = round(total_overlap * 100 / max(total_baseline_sources, total_found_sources), 2) \
|
| 154 |
+
if max(total_baseline_sources, total_found_sources) > 0 else 0
|
| 155 |
+
|
| 156 |
+
high_ranked_overlap_percentage = round(
|
| 157 |
+
total_high_ranked_overlap * 100 / max(total_high_ranked_baseline, total_high_ranked_found), 2) \
|
| 158 |
+
if max(total_high_ranked_baseline, total_high_ranked_found) > 0 else 0
|
| 159 |
+
|
| 160 |
+
# Compile results
|
| 161 |
+
cumulative_stats = {
|
| 162 |
+
"total_questions_analyzed": valid_questions,
|
| 163 |
+
"total_baseline_sources": total_baseline_sources,
|
| 164 |
+
"total_found_sources": total_found_sources,
|
| 165 |
+
"total_overlap_count": total_overlap,
|
| 166 |
+
"overall_overlap_percentage": overlap_percentage,
|
| 167 |
+
"total_high_ranked_baseline_sources": total_high_ranked_baseline,
|
| 168 |
+
"total_high_ranked_found_sources": total_high_ranked_found,
|
| 169 |
+
"total_high_ranked_overlap_count": total_high_ranked_overlap,
|
| 170 |
+
"overall_high_ranked_overlap_percentage": high_ranked_overlap_percentage,
|
| 171 |
+
"avg_baseline_sources_per_question": round(total_baseline_sources / valid_questions,
|
| 172 |
+
2) if valid_questions > 0 else 0,
|
| 173 |
+
"avg_found_sources_per_question": round(total_found_sources / valid_questions,
|
| 174 |
+
2) if valid_questions > 0 else 0
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
return pd.DataFrame([cumulative_stats])
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source_finder_run_id: int, ranker_id: int):
|
|
|
|
| 181 |
"""
|
| 182 |
Create unified view of sources from both baseline_sources and source_runs
|
| 183 |
with indicators of where each source appears and their respective ranks.
|
| 184 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
|
|
|
|
|
|
|
|
|
| 186 |
query_runs = """
|
| 187 |
SELECT tb.tractate_chunk_id as id,
|
| 188 |
sr.rank as source_rank,
|
|
|
|
| 204 |
AND bs.ranker_id = $2
|
| 205 |
"""
|
| 206 |
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
| 207 |
+
stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
|
| 208 |
# Convert to dictionaries for easier lookup
|
| 209 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
| 210 |
baseline_dict = {s["id"]: dict(s) for s in baseline_sources}
|
|
|
|
| 231 |
"metadata": source_runs_dict.get(sugya_id, {}).get("metadata", "")
|
| 232 |
}
|
| 233 |
unified_results.append(result)
|
| 234 |
+
|
| 235 |
+
return unified_results, stats_df
|
| 236 |
|
| 237 |
|
| 238 |
+
async def get_source_text(conn: asyncpg.Connection, tractate_chunk_id: int):
|
| 239 |
"""
|
| 240 |
Retrieves the text content for a given tractate chunk ID.
|
| 241 |
"""
|
| 242 |
+
|
| 243 |
+
query = """
|
| 244 |
+
SELECT tb.text_with_nikud as text
|
| 245 |
+
FROM talmud_bavli tb
|
| 246 |
+
WHERE tb.tractate_chunk_id = $1
|
| 247 |
+
"""
|
| 248 |
+
result = await conn.fetchrow(query, tractate_chunk_id)
|
| 249 |
+
return result["text"] if result else "Source text not found"
|
| 250 |
|
| 251 |
def get_pg_sync_connection(schema="talmudexplore"):
|
| 252 |
conn = psycopg2.connect(dbname=os.getenv("pg_dbname"),
|
eval_tables.py
CHANGED
|
@@ -92,6 +92,11 @@ def create_eval_database():
|
|
| 92 |
);
|
| 93 |
''')
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
conn.commit()
|
| 96 |
conn.close()
|
| 97 |
|
|
|
|
| 92 |
);
|
| 93 |
''')
|
| 94 |
|
| 95 |
+
cursor.execute('''alter table source_run_results
|
| 96 |
+
add constraint source_run_results_pk
|
| 97 |
+
unique (source_finder_run_id, question_id, sugya_id);
|
| 98 |
+
''')
|
| 99 |
+
|
| 100 |
conn.commit()
|
| 101 |
conn.close()
|
| 102 |
|
tests/test_db_layer.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import pytest
|
| 3 |
|
| 4 |
-
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids
|
|
|
|
| 5 |
from data_access import get_unified_sources
|
| 6 |
|
| 7 |
|
| 8 |
@pytest.mark.asyncio
|
| 9 |
async def test_get_unified_sources():
|
| 10 |
-
|
|
|
|
| 11 |
assert results is not None
|
| 12 |
assert stats is not None
|
| 13 |
|
|
@@ -23,12 +25,12 @@ async def test_get_unified_sources():
|
|
| 23 |
@pytest.mark.asyncio
|
| 24 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
| 25 |
# Test with known source_finder_id, run_id, and ranker_id
|
| 26 |
-
|
| 27 |
-
run_id = 1
|
| 28 |
ranker_id = 1
|
| 29 |
|
| 30 |
# Call the function to test
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
# Check basic structure of results
|
| 34 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
|
@@ -65,12 +67,12 @@ async def test_calculate_cumulative_statistics_for_all_questions():
|
|
| 65 |
@pytest.mark.asyncio
|
| 66 |
async def test_get_metadata_none_returned():
|
| 67 |
# Test with known source_finder_id, run_id, and ranker_id
|
| 68 |
-
|
| 69 |
-
run_id = 1
|
| 70 |
question_id = 1
|
| 71 |
|
| 72 |
# Call the function to test
|
| 73 |
-
|
|
|
|
| 74 |
|
| 75 |
assert result == "", "Should return empty string when no metadata is found"
|
| 76 |
|
|
@@ -81,7 +83,8 @@ async def test_get_metadata():
|
|
| 81 |
question_id = 1
|
| 82 |
|
| 83 |
# Call the function to test
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
assert result is not None, "Should return metadata when it exists"
|
| 87 |
|
|
@@ -93,16 +96,17 @@ async def test_get_run_ids():
|
|
| 93 |
source_finder_id = 2 # Using a source finder ID that exists in the test database
|
| 94 |
|
| 95 |
# Call the function to test
|
| 96 |
-
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
|
| 108 |
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import pytest
|
| 3 |
|
| 4 |
+
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \
|
| 5 |
+
get_async_connection
|
| 6 |
from data_access import get_unified_sources
|
| 7 |
|
| 8 |
|
| 9 |
@pytest.mark.asyncio
|
| 10 |
async def test_get_unified_sources():
|
| 11 |
+
async with get_async_connection() as conn:
|
| 12 |
+
results, stats = await get_unified_sources(conn,2, 2, 1)
|
| 13 |
assert results is not None
|
| 14 |
assert stats is not None
|
| 15 |
|
|
|
|
| 25 |
@pytest.mark.asyncio
|
| 26 |
async def test_calculate_cumulative_statistics_for_all_questions():
|
| 27 |
# Test with known source_finder_id, run_id, and ranker_id
|
| 28 |
+
source_finder_run_id = 2
|
|
|
|
| 29 |
ranker_id = 1
|
| 30 |
|
| 31 |
# Call the function to test
|
| 32 |
+
async with get_async_connection() as conn:
|
| 33 |
+
result = await calculate_cumulative_statistics_for_all_questions(conn, source_finder_run_id, ranker_id)
|
| 34 |
|
| 35 |
# Check basic structure of results
|
| 36 |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame"
|
|
|
|
| 67 |
@pytest.mark.asyncio
|
| 68 |
async def test_get_metadata_none_returned():
|
| 69 |
# Test with known source_finder_id, run_id, and ranker_id
|
| 70 |
+
source_finder_run_id = 1
|
|
|
|
| 71 |
question_id = 1
|
| 72 |
|
| 73 |
# Call the function to test
|
| 74 |
+
async with get_async_connection() as conn:
|
| 75 |
+
result = await get_metadata(conn, question_id, source_finder_run_id)
|
| 76 |
|
| 77 |
assert result == "", "Should return empty string when no metadata is found"
|
| 78 |
|
|
|
|
| 83 |
question_id = 1
|
| 84 |
|
| 85 |
# Call the function to test
|
| 86 |
+
async with get_async_connection() as conn:
|
| 87 |
+
result = await get_metadata(conn, question_id, source_finder_run_id)
|
| 88 |
|
| 89 |
assert result is not None, "Should return metadata when it exists"
|
| 90 |
|
|
|
|
| 96 |
source_finder_id = 2 # Using a source finder ID that exists in the test database
|
| 97 |
|
| 98 |
# Call the function to test
|
| 99 |
+
async with get_async_connection() as conn:
|
| 100 |
+
result = await get_run_ids(conn, question_id, source_finder_id)
|
| 101 |
|
| 102 |
+
# Verify the result is a dictionary
|
| 103 |
+
assert isinstance(result, dict), "Result should be a dictionary"
|
| 104 |
|
| 105 |
+
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder)
|
| 106 |
+
assert len(result) > 0, "Should return at least one run ID"
|
| 107 |
|
| 108 |
+
# Test with a non-existent question_id
|
| 109 |
+
non_existent_question_id = 9999
|
| 110 |
+
empty_result = await get_run_ids(conn, non_existent_question_id, source_finder_id)
|
| 111 |
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question"
|
| 112 |
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question"
|