absa-ontology / app.py
darisdzakwanhoesien2
Modular Attempt
736b015
# app.py
# -*- coding: utf-8 -*-
"""
Top-level Gradio app for ESG ABSA (Rule-based | Classical ML | Deep | Hybrid++ | Explainability)
Expect directory layout:
./app.py
./core/
__init__.py
utils.py
lexicons.py
rule_based.py
classical_ml.py
deep_model.py
hybrid_model.py
explainability.py
app_state.py # small dict wrapper (see notes below)
"""
import os
import tempfile
import traceback
import gradio as gr
import pandas as pd
# Try to import core modules; provide helpful error if missing
try:
from core.utils import parse_document, safe_plot
from core.rule_based import run_rule_based, explain_rule_based_sentence
from core.classical_ml import run_classical_ml, explain_classical_sentence
from core.deep_model import run_deep_learning, explain_deep_sentence, plot_attention_plotly
from core.hybrid_model import run_hierarchical_hybrid, explain_hybrid_sentence, plot_ontology_scatter
from core.explainability import compare_explain, explain_sentence_across_models, plot_consistency_summary
from core.app_state import app_state
except Exception as e:
raise ImportError(f"Failed to import core modules. Make sure core/ package exists and contains the modules. Error: {e}")
EXAMPLE_TEXT = """## TANTANGAN DAN RESPONS TERHADAP ISU KEBERLANJUTAN
The ban on corn imports pushed us to become more self-sufficient by using locally sourced raw materials. Partnerships with local farmers became vital to secure supply and reduce reliance on international markets.
"""
def _safe_call(fn, *args, fallback=(None, pd.DataFrame(), None)):
"""
Helper to call model functions safely and return consistent outputs.
fallback default should match (csv_path, df, fig, ...) shape expected by callers.
"""
try:
return fn(*args)
except Exception as e:
traceback.print_exc()
# return reasonable fallback shaped to caller expected outputs
return fallback
# ---------------------------
# Gradio app UI
# ---------------------------
with gr.Blocks(title="ESG ABSA – Unified Explainability Dashboard (CPU-friendly demo)") as demo:
gr.Markdown("# 🌱 ESG ABSA — Unified (Rule-based • Classical • Deep • Hybrid++)")
gr.Markdown("Paste an ESG section or full report (headers `## ...` supported). Run any tab to produce CSV, preview and visualizations. Use Explainability tab to compare results across models.")
# Tab 1: Rule-based
with gr.Tab("1) Rule-based"):
t1 = gr.Textbox(lines=14, value=EXAMPLE_TEXT, label="Input Text")
r1_btn = gr.Button("Run Rule-based")
r1_file = gr.File(label="Download CSV")
r1_df = gr.DataFrame(label="Preview (Rule-based)", interactive=False)
r1_plot = gr.Plot(label="Visualization")
r1_expl = gr.DataFrame(label="Per-sentence Rule Explanations", interactive=False)
def _run_rule(text):
csv, df, fig = run_rule_based(text)
# explanation column
df2 = df.copy()
df2["Rule_Explanation"] = df2["Sentence_Text"].apply(lambda t: "; ".join(explain_rule_based_sentence(t)))
app_state["last_sentences"] = df2
return csv, df2, fig, df2[["Sentence_ID", "Rule_Explanation"]]
r1_btn.click(fn=_run_rule, inputs=t1, outputs=[r1_file, r1_df, r1_plot, r1_expl])
# Tab 2: Classical ML
with gr.Tab("2) Classical ML"):
t2 = gr.Textbox(lines=14, value=EXAMPLE_TEXT, label="Input Text")
r2_btn = gr.Button("Run Classical ML")
r2_file = gr.File(label="Download CSV")
r2_df = gr.DataFrame(label="Validation Predictions", interactive=False)
r2_plot = gr.Plot(label="Visualization")
r2_coef_sent = gr.DataFrame(label="Global Coefficients (Sentiment)", interactive=False)
r2_coef_aspect = gr.DataFrame(label="Global Coefficients (Aspect)", interactive=False)
def _run_classical(text):
csv, out, fig, coef_s, coef_a = run_classical_ml(text)
# update last_sentences for dashboard (prefer df_val)
if app_state.get("classical") and app_state["classical"].get("df_val") is not None:
app_state["last_sentences"] = app_state["classical"]["df_val"]
return csv, out, fig, coef_s, coef_a
r2_btn.click(fn=_run_classical, inputs=t2, outputs=[r2_file, r2_df, r2_plot, r2_coef_sent, r2_coef_aspect])
# optional initial load for convenience
demo.load(fn=_run_classical, inputs=[t2], outputs=[r2_file, r2_df, r2_plot, r2_coef_sent, r2_coef_aspect])
# Tab 3: Deep Learning
with gr.Tab("3) Deep Learning (mBERT demo)"):
t3 = gr.Textbox(lines=14, value=EXAMPLE_TEXT, label="Input Text")
e3 = gr.Slider(1, 2, value=1, step=1, label="Epochs (light demo)")
r3_btn = gr.Button("Train & Predict (mBERT)")
r3_file = gr.File(label="Download CSV")
r3_df = gr.DataFrame(label="Predictions", interactive=False)
r3_plot = gr.Plot(label="Visualization")
r3_interp = gr.DataFrame(label="Interpretability (tokens)", interactive=False)
def _run_deep(text, epochs):
csv, df, fig, interp = run_deep_learning(text, epochs)
if app_state.get("deep"):
app_state["last_sentences"] = app_state["deep"]["df"]
return csv, df, fig, interp
r3_btn.click(fn=_run_deep, inputs=[t3, e3], outputs=[r3_file, r3_df, r3_plot, r3_interp])
demo.load(fn=_run_deep, inputs=[t3, e3], outputs=[r3_file, r3_df, r3_plot, r3_interp])
# Tab 4: Hybrid++ (Hierarchical + MTL + Ontology)
with gr.Tab("4) Hybrid++ (Hierarchical + MTL + Ontology)"):
t4 = gr.Textbox(lines=14, value=EXAMPLE_TEXT, label="Input Text")
e4 = gr.Slider(minimum=1, maximum=50, value=3, label="Epochs")
tw4 = gr.Slider(minimum=0.0, maximum=2.0, value=1.5, step=0.1, label="Tone weight")
aw4 = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Alignment weight")
r4_btn = gr.Button("Run Hybrid++")
r4_file = gr.File(label="Download CSV")
r4_df = gr.DataFrame(label="Predictions", interactive=False)
r4_plot_t = gr.Plot(label="Tone × Sentiment")
r4_plot_a = gr.Plot(label="Ontology Alignment (cosine)")
r4_plot_s = gr.Plot(label="Tone Distribution by Section")
r4_metrics = gr.DataFrame(label="Metrics", interactive=False)
def _run_hybrid(text, epochs, tw, aw):
csv, df, f1, f2, f3, metrics = run_hierarchical_hybrid(text, epochs, tone_weight=tw, align_weight=aw)
if app_state.get("hybrid"):
app_state["last_sentences"] = app_state["hybrid"]["df"]
return csv, df, f1, f2, f3, metrics
r4_btn.click(fn=_run_hybrid, inputs=[t4, e4, tw4, aw4], outputs=[r4_file, r4_df, r4_plot_t, r4_plot_a, r4_plot_s, r4_metrics])
demo.load(fn=_run_hybrid, inputs=[t4, e4, tw4, aw4], outputs=[r4_file, r4_df, r4_plot_t, r4_plot_a, r4_plot_s, r4_metrics])
# Tab 5: Explainability Dashboard
with gr.Tab("5) 🧠 Explainability Dashboard"):
gr.Markdown("Compare explanations across Rule-based, Classical ML, Deep, and Hybrid models.")
# Input can be either a sentence typed directly or a dropdown from last run (we provide direct textbox)
sent_input = gr.Textbox(label="Enter a sentence for cross-model explanation", lines=3, placeholder="Type or paste a sentence from your input...")
compare_btn = gr.Button("Compare Explainability Across Models")
summary_table = gr.DataFrame(label="Cross-model summary", interactive=False)
deep_plot_out = gr.Plot(label="Deep: Token attention (Plotly)")
cls_plot_out = gr.Plot(label="Classical: Top TF-IDF features")
hybrid_plot_out = gr.Plot(label="Hybrid: Ontology embedding scatter")
consistency_plot = gr.Plot(label="Model consistency summary")
def _compare_sentence(sentence_text):
if not sentence_text or str(sentence_text).strip() == "":
return pd.DataFrame([["Error", "Please enter a sentence."]], columns=["Model", "Explanation"]), None, None, None, None
summary_df, deep_fig, cls_fig, hy_fig = compare_explain_for_sentence(sentence_text)
return summary_df, deep_fig, cls_fig, hy_fig, plot_consistency_summary_safe()
# We use a wrapper to reuse logic from core.explainability but accept direct sentence text
def compare_explain_for_sentence(sentence_text):
# Build summary rows similar to previous standalone compare_explain function but for a single sentence
records = []
# Rule
try:
rule_expl = explain_rule_based_sentence(sentence_text)
records.append(["Rule-based", ", ".join(rule_expl[:6])])
except Exception as e:
records.append(["Rule-based", f"Error: {e}"])
# Classical
try:
cls_expl = explain_classical_sentence(sentence_text)
if isinstance(cls_expl, dict) and "error" in cls_expl:
records.append(["Classical ML", cls_expl["error"]])
cls_fig = None
else:
pred = cls_expl.get("prediction", "N/A")
local = cls_expl.get("local_features", [])
local_text = "; ".join([f"{f['feature']} ({f['contribution']:.3f})" for f in local]) if local else "No local features"
records.append(["Classical ML", f"Pred: {pred}; Local: {local_text}"])
# global plot if available
top = cls_expl.get("global_top", None)
if top is not None and isinstance(top, pd.DataFrame) and not top.empty:
import plotly.express as px
cls_fig = px.bar(top.nlargest(12, "Coefficient"), x="Feature", y="Coefficient", color="Direction", title="Classical: Top features (sentiment)")
cls_fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
else:
cls_fig = None
except Exception as e:
records.append(["Classical ML", f"Error: {e}"])
cls_fig = None
# Deep
try:
deep_expl = explain_deep_sentence(sentence_text)
if isinstance(deep_expl, dict) and "error" in deep_expl:
records.append(["Deep (mBERT)", deep_expl["error"]])
deep_fig = None
else:
toks = deep_expl.get("tokens", [])
w = deep_expl.get("weights", [])
pairs = []
for tok, wt in zip(toks, w):
if tok in ["[PAD]", "[CLS]", "[SEP]"]:
continue
pairs.append(f"{tok}:{wt:.3f}")
records.append(["Deep (mBERT)", ", ".join(pairs[:12]) or "No tokens"])
deep_fig = plot_attention_plotly(toks, w, title="mBERT: Token attention (first-layer avg)")
except Exception as e:
records.append(["Deep (mBERT)", f"Error: {e}"])
deep_fig = None
# Hybrid
try:
hy_expl = explain_hybrid_sentence(sentence_text)
if isinstance(hy_expl, dict) and "error" in hy_expl:
records.append(["Hybrid++", hy_expl["error"]])
hy_fig = None
else:
records.append(["Hybrid++", f"Ontology alignment: {hy_expl['ontology_alignment']:.3f}; Path: {hy_expl['ontology_path']}; Section: '{hy_expl.get('section_name', 'N/A')}' (influence: {hy_expl.get('section_influence', 0.0):.2f})"])
# build hybrid scatter if hybrid state available
hy_state = app_state.get("hybrid")
hy_fig = plot_ontology_scatter(hy_state) if hy_state else None
except Exception as e:
records.append(["Hybrid++", f"Error: {e}"])
hy_fig = None
summary_df = pd.DataFrame(records, columns=["Model", "Top Explanation"])
return summary_df, deep_fig, cls_fig, hy_fig
def plot_consistency_summary_safe():
try:
return plot_consistency_summary()
except Exception:
return None
compare_btn.click(fn=_compare_sentence, inputs=[sent_input], outputs=[summary_table, deep_plot_out, cls_plot_out, hybrid_plot_out, consistency_plot])
demo.launch()