diabolic6045 commited on
Commit
81c9f4d
Β·
verified Β·
1 Parent(s): 36c94eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for Sanskrit text transcription using Qwen2.5-VL model
4
+ Based on quick_test_improved.py
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import base64
10
+ import io
11
+ from PIL import Image
12
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
13
+ from qwen_vl_utils import process_vision_info
14
+ from peft import PeftModel
15
+ import os
16
+ import logging
17
+ import spaces
18
+
19
+ # Set up logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ class SanskritTranscriptionModel:
24
+ def __init__(self, model_path: str, adapter_path: str = None):
25
+ """Initialize the model and processor"""
26
+ self.model_path = model_path
27
+ self.adapter_path = adapter_path
28
+ self.model = None
29
+ self.processor = None
30
+ self.is_loaded = False
31
+
32
+ def load_model(self):
33
+ """Load the model and processor"""
34
+ if self.is_loaded:
35
+ return
36
+
37
+ try:
38
+ logger.info("Loading processor...")
39
+ self.processor = AutoProcessor.from_pretrained(self.model_path)
40
+
41
+ logger.info("Loading base model...")
42
+ # Check if CUDA is available, otherwise use CPU
43
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
44
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ self.model_path,
46
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
47
+ device_map=device_map
48
+ )
49
+
50
+ if self.adapter_path and os.path.exists(self.adapter_path):
51
+ logger.info("Loading LoRA adapters...")
52
+ self.model = PeftModel.from_pretrained(self.model, self.adapter_path)
53
+ else:
54
+ logger.info("No adapter path found, using base model only")
55
+
56
+ self.model.eval()
57
+ device = next(self.model.parameters()).device
58
+ logger.info(f"Model loaded on device: {device}")
59
+ self.is_loaded = True
60
+
61
+ except Exception as e:
62
+ logger.error(f"Error loading model: {e}")
63
+ raise e
64
+
65
+ def transcribe_image(self, image: Image.Image, prompt: str = None) -> str:
66
+ """Transcribe Sanskrit text from image"""
67
+ if not self.is_loaded:
68
+ self.load_model()
69
+
70
+ if prompt is None:
71
+ prompt = "Please transcribe the Sanskrit text shown in this image:"
72
+
73
+ try:
74
+ messages = [
75
+ {
76
+ "role": "user",
77
+ "content": [
78
+ {"type": "image", "image": image},
79
+ {"type": "text", "text": prompt}
80
+ ]
81
+ }
82
+ ]
83
+
84
+ # Preparation for inference
85
+ text = self.processor.apply_chat_template(
86
+ messages, tokenize=False, add_generation_prompt=True
87
+ )
88
+ image_inputs, video_inputs = process_vision_info(messages)
89
+ inputs = self.processor(
90
+ text=[text],
91
+ images=image_inputs,
92
+ videos=video_inputs,
93
+ padding=True,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ # Get model device and move inputs there
98
+ model_device = next(self.model.parameters()).device
99
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
100
+
101
+ with torch.no_grad():
102
+ generated_ids = self.model.generate(
103
+ **inputs,
104
+ max_new_tokens=512,
105
+ do_sample=False,
106
+ pad_token_id=self.processor.tokenizer.eos_token_id,
107
+ use_cache=True,
108
+ repetition_penalty=1.1
109
+ )
110
+
111
+ # Extract only the generated part
112
+ generated_ids_trimmed = [
113
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
114
+ ]
115
+ output_text = self.processor.batch_decode(
116
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
117
+ )
118
+
119
+ return output_text[0] if output_text else ""
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error generating response: {e}")
123
+ return f"Error: {str(e)}"
124
+
125
+ # Initialize the model
126
+ model_instance = None
127
+
128
+ @spaces.GPU(duration=60) # 2 minutes for model loading and inference
129
+ def initialize_model():
130
+ """Initialize the model instance with ZeroGPU support"""
131
+ global model_instance
132
+ if model_instance is None:
133
+ model_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
134
+ adapter_path = './outputs/out-qwen2-5-vl'
135
+ model_instance = SanskritTranscriptionModel(model_path, adapter_path)
136
+ return model_instance
137
+
138
+ def check_model_status():
139
+ """Check if model is loaded and ready"""
140
+ try:
141
+ model = initialize_model()
142
+ if model.is_loaded:
143
+ return "βœ… Model loaded and ready"
144
+ else:
145
+ return "⏳ Model not loaded yet"
146
+ except Exception as e:
147
+ return f"❌ Model error: {str(e)}"
148
+
149
+ @spaces.GPU(duration=30) # 1 minute for transcription
150
+ def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()):
151
+ """Gradio interface function for transcription with ZeroGPU support"""
152
+ if image is None:
153
+ return "Please upload an image first."
154
+
155
+ try:
156
+ progress(0.1, desc="Requesting GPU resources...")
157
+ model = initialize_model()
158
+
159
+ progress(0.3, desc="Processing image...")
160
+ # Use custom prompt if provided, otherwise use default
161
+ prompt = custom_prompt if custom_prompt.strip() else "Please transcribe the Sanskrit text shown in this image:"
162
+
163
+ progress(0.5, desc="Generating transcription...")
164
+ result = model.transcribe_image(image, prompt)
165
+
166
+ progress(1.0, desc="Complete!")
167
+ return result
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error in transcribe_sanskrit: {e}")
171
+ return f"❌ Error occurred: {str(e)}\n\nPlease try again or check if the model files are properly loaded."
172
+
173
+ def create_gradio_interface():
174
+ """Create and configure the Gradio interface"""
175
+
176
+ with gr.Blocks(
177
+ title="Sanskrit Text Transcription",
178
+ theme=gr.themes.Soft()
179
+ ) as app:
180
+
181
+ gr.HTML("""
182
+ <div class="main-header">
183
+ <h1>πŸ•‰οΈ Sanskrit Text Transcription</h1>
184
+ <p>Upload an image containing Sanskrit text and get an accurate transcription using AI</p>
185
+ <p><strong>πŸš€ Powered by ZeroGPU:</strong> Dynamic GPU allocation for efficient processing</p>
186
+ </div>
187
+ """)
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=1):
191
+ gr.Markdown("### Upload Image")
192
+ image_input = gr.Image(
193
+ type="pil",
194
+ label="Sanskrit Text Image",
195
+ height=400
196
+ )
197
+
198
+ gr.Markdown("### Custom Prompt (Optional)")
199
+ custom_prompt = gr.Textbox(
200
+ label="Custom transcription prompt",
201
+ placeholder="Please transcribe the Sanskrit text shown in this image:",
202
+ lines=2,
203
+ value="Please transcribe the Sanskrit text shown in this image:"
204
+ )
205
+
206
+ transcribe_btn = gr.Button(
207
+ "πŸ•‰οΈ Transcribe Sanskrit Text",
208
+ variant="primary",
209
+ size="lg"
210
+ )
211
+
212
+ gr.Markdown("""
213
+ ### Instructions:
214
+ 1. Upload an image containing Sanskrit text
215
+ 2. Optionally modify the prompt for better results
216
+ 3. Click the transcribe button
217
+ 4. View the transcribed text below
218
+ """)
219
+
220
+ with gr.Column(scale=1):
221
+ gr.Markdown("### Transcription Result")
222
+ output_text = gr.Textbox(
223
+ label="Transcribed Sanskrit Text",
224
+ lines=10,
225
+ max_lines=20,
226
+ show_copy_button=True
227
+ )
228
+
229
+ gr.Markdown("### Model Information")
230
+ model_status = gr.Textbox(
231
+ label="Model Status",
232
+ value="Checking...",
233
+ interactive=False
234
+ )
235
+
236
+ check_status_btn = gr.Button("πŸ”„ Check Model Status", size="sm")
237
+
238
+ gr.Markdown("""
239
+ **Model:** Qwen2.5-VL-7B-Instruct with LoRA fine-tuning
240
+
241
+ **Features:**
242
+ - Multimodal vision-language model
243
+ - Fine-tuned on Sanskrit text data
244
+ - Supports various Sanskrit scripts
245
+ - High accuracy transcription
246
+ """)
247
+
248
+ # Example section
249
+ with gr.Row():
250
+ gr.Markdown("### Example Images")
251
+
252
+ # Event handlers
253
+ transcribe_btn.click(
254
+ fn=transcribe_sanskrit,
255
+ inputs=[image_input, custom_prompt],
256
+ outputs=output_text,
257
+ show_progress=True
258
+ )
259
+
260
+ # Auto-transcribe when image is uploaded
261
+ image_input.change(
262
+ fn=transcribe_sanskrit,
263
+ inputs=[image_input, custom_prompt],
264
+ outputs=output_text,
265
+ show_progress=True
266
+ )
267
+
268
+ # Model status check
269
+ check_status_btn.click(
270
+ fn=check_model_status,
271
+ outputs=model_status
272
+ )
273
+
274
+ # Check model status on app load
275
+ app.load(
276
+ fn=check_model_status,
277
+ outputs=model_status
278
+ )
279
+
280
+ return app
281
+
282
+ def main():
283
+ """Main function to launch the Gradio app"""
284
+ logger.info("Starting Sanskrit Transcription Gradio App...")
285
+
286
+ # Create the interface
287
+ app = create_gradio_interface()
288
+
289
+ # Launch the app
290
+ app.launch(
291
+ server_name="0.0.0.0", # Allow external access
292
+ server_port=7860, # Default Gradio port
293
+ share=False, # Enable request queuing
294
+ max_threads=4 # Limit concurrent requests
295
+ )
296
+
297
+ if __name__ == "__main__":
298
+ main()