IFMedTechdemo commited on
Commit
e22346a
·
verified ·
1 Parent(s): 5ea9a7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -215
app.py CHANGED
@@ -1,120 +1,21 @@
1
  #################################################################################################
2
 
3
- import subprocess
4
- import sys
5
-
6
  import spaces
7
- import torch
8
-
9
  import gradio as gr
10
  from PIL import Image
11
  import numpy as np
12
  import cv2
13
- import pypdfium2 as pdfium
14
- from transformers import (
15
- LightOnOCRForConditionalGeneration,
16
- LightOnOCRProcessor,
17
- )
18
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
19
  import re
20
 
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- if device == "cuda":
23
- attn_implementation = "sdpa"
24
- dtype = torch.bfloat16
25
- else:
26
- attn_implementation = "eager"
27
- dtype = torch.float32
28
-
29
- ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
30
- "lightonai/LightOnOCR-1B-1025",
31
- attn_implementation=attn_implementation,
32
- torch_dtype=dtype,
33
- trust_remote_code=True,
34
- ).to(device).eval()
35
-
36
- processor = LightOnOCRProcessor.from_pretrained(
37
- "lightonai/LightOnOCR-1B-1025",
38
- trust_remote_code=True,
39
- )
40
-
41
- ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
42
- ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
43
- ner_pipeline = pipeline(
44
- "ner",
45
- model=ner_model,
46
- tokenizer=ner_tokenizer,
47
- aggregation_strategy="simple",
48
- )
49
-
50
- def render_pdf_page(page, max_resolution=1540, scale=2.77):
51
- width, height = page.get_size()
52
- pixel_width = width * scale
53
- pixel_height = height * scale
54
- resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
55
- target_scale = scale * resize_factor
56
- return page.render(scale=target_scale, rev_byteorder=True).to_pil()
57
-
58
- def process_pdf(pdf_path, page_num=1):
59
- pdf = pdfium.PdfDocument(pdf_path)
60
- total_pages = len(pdf)
61
- page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
62
- page = pdf[page_idx]
63
- img = render_pdf_page(page)
64
- pdf.close()
65
- return img, total_pages, page_idx + 1
66
-
67
- def clean_output_text(text):
68
- markers_to_remove = ["system", "user", "assistant"]
69
- lines = text.split('\n')
70
- cleaned_lines = []
71
- for line in lines:
72
- stripped = line.strip()
73
- if stripped.lower() not in markers_to_remove:
74
- cleaned_lines.append(line)
75
- cleaned = '\n'.join(cleaned_lines).strip()
76
- if "assistant" in text.lower():
77
- parts = text.split("assistant", 1)
78
- if len(parts) > 1:
79
- cleaned = parts[1].strip()
80
- return cleaned
81
-
82
- def preprocess_image_for_ocr(image):
83
- image_rgb = image.convert("RGB")
84
- img_np = np.array(image_rgb)
85
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
86
- adaptive_threshold = cv2.adaptiveThreshold(
87
- gray,
88
- 255,
89
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
90
- cv2.THRESH_BINARY,
91
- 85,
92
- 35,
93
- )
94
- preprocessed_pil = Image.fromarray(adaptive_threshold)
95
- return preprocessed_pil
96
-
97
-
98
-
99
  def extract_medication_lines(text):
100
- """
101
- Extracts medication/drug lines from text using regex.
102
- Matches lines beginning with tab, tablet, cap, capsule, syrup, syp, oral, inj, injection, ointment, drops, patch, sol, solution, etc.
103
- Handles case-insensitivity and abbreviations like T., C., tab., cap. etc.
104
- """
105
- # "|" means OR. (?:...) is a non-capturing group.
106
- pattern = r"""^\s* # Leading spaces allowed
107
- (
108
- T\.?|TAB\.?|TABLET # T., T, TAB, TAB., TABLET
109
- |C\.?|CAP\.?|CAPSULE # C., C, CAP, CAP., CAPSULE
110
  |SYRUP|SYP
111
  |ORAL
112
- |INJ\.?|INJECTION # INJ., INJ, INJECTION
113
  |OINTMENT|DROPS|PATCH|SOL\.?|SOLUTION
114
- )
115
- \s+[A-Z0-9 \-\(\)/,.]+ # Name/dose/other info (at least one space/letter after the pattern)
116
- """
117
- # Compile with re.IGNORECASE and re.VERBOSE for readability
118
  med_regex = re.compile(pattern, re.IGNORECASE | re.VERBOSE)
119
  meds = []
120
  for line in text.split('\n'):
@@ -123,30 +24,51 @@ def extract_medication_lines(text):
123
  meds.append(line)
124
  return '\n'.join(meds)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- def extract_meds(text, use_ner):
128
- """
129
- Switches between Clinical NER or regex extraction.
130
- Returns medications string.
131
- """
 
 
 
 
 
 
 
132
  if use_ner:
133
- entities = ner_pipeline(text)
134
- meds = []
135
- for ent in entities:
136
- if ent["entity_group"] == "treatment":
137
- word = ent["word"]
138
- if word.startswith("##") and meds:
139
- meds[-1] += word[2:]
140
- else:
141
- meds.append(word)
142
- return ", ".join(set(meds)) if meds else "None detected"
143
- else:
144
- return extract_medication_lines(text) or "None detected"
145
 
146
- @spaces.GPU
147
- def extract_text_from_image(image, temperature=0.2):
148
- """OCR with adaptive thresholding."""
149
  processed_img = preprocess_image_for_ocr(image)
 
150
  chat = [
151
  {
152
  "role": "user",
@@ -162,15 +84,13 @@ def extract_text_from_image(image, temperature=0.2):
162
  return_dict=True,
163
  return_tensors="pt",
164
  )
165
- # Move inputs to device
166
  inputs = {
167
- k: (
168
- v.to(device=device, dtype=dtype)
169
  if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
170
  else v.to(device)
171
  if isinstance(v, torch.Tensor)
172
- else v
173
- )
174
  for k, v in inputs.items()
175
  }
176
  generation_kwargs = dict(
@@ -182,68 +102,40 @@ def extract_text_from_image(image, temperature=0.2):
182
  )
183
  with torch.no_grad():
184
  outputs = ocr_model.generate(**generation_kwargs)
 
185
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
186
- cleaned_text = clean_output_text(output_text)
187
- yield cleaned_text, output_text, processed_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def process_input(file_input, temperature, page_num, extraction_mode):
190
  if file_input is None:
191
- yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
192
  return
 
 
193
 
194
- image_to_process = None
195
- page_info = ""
196
- slider_value = page_num
197
- file_path = file_input if isinstance(file_input, str) else file_input.name
198
-
199
- if file_path.lower().endswith(".pdf"):
200
- try:
201
- image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
202
- page_info = f"Processing page {actual_page} of {total_pages}"
203
- slider_value = actual_page
204
- except Exception as e:
205
- msg = f"Error processing PDF: {str(e)}"
206
- yield msg, "", msg, "", None, slider_value
207
- return
208
- else:
209
- try:
210
- image_to_process = Image.open(file_path)
211
- page_info = "Processing image"
212
- except Exception as e:
213
- msg = f"Error opening image: {str(e)}"
214
- yield msg, "", msg, "", None, slider_value
215
- return
216
-
217
- use_ner = extraction_mode == "Regex" #"Clinical NER"
218
- try:
219
- for cleaned_text, raw_md, processed_img in extract_text_from_image(
220
- image_to_process, temperature
221
- ):
222
- meds_out = extract_meds(cleaned_text, use_ner)
223
- yield cleaned_text, meds_out, raw_md, page_info, processed_img, slider_value
224
- except Exception as e:
225
- error_msg = f"Error during text extraction: {str(e)}"
226
- yield error_msg, "", error_msg, page_info, image_to_process, slider_value
227
-
228
- def update_slider(file_input):
229
- if file_input is None:
230
- return gr.update(maximum=20, value=1)
231
- file_path = file_input if isinstance(file_input, str) else file_input.name
232
- if file_path.lower().endswith('.pdf'):
233
- try:
234
- pdf = pdfium.PdfDocument(file_path)
235
- total_pages = len(pdf)
236
- pdf.close()
237
- return gr.update(maximum=total_pages, value=1)
238
- except:
239
- return gr.update(maximum=20, value=1)
240
- else:
241
- return gr.update(maximum=1, value=1)
242
 
243
  with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
244
  file_input = gr.File(
245
- label="🖼️ Upload Image or PDF",
246
- file_types=[".pdf", ".png", ".jpg", ".jpeg"],
247
  type="filepath"
248
  )
249
  temperature = gr.Slider(
@@ -253,24 +145,12 @@ with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo
253
  step=0.05,
254
  label="Temperature"
255
  )
256
- page_slider = gr.Slider(
257
- minimum=1, maximum=20, value=1, step=1,
258
- label="Page Number (PDF only)",
259
- interactive=True
260
- )
261
  extraction_mode = gr.Radio(
262
  choices=["Clinical NER", "Regex"],
263
  value="Regex",
264
  label="Extraction Method",
265
  info="Clinical NER uses ML, Regex uses rules"
266
  )
267
- output_text = gr.Textbox(
268
- label="📝 Extracted Text",
269
- lines=4,
270
- max_lines=10,
271
- interactive=False,
272
- show_copy_button=True
273
- )
274
  medicines_output = gr.Textbox(
275
  label="💊 Extracted Medicines/Drugs",
276
  placeholder="Medicine/drug names will appear here...",
@@ -279,34 +159,16 @@ with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo
279
  interactive=False,
280
  show_copy_button=True
281
  )
282
- raw_output = gr.Textbox(
283
- label="Raw Model Output",
284
- lines=2,
285
- max_lines=5,
286
- interactive=False
287
- )
288
- page_info = gr.Markdown(
289
- value="" # Info of PDF page
290
- )
291
  rendered_image = gr.Image(
292
- label="Processed Image (Thresholded for OCR)",
293
  interactive=False
294
  )
295
- num_pages = gr.Number(
296
- value=1, label="Current Page (slider)", visible=False
297
- )
298
  submit_btn = gr.Button("Extract Medicines", variant="primary")
299
 
300
  submit_btn.click(
301
  fn=process_input,
302
- inputs=[file_input, temperature, page_slider, extraction_mode],
303
- outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
304
- )
305
-
306
- file_input.change(
307
- fn=update_slider,
308
- inputs=[file_input],
309
- outputs=[page_slider]
310
  )
311
 
312
  if __name__ == "__main__":
 
1
  #################################################################################################
2
 
 
 
 
3
  import spaces
 
 
4
  import gradio as gr
5
  from PIL import Image
6
  import numpy as np
7
  import cv2
 
 
 
 
 
 
8
  import re
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def extract_medication_lines(text):
11
+ pattern = r"""^\s*(
12
+ T\.?|TAB\.?|TABLET
13
+ |C\.?|CAP\.?|CAPSULE
 
 
 
 
 
 
 
14
  |SYRUP|SYP
15
  |ORAL
16
+ |INJ\.?|INJECTION
17
  |OINTMENT|DROPS|PATCH|SOL\.?|SOLUTION
18
+ )\s+[A-Z0-9 \-\(\)/,.]+"""
 
 
 
19
  med_regex = re.compile(pattern, re.IGNORECASE | re.VERBOSE)
20
  meds = []
21
  for line in text.split('\n'):
 
24
  meds.append(line)
25
  return '\n'.join(meds)
26
 
27
+ def preprocess_image_for_ocr(image):
28
+ image_rgb = image.convert("RGB")
29
+ img_np = np.array(image_rgb)
30
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
31
+ adaptive_threshold = cv2.adaptiveThreshold(
32
+ gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 85,35,
33
+ )
34
+ preprocessed_pil = Image.fromarray(adaptive_threshold)
35
+ return preprocessed_pil
36
+
37
+ @spaces.GPU
38
+ def extract_text_from_image(image, temperature=0.2, use_ner=False):
39
+ # Import and load within GPU context!
40
+ import torch
41
+ from transformers import (
42
+ LightOnOCRForConditionalGeneration,
43
+ LightOnOCRProcessor,
44
+ AutoTokenizer, AutoModelForTokenClassification, pipeline,
45
+ )
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ attn_implementation = "sdpa" if device == "cuda" else "eager"
49
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
50
 
51
+ ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
52
+ "lightonai/LightOnOCR-1B-1025",
53
+ attn_implementation=attn_implementation,
54
+ torch_dtype=dtype,
55
+ trust_remote_code=True,
56
+ ).to(device).eval()
57
+
58
+ processor = LightOnOCRProcessor.from_pretrained(
59
+ "lightonai/LightOnOCR-1B-1025",
60
+ trust_remote_code=True,
61
+ )
62
+ # NER only if requested
63
  if use_ner:
64
+ ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
65
+ ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
66
+ ner_pipeline = pipeline(
67
+ "ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple"
68
+ )
 
 
 
 
 
 
 
69
 
 
 
 
70
  processed_img = preprocess_image_for_ocr(image)
71
+
72
  chat = [
73
  {
74
  "role": "user",
 
84
  return_dict=True,
85
  return_tensors="pt",
86
  )
87
+
88
  inputs = {
89
+ k: (v.to(device=device, dtype=dtype)
 
90
  if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
91
  else v.to(device)
92
  if isinstance(v, torch.Tensor)
93
+ else v)
 
94
  for k, v in inputs.items()
95
  }
96
  generation_kwargs = dict(
 
102
  )
103
  with torch.no_grad():
104
  outputs = ocr_model.generate(**generation_kwargs)
105
+
106
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
107
+ cleaned_text = output_text.strip()
108
+ # Extract medicines
109
+ if use_ner:
110
+ entities = ner_pipeline(cleaned_text)
111
+ meds = []
112
+ for ent in entities:
113
+ if ent["entity_group"] == "treatment":
114
+ word = ent["word"]
115
+ if word.startswith("##") and meds:
116
+ meds[-1] += word[2:]
117
+ else:
118
+ meds.append(word)
119
+ result_meds = ", ".join(set(meds)) if meds else "None detected"
120
+ else:
121
+ result_meds = extract_medication_lines(cleaned_text) or "None detected"
122
+
123
+ yield result_meds, processed_img # Only medicines and processed image
124
 
125
  def process_input(file_input, temperature, page_num, extraction_mode):
126
  if file_input is None:
127
+ yield "Please upload an image or PDF first.", None
128
  return
129
+ image_to_process = Image.open(file_input) if not str(file_input).lower().endswith(".pdf") else None # simplify to image only
130
+ use_ner = extraction_mode == "Clinical NER"
131
 
132
+ for meds_out, processed_img in extract_text_from_image(image_to_process, temperature, use_ner):
133
+ yield meds_out, processed_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
136
  file_input = gr.File(
137
+ label="🖼️ Upload Image",
138
+ file_types=[".png", ".jpg", ".jpeg"],
139
  type="filepath"
140
  )
141
  temperature = gr.Slider(
 
145
  step=0.05,
146
  label="Temperature"
147
  )
 
 
 
 
 
148
  extraction_mode = gr.Radio(
149
  choices=["Clinical NER", "Regex"],
150
  value="Regex",
151
  label="Extraction Method",
152
  info="Clinical NER uses ML, Regex uses rules"
153
  )
 
 
 
 
 
 
 
154
  medicines_output = gr.Textbox(
155
  label="💊 Extracted Medicines/Drugs",
156
  placeholder="Medicine/drug names will appear here...",
 
159
  interactive=False,
160
  show_copy_button=True
161
  )
 
 
 
 
 
 
 
 
 
162
  rendered_image = gr.Image(
163
+ label="Processed Image (Adaptive Thresholded for OCR)",
164
  interactive=False
165
  )
 
 
 
166
  submit_btn = gr.Button("Extract Medicines", variant="primary")
167
 
168
  submit_btn.click(
169
  fn=process_input,
170
+ inputs=[file_input, temperature, 1, extraction_mode], # page_num not used for image, set to 1
171
+ outputs=[medicines_output, rendered_image]
 
 
 
 
 
 
172
  )
173
 
174
  if __name__ == "__main__":