happybamboo commited on
Commit
834908b
·
verified ·
1 Parent(s): 07dfe96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -50
app.py CHANGED
@@ -1,98 +1,111 @@
1
  import gradio as gr
2
  import re
 
3
  from transformers import pipeline
 
4
 
5
- # 1. 加载 NER 模型
6
- # 这是一个中文命名实体识别模型
7
- print("正在加载模型,请稍候...")
8
  ner_pipeline = pipeline("ner", model="uer/roberta-base-finetuned-cluener2020-chinese", aggregation_strategy="simple")
9
 
 
 
10
  def extract_money(text):
11
- """
12
- 使用正则表达式提取金额,避免模型漏掉。
13
- """
14
  money_entities = []
15
- # 匹配规则:匹配 ¥, $, 元, 万, 亿 等金额格式
16
  pattern = r'([¥$€USDCNY人民币]*\s*\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*[万亿]?(?:元|美元|欧元|CNY|HKD)?)'
17
-
18
  matches = re.finditer(pattern, text)
19
  for match in matches:
20
  val = match.group(0).strip()
21
- # 过滤掉单纯的数字(如年份 2023),只保留像金额的
22
  if len(val) > 1 and (re.search(r'[^\d.,]', val) or '.' in val):
23
  money_entities.append(val)
24
-
25
  return money_entities
26
 
27
  def run_ner_on_long_text(text):
28
- """
29
- 【核心修复】
30
- 将长文本切分成短片段(每段 400 字),防止超过 BERT 512 Token 的限制。
31
- """
32
- # 设定步长,BERT限制是512 token,我们保守取 400 字符作为一段
33
  chunk_size = 400
34
  all_results = []
35
-
36
- # 循环切分文本
37
  for i in range(0, len(text), chunk_size):
38
  chunk = text[i : i + chunk_size]
39
  if not chunk.strip():
40
  continue
41
-
42
  try:
43
- # 对每一小段运行模型
44
  chunk_results = ner_pipeline(chunk)
45
  all_results.extend(chunk_results)
46
  except Exception as e:
47
  print(f"片段处理出错: {e}")
48
  continue
49
-
50
  return all_results
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def extract_audit_info(contract_text, file_obj):
53
- # 1. 读取内容
54
- content = contract_text
55
  if file_obj is not None:
56
- try:
57
- with open(file_obj.name, "r", encoding="utf-8") as f:
58
- content = f.read()
59
- except Exception as e:
60
- return f"文件读取失败: {str(e)}"
61
-
62
- if not content:
63
- return "请输入文本或上传文件"
64
 
65
- # 限制总处理长度,防止恶意上传超大文件卡死服务器(例如限制前 5000 字)
66
- # 审计合同通常关键信息在前几页和最后几页
 
 
67
  if len(content) > 5000:
68
  process_text = content[:5000]
69
- output_warning = "(提示:文本过长,仅分析了前 5000 字)\n\n"
70
  else:
71
  process_text = content
72
- output_warning = ""
73
 
74
- # 2. 分段运行 NER 模型(修复报错的关键点)
75
  ner_results = run_ner_on_long_text(process_text)
76
 
77
- # 3. 全文运行正则提取金额(正则不需要分段)
78
  money_results = extract_money(process_text)
79
 
80
- # 4. 整理输出
81
- output_str = f"=== 提取报告 ===\n{output_warning}\n"
82
 
83
- # --- 整理金额 ---
84
  output_str += "💰【涉及金额】:\n"
85
- # 去重
86
- money_results = sorted(list(set(money_results)))
87
  if money_results:
88
  for m in money_results:
89
  output_str += f"- {m}\n"
90
  else:
91
- output_str += "(未检测到明确金额)\n"
92
-
93
  output_str += "\n"
94
 
95
- # --- 整理模型实体 ---
96
  label_map = {
97
  "organization": "🏢 组织/公司",
98
  "company": "🏢 公司",
@@ -106,7 +119,7 @@ def extract_audit_info(contract_text, file_obj):
106
  for item in ner_results:
107
  group = item['entity_group']
108
  word = item['word']
109
- # 过滤低置信度和太短的词
110
  if item['score'] > 0.4 and len(word) > 1:
111
  cn_label = label_map.get(group, group)
112
  if cn_label not in found_entities:
@@ -121,15 +134,16 @@ def extract_audit_info(contract_text, file_obj):
121
 
122
  return output_str
123
 
124
- # --- Gradio 界面 ---
125
  with gr.Blocks() as demo:
126
- gr.Markdown("# 🧾 审计合同关键信息提取")
127
- gr.Markdown("自动提取:**金额**、**公司名称**、**签订日期**、**签约人**等信息。")
128
 
129
  with gr.Row():
130
  with gr.Column():
131
- input_text = gr.Textbox(label="直接粘贴合同文本", lines=8, placeholder="在此粘贴文本...")
132
- input_file = gr.File(label="或上传 TXT 文件", file_types=[".txt"])
 
133
  btn = gr.Button("🚀 开始分析", variant="primary")
134
 
135
  with gr.Column():
 
1
  import gradio as gr
2
  import re
3
+ import os
4
  from transformers import pipeline
5
+ from pypdf import PdfReader # 引入 PDF 读取库
6
 
7
+ # --- 1. 加载模型 ---
8
+ print("正在加载 NER 模型...")
 
9
  ner_pipeline = pipeline("ner", model="uer/roberta-base-finetuned-cluener2020-chinese", aggregation_strategy="simple")
10
 
11
+ # --- 2. 核心功能函数 ---
12
+
13
  def extract_money(text):
14
+ """正则提取金额"""
 
 
15
  money_entities = []
 
16
  pattern = r'([¥$€USDCNY人民币]*\s*\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*[万亿]?(?:元|美元|欧元|CNY|HKD)?)'
 
17
  matches = re.finditer(pattern, text)
18
  for match in matches:
19
  val = match.group(0).strip()
 
20
  if len(val) > 1 and (re.search(r'[^\d.,]', val) or '.' in val):
21
  money_entities.append(val)
 
22
  return money_entities
23
 
24
  def run_ner_on_long_text(text):
25
+ """分段处理长文本,避免 BERT 报错"""
 
 
 
 
26
  chunk_size = 400
27
  all_results = []
 
 
28
  for i in range(0, len(text), chunk_size):
29
  chunk = text[i : i + chunk_size]
30
  if not chunk.strip():
31
  continue
 
32
  try:
 
33
  chunk_results = ner_pipeline(chunk)
34
  all_results.extend(chunk_results)
35
  except Exception as e:
36
  print(f"片段处理出错: {e}")
37
  continue
 
38
  return all_results
39
 
40
+ def read_file_content(file_obj):
41
+ """
42
+ 识别文件类型并提取文本
43
+ 支持:.txt, .pdf
44
+ """
45
+ content = ""
46
+ try:
47
+ file_path = file_obj.name
48
+ file_ext = os.path.splitext(file_path)[1].lower()
49
+
50
+ if file_ext == ".pdf":
51
+ # 处理 PDF
52
+ reader = PdfReader(file_path)
53
+ for page in reader.pages:
54
+ # 提取每一页的文本并拼接
55
+ text = page.extract_text()
56
+ if text:
57
+ content += text + "\n"
58
+ else:
59
+ # 默认当作 TXT 处理
60
+ with open(file_path, "r", encoding="utf-8") as f:
61
+ content = f.read()
62
+
63
+ except Exception as e:
64
+ return f"ERROR: 文件读取失败 ({str(e)})"
65
+
66
+ return content
67
+
68
  def extract_audit_info(contract_text, file_obj):
69
+ # 1. 获取文本内容(优先读文件,否则读文本框)
70
+ content = ""
71
  if file_obj is not None:
72
+ content = read_file_content(file_obj)
73
+ if content.startswith("ERROR"):
74
+ return content # 返回错误信息
75
+ else:
76
+ content = contract_text
 
 
 
77
 
78
+ if not content or not content.strip():
79
+ return "⚠️ 未能提取到文本内容。请确保:\n1. 上传了正确的文件。\n2. 如果是 PDF,请确保是【文字版】而非【扫描图片版】(图片版需要OCR功能)。"
80
+
81
+ # 限制分析长度(防止内存爆炸,取前 5000 字)
82
  if len(content) > 5000:
83
  process_text = content[:5000]
84
+ warning = f"(提示:文本共 {len(content)} 字,仅分析前 5000 字)\n\n"
85
  else:
86
  process_text = content
87
+ warning = ""
88
 
89
+ # 2. AI 模型分析(分段)
90
  ner_results = run_ner_on_long_text(process_text)
91
 
92
+ # 3. 正则提取金额
93
  money_results = extract_money(process_text)
94
 
95
+ # 4. 生成报告
96
+ output_str = f"=== 📊 提取报告 ===\n{warning}"
97
 
98
+ # 金额部分
99
  output_str += "💰【涉及金额】:\n"
100
+ money_results = sorted(list(set(money_results))) # 去重排序
 
101
  if money_results:
102
  for m in money_results:
103
  output_str += f"- {m}\n"
104
  else:
105
+ output_str += "()\n"
 
106
  output_str += "\n"
107
 
108
+ # 实体部分
109
  label_map = {
110
  "organization": "🏢 组织/公司",
111
  "company": "🏢 公司",
 
119
  for item in ner_results:
120
  group = item['entity_group']
121
  word = item['word']
122
+ # 过滤杂质:置信度>0.4 且 长度>1
123
  if item['score'] > 0.4 and len(word) > 1:
124
  cn_label = label_map.get(group, group)
125
  if cn_label not in found_entities:
 
134
 
135
  return output_str
136
 
137
+ # --- 3. Gradio 界面 ---
138
  with gr.Blocks() as demo:
139
+ gr.Markdown("# 🧾 智能审计/合同信息提取")
140
+ gr.Markdown("支持上传 **.txt** 或 **.pdf** 文件,自动提取金额、日期、公司名等。")
141
 
142
  with gr.Row():
143
  with gr.Column():
144
+ input_text = gr.Textbox(label="直接粘贴文本", lines=8, placeholder="在此粘贴合同文本...")
145
+ # 修改 file_types 支持 pdf
146
+ input_file = gr.File(label="上传文件 (PDF / TXT)", file_types=[".txt", ".pdf"])
147
  btn = gr.Button("🚀 开始分析", variant="primary")
148
 
149
  with gr.Column():