IFMedTechdemo commited on
Commit
c9ad6ed
·
verified ·
1 Parent(s): f574169

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -339
app.py CHANGED
@@ -1,323 +1,5 @@
1
  #################################################################################################
2
 
3
- import subprocess
4
- import sys
5
-
6
- import spaces
7
- import torch
8
-
9
- import gradio as gr
10
- from PIL import Image
11
- import numpy as np
12
- import cv2
13
- import pypdfium2 as pdfium
14
- from transformers import (
15
- LightOnOCRForConditionalGeneration,
16
- LightOnOCRProcessor,
17
- )
18
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
19
- import re
20
-
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- if device == "cuda":
23
- attn_implementation = "sdpa"
24
- dtype = torch.bfloat16
25
- else:
26
- attn_implementation = "eager"
27
- dtype = torch.float32
28
-
29
- ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
30
- "lightonai/LightOnOCR-1B-1025",
31
- attn_implementation=attn_implementation,
32
- torch_dtype=dtype,
33
- trust_remote_code=True,
34
- ).to(device).eval()
35
-
36
- processor = LightOnOCRProcessor.from_pretrained(
37
- "lightonai/LightOnOCR-1B-1025",
38
- trust_remote_code=True,
39
- )
40
-
41
- ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
42
- ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
43
- ner_pipeline = pipeline(
44
- "ner",
45
- model=ner_model,
46
- tokenizer=ner_tokenizer,
47
- aggregation_strategy="simple",
48
- )
49
-
50
- def render_pdf_page(page, max_resolution=1540, scale=2.77):
51
- width, height = page.get_size()
52
- pixel_width = width * scale
53
- pixel_height = height * scale
54
- resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
55
- target_scale = scale * resize_factor
56
- return page.render(scale=target_scale, rev_byteorder=True).to_pil()
57
-
58
- def process_pdf(pdf_path, page_num=1):
59
- pdf = pdfium.PdfDocument(pdf_path)
60
- total_pages = len(pdf)
61
- page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
62
- page = pdf[page_idx]
63
- img = render_pdf_page(page)
64
- pdf.close()
65
- return img, total_pages, page_idx + 1
66
-
67
- def clean_output_text(text):
68
- markers_to_remove = ["system", "user", "assistant"]
69
- lines = text.split('\n')
70
- cleaned_lines = []
71
- for line in lines:
72
- stripped = line.strip()
73
- if stripped.lower() not in markers_to_remove:
74
- cleaned_lines.append(line)
75
- cleaned = '\n'.join(cleaned_lines).strip()
76
- if "assistant" in text.lower():
77
- parts = text.split("assistant", 1)
78
- if len(parts) > 1:
79
- cleaned = parts[1].strip()
80
- return cleaned
81
-
82
- def preprocess_image_for_ocr(image):
83
- image_rgb = image.convert("RGB")
84
- img_np = np.array(image_rgb)
85
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
86
- adaptive_threshold = cv2.adaptiveThreshold(
87
- gray,
88
- 255,
89
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
90
- cv2.THRESH_BINARY,
91
- 85,
92
- 11,
93
- )
94
- preprocessed_pil = Image.fromarray(adaptive_threshold)
95
- return preprocessed_pil
96
-
97
-
98
-
99
- def extract_medication_lines(text):
100
- """
101
- Extracts medication/drug lines from text using regex.
102
- Matches lines beginning with tab, tablet, cap, capsule, syrup, syp, oral, inj, injection, ointment, drops, patch, sol, solution, etc.
103
- Handles case-insensitivity and abbreviations like T., C., tab., cap. etc.
104
- """
105
- # "|" means OR. (?:...) is a non-capturing group.
106
- pattern = r"""^\s* # Leading spaces allowed
107
- (
108
- T\.?|TAB\.?|TABLET # T., T, TAB, TAB., TABLET
109
- |C\.?|CAP\.?|CAPSULE # C., C, CAP, CAP., CAPSULE
110
- |SYRUP|SYP
111
- |ORAL
112
- |INJ\.?|INJECTION # INJ., INJ, INJECTION
113
- |OINTMENT|DROPS|PATCH|SOL\.?|SOLUTION
114
- )
115
- \s+[A-Z0-9 \-\(\)/,.]+ # Name/dose/other info (at least one space/letter after the pattern)
116
- """
117
- # Compile with re.IGNORECASE and re.VERBOSE for readability
118
- med_regex = re.compile(pattern, re.IGNORECASE | re.VERBOSE)
119
- meds = []
120
- for line in text.split('\n'):
121
- line = line.strip()
122
- if med_regex.match(line):
123
- meds.append(line)
124
- return '\n'.join(meds)
125
-
126
-
127
- def extract_meds(text, use_ner):
128
- """
129
- Switches between Clinical NER or regex extraction.
130
- Returns medications string.
131
- """
132
- if use_ner:
133
- entities = ner_pipeline(text)
134
- meds = []
135
- for ent in entities:
136
- if ent["entity_group"] == "treatment":
137
- word = ent["word"]
138
- if word.startswith("##") and meds:
139
- meds[-1] += word[2:]
140
- else:
141
- meds.append(word)
142
- return ", ".join(set(meds)) if meds else "None detected"
143
- else:
144
- return extract_medication_lines(text) or "None detected"
145
-
146
- @spaces.GPU
147
- def extract_text_from_image(image, temperature=0.2):
148
- """OCR with adaptive thresholding."""
149
- processed_img = preprocess_image_for_ocr(image)
150
- chat = [
151
- {
152
- "role": "user",
153
- "content": [
154
- {"type": "image", "image": processed_img}
155
- ],
156
- }
157
- ]
158
- inputs = processor.apply_chat_template(
159
- chat,
160
- add_generation_prompt=True,
161
- tokenize=True,
162
- return_dict=True,
163
- return_tensors="pt",
164
- )
165
- # Move inputs to device
166
- inputs = {
167
- k: (
168
- v.to(device=device, dtype=dtype)
169
- if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
170
- else v.to(device)
171
- if isinstance(v, torch.Tensor)
172
- else v
173
- )
174
- for k, v in inputs.items()
175
- }
176
- generation_kwargs = dict(
177
- **inputs,
178
- max_new_tokens=2048,
179
- temperature=temperature if temperature > 0 else 0.0,
180
- use_cache=True,
181
- do_sample=temperature > 0,
182
- )
183
- with torch.no_grad():
184
- outputs = ocr_model.generate(**generation_kwargs)
185
- output_text = processor.decode(outputs[0], skip_special_tokens=True)
186
- cleaned_text = clean_output_text(output_text)
187
- yield cleaned_text, output_text, processed_img
188
-
189
- def process_input(file_input, temperature, page_num, extraction_mode):
190
- if file_input is None:
191
- yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
192
- return
193
-
194
- image_to_process = None
195
- page_info = ""
196
- slider_value = page_num
197
- file_path = file_input if isinstance(file_input, str) else file_input.name
198
-
199
- if file_path.lower().endswith(".pdf"):
200
- try:
201
- image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
202
- page_info = f"Processing page {actual_page} of {total_pages}"
203
- slider_value = actual_page
204
- except Exception as e:
205
- msg = f"Error processing PDF: {str(e)}"
206
- yield msg, "", msg, "", None, slider_value
207
- return
208
- else:
209
- try:
210
- image_to_process = Image.open(file_path)
211
- page_info = "Processing image"
212
- except Exception as e:
213
- msg = f"Error opening image: {str(e)}"
214
- yield msg, "", msg, "", None, slider_value
215
- return
216
-
217
- use_ner = extraction_mode == "Regex" #"Clinical NER"
218
- try:
219
- for cleaned_text, raw_md, processed_img in extract_text_from_image(
220
- image_to_process, temperature
221
- ):
222
- meds_out = extract_meds(cleaned_text, use_ner)
223
- yield cleaned_text, meds_out, raw_md, page_info, processed_img, slider_value
224
- except Exception as e:
225
- error_msg = f"Error during text extraction: {str(e)}"
226
- yield error_msg, "", error_msg, page_info, image_to_process, slider_value
227
-
228
- def update_slider(file_input):
229
- if file_input is None:
230
- return gr.update(maximum=20, value=1)
231
- file_path = file_input if isinstance(file_input, str) else file_input.name
232
- if file_path.lower().endswith('.pdf'):
233
- try:
234
- pdf = pdfium.PdfDocument(file_path)
235
- total_pages = len(pdf)
236
- pdf.close()
237
- return gr.update(maximum=total_pages, value=1)
238
- except:
239
- return gr.update(maximum=20, value=1)
240
- else:
241
- return gr.update(maximum=1, value=1)
242
-
243
- with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
244
- file_input = gr.File(
245
- label="🖼️ Upload Image or PDF",
246
- file_types=[".pdf", ".png", ".jpg", ".jpeg"],
247
- type="filepath"
248
- )
249
- temperature = gr.Slider(
250
- minimum=0.0,
251
- maximum=1.0,
252
- value=0.2,
253
- step=0.05,
254
- label="Temperature"
255
- )
256
- page_slider = gr.Slider(
257
- minimum=1, maximum=20, value=1, step=1,
258
- label="Page Number (PDF only)",
259
- interactive=True
260
- )
261
- extraction_mode = gr.Radio(
262
- choices=["Clinical NER", "Regex"],
263
- value="Regex",
264
- label="Extraction Method",
265
- info="Clinical NER uses ML, Regex uses rules"
266
- )
267
- output_text = gr.Textbox(
268
- label="📝 Extracted Text",
269
- lines=4,
270
- max_lines=10,
271
- interactive=False,
272
- show_copy_button=True
273
- )
274
- medicines_output = gr.Textbox(
275
- label="💊 Extracted Medicines/Drugs",
276
- placeholder="Medicine/drug names will appear here...",
277
- lines=2,
278
- max_lines=10,
279
- interactive=False,
280
- show_copy_button=True
281
- )
282
- raw_output = gr.Textbox(
283
- label="Raw Model Output",
284
- lines=2,
285
- max_lines=5,
286
- interactive=False
287
- )
288
- page_info = gr.Markdown(
289
- value="" # Info of PDF page
290
- )
291
- rendered_image = gr.Image(
292
- label="Processed Image (Thresholded for OCR)",
293
- interactive=False
294
- )
295
- num_pages = gr.Number(
296
- value=1, label="Current Page (slider)", visible=False
297
- )
298
- submit_btn = gr.Button("Extract Medicines", variant="primary")
299
-
300
- submit_btn.click(
301
- fn=process_input,
302
- inputs=[file_input, temperature, page_slider, extraction_mode],
303
- outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
304
- )
305
-
306
- file_input.change(
307
- fn=update_slider,
308
- inputs=[file_input],
309
- outputs=[page_slider]
310
- )
311
-
312
- if __name__ == "__main__":
313
- demo.launch()
314
-
315
-
316
-
317
- #################################################### running code only NER #######################
318
-
319
- #!/usr/bin/env python3
320
-
321
  # import subprocess
322
  # import sys
323
 
@@ -334,6 +16,7 @@ if __name__ == "__main__":
334
  # LightOnOCRProcessor,
335
  # )
336
  # from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
 
337
 
338
  # device = "cuda" if torch.cuda.is_available() else "cpu"
339
  # if device == "cuda":
@@ -397,7 +80,6 @@ if __name__ == "__main__":
397
  # return cleaned
398
 
399
  # def preprocess_image_for_ocr(image):
400
- # """Convert PIL.Image to adaptive thresholded image for OCR."""
401
  # image_rgb = image.convert("RGB")
402
  # img_np = np.array(image_rgb)
403
  # gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
@@ -412,9 +94,58 @@ if __name__ == "__main__":
412
  # preprocessed_pil = Image.fromarray(adaptive_threshold)
413
  # return preprocessed_pil
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  # @spaces.GPU
416
  # def extract_text_from_image(image, temperature=0.2):
417
- # """OCR + clinical NER, with preprocessing."""
418
  # processed_img = preprocess_image_for_ocr(image)
419
  # chat = [
420
  # {
@@ -451,22 +182,11 @@ if __name__ == "__main__":
451
  # )
452
  # with torch.no_grad():
453
  # outputs = ocr_model.generate(**generation_kwargs)
454
-
455
  # output_text = processor.decode(outputs[0], skip_special_tokens=True)
456
  # cleaned_text = clean_output_text(output_text)
457
- # entities = ner_pipeline(cleaned_text)
458
- # medications = []
459
- # for ent in entities:
460
- # if ent["entity_group"] == "treatment":
461
- # word = ent["word"]
462
- # if word.startswith("##") and medications:
463
- # medications[-1] += word[2:]
464
- # else:
465
- # medications.append(word)
466
- # medications_str = ", ".join(set(medications)) if medications else "None detected"
467
- # yield cleaned_text, medications_str, output_text, processed_img
468
-
469
- # def process_input(file_input, temperature, page_num):
470
  # if file_input is None:
471
  # yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
472
  # return
@@ -494,11 +214,13 @@ if __name__ == "__main__":
494
  # yield msg, "", msg, "", None, slider_value
495
  # return
496
 
 
497
  # try:
498
- # for cleaned_text, medications, raw_md, processed_img in extract_text_from_image(
499
  # image_to_process, temperature
500
  # ):
501
- # yield cleaned_text, medications, raw_md, page_info, processed_img, slider_value
 
502
  # except Exception as e:
503
  # error_msg = f"Error during text extraction: {str(e)}"
504
  # yield error_msg, "", error_msg, page_info, image_to_process, slider_value
@@ -536,6 +258,12 @@ if __name__ == "__main__":
536
  # label="Page Number (PDF only)",
537
  # interactive=True
538
  # )
 
 
 
 
 
 
539
  # output_text = gr.Textbox(
540
  # label="📝 Extracted Text",
541
  # lines=4,
@@ -547,7 +275,7 @@ if __name__ == "__main__":
547
  # label="💊 Extracted Medicines/Drugs",
548
  # placeholder="Medicine/drug names will appear here...",
549
  # lines=2,
550
- # max_lines=5,
551
  # interactive=False,
552
  # show_copy_button=True
553
  # )
@@ -558,7 +286,7 @@ if __name__ == "__main__":
558
  # interactive=False
559
  # )
560
  # page_info = gr.Markdown(
561
- # value="" # Info of PDF page
562
  # )
563
  # rendered_image = gr.Image(
564
  # label="Processed Image (Thresholded for OCR)",
@@ -571,7 +299,7 @@ if __name__ == "__main__":
571
 
572
  # submit_btn.click(
573
  # fn=process_input,
574
- # inputs=[file_input, temperature, page_slider],
575
  # outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
576
  # )
577
 
@@ -586,6 +314,278 @@ if __name__ == "__main__":
586
 
587
 
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
 
590
  ########################################## #############################################################
591
 
 
1
  #################################################################################################
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # import subprocess
4
  # import sys
5
 
 
16
  # LightOnOCRProcessor,
17
  # )
18
  # from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
19
+ # import re
20
 
21
  # device = "cuda" if torch.cuda.is_available() else "cpu"
22
  # if device == "cuda":
 
80
  # return cleaned
81
 
82
  # def preprocess_image_for_ocr(image):
 
83
  # image_rgb = image.convert("RGB")
84
  # img_np = np.array(image_rgb)
85
  # gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
 
94
  # preprocessed_pil = Image.fromarray(adaptive_threshold)
95
  # return preprocessed_pil
96
 
97
+
98
+
99
+ # def extract_medication_lines(text):
100
+ # """
101
+ # Extracts medication/drug lines from text using regex.
102
+ # Matches lines beginning with tab, tablet, cap, capsule, syrup, syp, oral, inj, injection, ointment, drops, patch, sol, solution, etc.
103
+ # Handles case-insensitivity and abbreviations like T., C., tab., cap. etc.
104
+ # """
105
+ # # "|" means OR. (?:...) is a non-capturing group.
106
+ # pattern = r"""^\s* # Leading spaces allowed
107
+ # (
108
+ # T\.?|TAB\.?|TABLET # T., T, TAB, TAB., TABLET
109
+ # |C\.?|CAP\.?|CAPSULE # C., C, CAP, CAP., CAPSULE
110
+ # |SYRUP|SYP
111
+ # |ORAL
112
+ # |INJ\.?|INJECTION # INJ., INJ, INJECTION
113
+ # |OINTMENT|DROPS|PATCH|SOL\.?|SOLUTION
114
+ # )
115
+ # \s+[A-Z0-9 \-\(\)/,.]+ # Name/dose/other info (at least one space/letter after the pattern)
116
+ # """
117
+ # # Compile with re.IGNORECASE and re.VERBOSE for readability
118
+ # med_regex = re.compile(pattern, re.IGNORECASE | re.VERBOSE)
119
+ # meds = []
120
+ # for line in text.split('\n'):
121
+ # line = line.strip()
122
+ # if med_regex.match(line):
123
+ # meds.append(line)
124
+ # return '\n'.join(meds)
125
+
126
+
127
+ # def extract_meds(text, use_ner):
128
+ # """
129
+ # Switches between Clinical NER or regex extraction.
130
+ # Returns medications string.
131
+ # """
132
+ # if use_ner:
133
+ # entities = ner_pipeline(text)
134
+ # meds = []
135
+ # for ent in entities:
136
+ # if ent["entity_group"] == "treatment":
137
+ # word = ent["word"]
138
+ # if word.startswith("##") and meds:
139
+ # meds[-1] += word[2:]
140
+ # else:
141
+ # meds.append(word)
142
+ # return ", ".join(set(meds)) if meds else "None detected"
143
+ # else:
144
+ # return extract_medication_lines(text) or "None detected"
145
+
146
  # @spaces.GPU
147
  # def extract_text_from_image(image, temperature=0.2):
148
+ # """OCR with adaptive thresholding."""
149
  # processed_img = preprocess_image_for_ocr(image)
150
  # chat = [
151
  # {
 
182
  # )
183
  # with torch.no_grad():
184
  # outputs = ocr_model.generate(**generation_kwargs)
 
185
  # output_text = processor.decode(outputs[0], skip_special_tokens=True)
186
  # cleaned_text = clean_output_text(output_text)
187
+ # yield cleaned_text, output_text, processed_img
188
+
189
+ # def process_input(file_input, temperature, page_num, extraction_mode):
 
 
 
 
 
 
 
 
 
 
190
  # if file_input is None:
191
  # yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
192
  # return
 
214
  # yield msg, "", msg, "", None, slider_value
215
  # return
216
 
217
+ # use_ner = extraction_mode == "Regex" #"Clinical NER"
218
  # try:
219
+ # for cleaned_text, raw_md, processed_img in extract_text_from_image(
220
  # image_to_process, temperature
221
  # ):
222
+ # meds_out = extract_meds(cleaned_text, use_ner)
223
+ # yield cleaned_text, meds_out, raw_md, page_info, processed_img, slider_value
224
  # except Exception as e:
225
  # error_msg = f"Error during text extraction: {str(e)}"
226
  # yield error_msg, "", error_msg, page_info, image_to_process, slider_value
 
258
  # label="Page Number (PDF only)",
259
  # interactive=True
260
  # )
261
+ # extraction_mode = gr.Radio(
262
+ # choices=["Clinical NER", "Regex"],
263
+ # value="Regex",
264
+ # label="Extraction Method",
265
+ # info="Clinical NER uses ML, Regex uses rules"
266
+ # )
267
  # output_text = gr.Textbox(
268
  # label="📝 Extracted Text",
269
  # lines=4,
 
275
  # label="💊 Extracted Medicines/Drugs",
276
  # placeholder="Medicine/drug names will appear here...",
277
  # lines=2,
278
+ # max_lines=10,
279
  # interactive=False,
280
  # show_copy_button=True
281
  # )
 
286
  # interactive=False
287
  # )
288
  # page_info = gr.Markdown(
289
+ # value="" # Info of PDF page
290
  # )
291
  # rendered_image = gr.Image(
292
  # label="Processed Image (Thresholded for OCR)",
 
299
 
300
  # submit_btn.click(
301
  # fn=process_input,
302
+ # inputs=[file_input, temperature, page_slider, extraction_mode],
303
  # outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
304
  # )
305
 
 
314
 
315
 
316
 
317
+ #################################################### running code only NER #######################
318
+
319
+ #!/usr/bin/env python3
320
+
321
+ import subprocess
322
+ import sys
323
+
324
+ import spaces
325
+ import torch
326
+
327
+ import gradio as gr
328
+ from PIL import Image
329
+ import numpy as np
330
+ import cv2
331
+ import pypdfium2 as pdfium
332
+ from transformers import (
333
+ LightOnOCRForConditionalGeneration,
334
+ LightOnOCRProcessor,
335
+ )
336
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
337
+
338
+ device = "cuda" if torch.cuda.is_available() else "cpu"
339
+ if device == "cuda":
340
+ attn_implementation = "sdpa"
341
+ dtype = torch.bfloat16
342
+ else:
343
+ attn_implementation = "eager"
344
+ dtype = torch.float32
345
+
346
+ ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
347
+ "lightonai/LightOnOCR-1B-1025",
348
+ attn_implementation=attn_implementation,
349
+ torch_dtype=dtype,
350
+ trust_remote_code=True,
351
+ ).to(device).eval()
352
+
353
+ processor = LightOnOCRProcessor.from_pretrained(
354
+ "lightonai/LightOnOCR-1B-1025",
355
+ trust_remote_code=True,
356
+ )
357
+
358
+ ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
359
+ ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
360
+ ner_pipeline = pipeline(
361
+ "ner",
362
+ model=ner_model,
363
+ tokenizer=ner_tokenizer,
364
+ aggregation_strategy="simple",
365
+ )
366
+
367
+ def render_pdf_page(page, max_resolution=1540, scale=2.77):
368
+ width, height = page.get_size()
369
+ pixel_width = width * scale
370
+ pixel_height = height * scale
371
+ resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
372
+ target_scale = scale * resize_factor
373
+ return page.render(scale=target_scale, rev_byteorder=True).to_pil()
374
+
375
+ def process_pdf(pdf_path, page_num=1):
376
+ pdf = pdfium.PdfDocument(pdf_path)
377
+ total_pages = len(pdf)
378
+ page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
379
+ page = pdf[page_idx]
380
+ img = render_pdf_page(page)
381
+ pdf.close()
382
+ return img, total_pages, page_idx + 1
383
+
384
+ def clean_output_text(text):
385
+ markers_to_remove = ["system", "user", "assistant"]
386
+ lines = text.split('\n')
387
+ cleaned_lines = []
388
+ for line in lines:
389
+ stripped = line.strip()
390
+ if stripped.lower() not in markers_to_remove:
391
+ cleaned_lines.append(line)
392
+ cleaned = '\n'.join(cleaned_lines).strip()
393
+ if "assistant" in text.lower():
394
+ parts = text.split("assistant", 1)
395
+ if len(parts) > 1:
396
+ cleaned = parts[1].strip()
397
+ return cleaned
398
+
399
+ def preprocess_image_for_ocr(image):
400
+ """Convert PIL.Image to adaptive thresholded image for OCR."""
401
+ image_rgb = image.convert("RGB")
402
+ img_np = np.array(image_rgb)
403
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
404
+ adaptive_threshold = cv2.adaptiveThreshold(
405
+ gray,
406
+ 255,
407
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
408
+ cv2.THRESH_BINARY,
409
+ 85,
410
+ 11,
411
+ )
412
+ preprocessed_pil = Image.fromarray(adaptive_threshold)
413
+ return preprocessed_pil
414
+
415
+ @spaces.GPU
416
+ def extract_text_from_image(image, temperature=0.2):
417
+ """OCR + clinical NER, with preprocessing."""
418
+ processed_img = preprocess_image_for_ocr(image)
419
+ chat = [
420
+ {
421
+ "role": "user",
422
+ "content": [
423
+ {"type": "image", "image": processed_img}
424
+ ],
425
+ }
426
+ ]
427
+ inputs = processor.apply_chat_template(
428
+ chat,
429
+ add_generation_prompt=True,
430
+ tokenize=True,
431
+ return_dict=True,
432
+ return_tensors="pt",
433
+ )
434
+ # Move inputs to device
435
+ inputs = {
436
+ k: (
437
+ v.to(device=device, dtype=dtype)
438
+ if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
439
+ else v.to(device)
440
+ if isinstance(v, torch.Tensor)
441
+ else v
442
+ )
443
+ for k, v in inputs.items()
444
+ }
445
+ generation_kwargs = dict(
446
+ **inputs,
447
+ max_new_tokens=2048,
448
+ temperature=temperature if temperature > 0 else 0.0,
449
+ use_cache=True,
450
+ do_sample=temperature > 0,
451
+ )
452
+ with torch.no_grad():
453
+ outputs = ocr_model.generate(**generation_kwargs)
454
+
455
+ output_text = processor.decode(outputs[0], skip_special_tokens=True)
456
+ cleaned_text = clean_output_text(output_text)
457
+ entities = ner_pipeline(cleaned_text)
458
+ medications = []
459
+ for ent in entities:
460
+ if ent["entity_group"] == "treatment":
461
+ word = ent["word"]
462
+ if word.startswith("##") and medications:
463
+ medications[-1] += word[2:]
464
+ else:
465
+ medications.append(word)
466
+ medications_str = ", ".join(set(medications)) if medications else "None detected"
467
+ yield cleaned_text, medications_str, output_text, processed_img
468
+
469
+ def process_input(file_input, temperature, page_num):
470
+ if file_input is None:
471
+ yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
472
+ return
473
+
474
+ image_to_process = None
475
+ page_info = ""
476
+ slider_value = page_num
477
+ file_path = file_input if isinstance(file_input, str) else file_input.name
478
+
479
+ if file_path.lower().endswith(".pdf"):
480
+ try:
481
+ image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
482
+ page_info = f"Processing page {actual_page} of {total_pages}"
483
+ slider_value = actual_page
484
+ except Exception as e:
485
+ msg = f"Error processing PDF: {str(e)}"
486
+ yield msg, "", msg, "", None, slider_value
487
+ return
488
+ else:
489
+ try:
490
+ image_to_process = Image.open(file_path)
491
+ page_info = "Processing image"
492
+ except Exception as e:
493
+ msg = f"Error opening image: {str(e)}"
494
+ yield msg, "", msg, "", None, slider_value
495
+ return
496
+
497
+ try:
498
+ for cleaned_text, medications, raw_md, processed_img in extract_text_from_image(
499
+ image_to_process, temperature
500
+ ):
501
+ yield cleaned_text, medications, raw_md, page_info, processed_img, slider_value
502
+ except Exception as e:
503
+ error_msg = f"Error during text extraction: {str(e)}"
504
+ yield error_msg, "", error_msg, page_info, image_to_process, slider_value
505
+
506
+ def update_slider(file_input):
507
+ if file_input is None:
508
+ return gr.update(maximum=20, value=1)
509
+ file_path = file_input if isinstance(file_input, str) else file_input.name
510
+ if file_path.lower().endswith('.pdf'):
511
+ try:
512
+ pdf = pdfium.PdfDocument(file_path)
513
+ total_pages = len(pdf)
514
+ pdf.close()
515
+ return gr.update(maximum=total_pages, value=1)
516
+ except:
517
+ return gr.update(maximum=20, value=1)
518
+ else:
519
+ return gr.update(maximum=1, value=1)
520
+
521
+ with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
522
+ file_input = gr.File(
523
+ label="🖼️ Upload Image or PDF",
524
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
525
+ type="filepath"
526
+ )
527
+ temperature = gr.Slider(
528
+ minimum=0.0,
529
+ maximum=1.0,
530
+ value=0.2,
531
+ step=0.05,
532
+ label="Temperature"
533
+ )
534
+ page_slider = gr.Slider(
535
+ minimum=1, maximum=20, value=1, step=1,
536
+ label="Page Number (PDF only)",
537
+ interactive=True
538
+ )
539
+ output_text = gr.Textbox(
540
+ label="📝 Extracted Text",
541
+ lines=4,
542
+ max_lines=10,
543
+ interactive=False,
544
+ show_copy_button=True
545
+ )
546
+ medicines_output = gr.Textbox(
547
+ label="💊 Extracted Medicines/Drugs",
548
+ placeholder="Medicine/drug names will appear here...",
549
+ lines=2,
550
+ max_lines=5,
551
+ interactive=False,
552
+ show_copy_button=True
553
+ )
554
+ raw_output = gr.Textbox(
555
+ label="Raw Model Output",
556
+ lines=2,
557
+ max_lines=5,
558
+ interactive=False
559
+ )
560
+ page_info = gr.Markdown(
561
+ value="" # Info of PDF page
562
+ )
563
+ rendered_image = gr.Image(
564
+ label="Processed Image (Thresholded for OCR)",
565
+ interactive=False
566
+ )
567
+ num_pages = gr.Number(
568
+ value=1, label="Current Page (slider)", visible=False
569
+ )
570
+ submit_btn = gr.Button("Extract Medicines", variant="primary")
571
+
572
+ submit_btn.click(
573
+ fn=process_input,
574
+ inputs=[file_input, temperature, page_slider],
575
+ outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
576
+ )
577
+
578
+ file_input.change(
579
+ fn=update_slider,
580
+ inputs=[file_input],
581
+ outputs=[page_slider]
582
+ )
583
+
584
+ if __name__ == "__main__":
585
+ demo.launch()
586
+
587
+
588
+
589
 
590
  ########################################## #############################################################
591