Update app.py
Browse files
app.py
CHANGED
|
@@ -1,98 +1,111 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import re
|
|
|
|
| 3 |
from transformers import pipeline
|
|
|
|
| 4 |
|
| 5 |
-
# 1.
|
| 6 |
-
|
| 7 |
-
print("正在加载模型,请稍候...")
|
| 8 |
ner_pipeline = pipeline("ner", model="uer/roberta-base-finetuned-cluener2020-chinese", aggregation_strategy="simple")
|
| 9 |
|
|
|
|
|
|
|
| 10 |
def extract_money(text):
|
| 11 |
-
"""
|
| 12 |
-
使用正则表达式提取金额,避免模型漏掉。
|
| 13 |
-
"""
|
| 14 |
money_entities = []
|
| 15 |
-
# 匹配规则:匹配 ¥, $, 元, 万, 亿 等金额格式
|
| 16 |
pattern = r'([¥$€USDCNY人民币]*\s*\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*[万亿]?(?:元|美元|欧元|CNY|HKD)?)'
|
| 17 |
-
|
| 18 |
matches = re.finditer(pattern, text)
|
| 19 |
for match in matches:
|
| 20 |
val = match.group(0).strip()
|
| 21 |
-
# 过滤掉单纯的数字(如年份 2023),只保留像金额的
|
| 22 |
if len(val) > 1 and (re.search(r'[^\d.,]', val) or '.' in val):
|
| 23 |
money_entities.append(val)
|
| 24 |
-
|
| 25 |
return money_entities
|
| 26 |
|
| 27 |
def run_ner_on_long_text(text):
|
| 28 |
-
"""
|
| 29 |
-
【核心修复】
|
| 30 |
-
将长文本切分成短片段(每段 400 字),防止超过 BERT 512 Token 的限制。
|
| 31 |
-
"""
|
| 32 |
-
# 设定步长,BERT限制是512 token,我们保守取 400 字符作为一段
|
| 33 |
chunk_size = 400
|
| 34 |
all_results = []
|
| 35 |
-
|
| 36 |
-
# 循环切分文本
|
| 37 |
for i in range(0, len(text), chunk_size):
|
| 38 |
chunk = text[i : i + chunk_size]
|
| 39 |
if not chunk.strip():
|
| 40 |
continue
|
| 41 |
-
|
| 42 |
try:
|
| 43 |
-
# 对每一小段运行模型
|
| 44 |
chunk_results = ner_pipeline(chunk)
|
| 45 |
all_results.extend(chunk_results)
|
| 46 |
except Exception as e:
|
| 47 |
print(f"片段处理出错: {e}")
|
| 48 |
continue
|
| 49 |
-
|
| 50 |
return all_results
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def extract_audit_info(contract_text, file_obj):
|
| 53 |
-
# 1.
|
| 54 |
-
content =
|
| 55 |
if file_obj is not None:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if not content:
|
| 63 |
-
return "请输入文本或上传文件"
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
if len(content) > 5000:
|
| 68 |
process_text = content[:5000]
|
| 69 |
-
|
| 70 |
else:
|
| 71 |
process_text = content
|
| 72 |
-
|
| 73 |
|
| 74 |
-
# 2.
|
| 75 |
ner_results = run_ner_on_long_text(process_text)
|
| 76 |
|
| 77 |
-
# 3.
|
| 78 |
money_results = extract_money(process_text)
|
| 79 |
|
| 80 |
-
# 4.
|
| 81 |
-
output_str = f"=== 提取报告 ===\n{
|
| 82 |
|
| 83 |
-
#
|
| 84 |
output_str += "💰【涉及金额】:\n"
|
| 85 |
-
#
|
| 86 |
-
money_results = sorted(list(set(money_results)))
|
| 87 |
if money_results:
|
| 88 |
for m in money_results:
|
| 89 |
output_str += f"- {m}\n"
|
| 90 |
else:
|
| 91 |
-
output_str += "(
|
| 92 |
-
|
| 93 |
output_str += "\n"
|
| 94 |
|
| 95 |
-
#
|
| 96 |
label_map = {
|
| 97 |
"organization": "🏢 组织/公司",
|
| 98 |
"company": "🏢 公司",
|
|
@@ -106,7 +119,7 @@ def extract_audit_info(contract_text, file_obj):
|
|
| 106 |
for item in ner_results:
|
| 107 |
group = item['entity_group']
|
| 108 |
word = item['word']
|
| 109 |
-
#
|
| 110 |
if item['score'] > 0.4 and len(word) > 1:
|
| 111 |
cn_label = label_map.get(group, group)
|
| 112 |
if cn_label not in found_entities:
|
|
@@ -121,15 +134,16 @@ def extract_audit_info(contract_text, file_obj):
|
|
| 121 |
|
| 122 |
return output_str
|
| 123 |
|
| 124 |
-
# --- Gradio 界面 ---
|
| 125 |
with gr.Blocks() as demo:
|
| 126 |
-
gr.Markdown("# 🧾
|
| 127 |
-
gr.Markdown("
|
| 128 |
|
| 129 |
with gr.Row():
|
| 130 |
with gr.Column():
|
| 131 |
-
input_text = gr.Textbox(label="
|
| 132 |
-
|
|
|
|
| 133 |
btn = gr.Button("🚀 开始分析", variant="primary")
|
| 134 |
|
| 135 |
with gr.Column():
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import re
|
| 3 |
+
import os
|
| 4 |
from transformers import pipeline
|
| 5 |
+
from pypdf import PdfReader # 引入 PDF 读取库
|
| 6 |
|
| 7 |
+
# --- 1. 加载模型 ---
|
| 8 |
+
print("正在加载 NER 模型...")
|
|
|
|
| 9 |
ner_pipeline = pipeline("ner", model="uer/roberta-base-finetuned-cluener2020-chinese", aggregation_strategy="simple")
|
| 10 |
|
| 11 |
+
# --- 2. 核心功能函数 ---
|
| 12 |
+
|
| 13 |
def extract_money(text):
|
| 14 |
+
"""正则提取金额"""
|
|
|
|
|
|
|
| 15 |
money_entities = []
|
|
|
|
| 16 |
pattern = r'([¥$€USDCNY人民币]*\s*\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*[万亿]?(?:元|美元|欧元|CNY|HKD)?)'
|
|
|
|
| 17 |
matches = re.finditer(pattern, text)
|
| 18 |
for match in matches:
|
| 19 |
val = match.group(0).strip()
|
|
|
|
| 20 |
if len(val) > 1 and (re.search(r'[^\d.,]', val) or '.' in val):
|
| 21 |
money_entities.append(val)
|
|
|
|
| 22 |
return money_entities
|
| 23 |
|
| 24 |
def run_ner_on_long_text(text):
|
| 25 |
+
"""分段处理长文本,避免 BERT 报错"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
chunk_size = 400
|
| 27 |
all_results = []
|
|
|
|
|
|
|
| 28 |
for i in range(0, len(text), chunk_size):
|
| 29 |
chunk = text[i : i + chunk_size]
|
| 30 |
if not chunk.strip():
|
| 31 |
continue
|
|
|
|
| 32 |
try:
|
|
|
|
| 33 |
chunk_results = ner_pipeline(chunk)
|
| 34 |
all_results.extend(chunk_results)
|
| 35 |
except Exception as e:
|
| 36 |
print(f"片段处理出错: {e}")
|
| 37 |
continue
|
|
|
|
| 38 |
return all_results
|
| 39 |
|
| 40 |
+
def read_file_content(file_obj):
|
| 41 |
+
"""
|
| 42 |
+
识别文件类型并提取文本
|
| 43 |
+
支持:.txt, .pdf
|
| 44 |
+
"""
|
| 45 |
+
content = ""
|
| 46 |
+
try:
|
| 47 |
+
file_path = file_obj.name
|
| 48 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
| 49 |
+
|
| 50 |
+
if file_ext == ".pdf":
|
| 51 |
+
# 处理 PDF
|
| 52 |
+
reader = PdfReader(file_path)
|
| 53 |
+
for page in reader.pages:
|
| 54 |
+
# 提取每一页的文本并拼接
|
| 55 |
+
text = page.extract_text()
|
| 56 |
+
if text:
|
| 57 |
+
content += text + "\n"
|
| 58 |
+
else:
|
| 59 |
+
# 默认当作 TXT 处理
|
| 60 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 61 |
+
content = f.read()
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
return f"ERROR: 文件读取失败 ({str(e)})"
|
| 65 |
+
|
| 66 |
+
return content
|
| 67 |
+
|
| 68 |
def extract_audit_info(contract_text, file_obj):
|
| 69 |
+
# 1. 获取文本内容(优先读文件,否则读文本框)
|
| 70 |
+
content = ""
|
| 71 |
if file_obj is not None:
|
| 72 |
+
content = read_file_content(file_obj)
|
| 73 |
+
if content.startswith("ERROR"):
|
| 74 |
+
return content # 返回错误信息
|
| 75 |
+
else:
|
| 76 |
+
content = contract_text
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
if not content or not content.strip():
|
| 79 |
+
return "⚠️ 未能提取到文本内容。请确保:\n1. 上传了正确的文件。\n2. 如果是 PDF,请确保是【文字版】而非【扫描图片版】(图片版需要OCR功能)。"
|
| 80 |
+
|
| 81 |
+
# 限制分析长度(防止内存爆炸,取前 5000 字)
|
| 82 |
if len(content) > 5000:
|
| 83 |
process_text = content[:5000]
|
| 84 |
+
warning = f"(提示:文本共 {len(content)} 字,仅分析前 5000 字)\n\n"
|
| 85 |
else:
|
| 86 |
process_text = content
|
| 87 |
+
warning = ""
|
| 88 |
|
| 89 |
+
# 2. AI 模型分析(分段)
|
| 90 |
ner_results = run_ner_on_long_text(process_text)
|
| 91 |
|
| 92 |
+
# 3. 正则提取金额
|
| 93 |
money_results = extract_money(process_text)
|
| 94 |
|
| 95 |
+
# 4. 生成报告
|
| 96 |
+
output_str = f"=== 📊 提取报告 ===\n{warning}"
|
| 97 |
|
| 98 |
+
# 金额部分
|
| 99 |
output_str += "💰【涉及金额】:\n"
|
| 100 |
+
money_results = sorted(list(set(money_results))) # 去重排序
|
|
|
|
| 101 |
if money_results:
|
| 102 |
for m in money_results:
|
| 103 |
output_str += f"- {m}\n"
|
| 104 |
else:
|
| 105 |
+
output_str += "(无)\n"
|
|
|
|
| 106 |
output_str += "\n"
|
| 107 |
|
| 108 |
+
# 实体部分
|
| 109 |
label_map = {
|
| 110 |
"organization": "🏢 组织/公司",
|
| 111 |
"company": "🏢 公司",
|
|
|
|
| 119 |
for item in ner_results:
|
| 120 |
group = item['entity_group']
|
| 121 |
word = item['word']
|
| 122 |
+
# 过滤杂质:置信度>0.4 且 长度>1
|
| 123 |
if item['score'] > 0.4 and len(word) > 1:
|
| 124 |
cn_label = label_map.get(group, group)
|
| 125 |
if cn_label not in found_entities:
|
|
|
|
| 134 |
|
| 135 |
return output_str
|
| 136 |
|
| 137 |
+
# --- 3. Gradio 界面 ---
|
| 138 |
with gr.Blocks() as demo:
|
| 139 |
+
gr.Markdown("# 🧾 智能审计/合同信息提取")
|
| 140 |
+
gr.Markdown("支持上传 **.txt** 或 **.pdf** 文件,自动提取金额、日期、公司名等。")
|
| 141 |
|
| 142 |
with gr.Row():
|
| 143 |
with gr.Column():
|
| 144 |
+
input_text = gr.Textbox(label="直接粘贴文本", lines=8, placeholder="在此粘贴合同文本...")
|
| 145 |
+
# 修改 file_types 支持 pdf
|
| 146 |
+
input_file = gr.File(label="上传文件 (PDF / TXT)", file_types=[".txt", ".pdf"])
|
| 147 |
btn = gr.Button("🚀 开始分析", variant="primary")
|
| 148 |
|
| 149 |
with gr.Column():
|