-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_usage.py
More file actions
326 lines (258 loc) · 11 KB
/
example_usage.py
File metadata and controls
326 lines (258 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
#!/usr/bin/env python3
"""
Example Usage Script
Demonstrates how to use the advanced OCR prediction pipeline
"""
import os
import numpy as np
from pathlib import Path
# Import our custom modules
from advanced_model import AdvancedFastPlateOCR, create_model
from image_preprocessor import ImagePreprocessor, ImageQualityAssessor
from ocr_predictor import PlateOCRPredictor, create_predictor
from config import ConfigManager, ConfigFactory, get_training_config
from model_export import ModelExporter
def example_1_basic_prediction():
"""Example 1: Basic prediction with dummy data"""
print("🔍 Example 1: Basic Prediction")
print("-" * 40)
# Create dummy vocabulary
vocab = ['<pad>', '<sos>', '<eos>'] + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
# Save dummy vocabulary
with open("example_vocab.json", "w") as f:
import json
json.dump(vocab, f)
# Create dummy model
model = create_model(len(vocab))
# Save dummy model
import torch
torch.save({"model": model.state_dict()}, "example_model.pth")
try:
# Initialize predictor
predictor = PlateOCRPredictor("example_model.pth", "example_vocab.json")
# Create dummy image
dummy_image = np.random.randint(0, 255, (100, 300, 3), dtype=np.uint8)
# Make prediction
result = predictor.predict_single(dummy_image, method="greedy")
print(f"✅ Predicted text: '{result.text}'")
print(f"✅ Confidence: {result.confidence:.3f}")
print(f"✅ Processing time: {result.processing_time:.3f}s")
finally:
# Cleanup
if os.path.exists("example_vocab.json"):
os.remove("example_vocab.json")
if os.path.exists("example_model.pth"):
os.remove("example_model.pth")
def example_2_image_preprocessing():
"""Example 2: Image preprocessing pipeline"""
print("\n🖼️ Example 2: Image Preprocessing")
print("-" * 40)
# Create preprocessor
preprocessor = ImagePreprocessor(
target_height=96,
target_width=512,
augment=True
)
# Create dummy image
dummy_image = np.random.randint(0, 255, (200, 400, 3), dtype=np.uint8)
# Process image for training
processed_training = preprocessor.preprocess_image(dummy_image, training=True)
print(f"✅ Training preprocessing: {dummy_image.shape} -> {processed_training.shape}")
# Process image for inference
processed_inference = preprocessor.preprocess_image(dummy_image, training=False)
print(f"✅ Inference preprocessing: {dummy_image.shape} -> {processed_inference.shape}")
# Assess image quality
quality_metrics = ImageQualityAssessor.assess_quality(dummy_image)
print(f"✅ Image quality score: {quality_metrics['quality_score']:.3f}")
print(f"✅ Sharpness: {quality_metrics['sharpness']:.1f}")
print(f"✅ Contrast: {quality_metrics['contrast']:.1f}")
def example_3_configuration_management():
"""Example 3: Configuration management"""
print("\n⚙️ Example 3: Configuration Management")
print("-" * 40)
# Create configuration manager
manager = ConfigManager()
# Get different preset configurations
training_config = ConfigFactory.create_config("training")
inference_config = ConfigFactory.create_config("inference")
lightweight_config = ConfigFactory.create_config("lightweight")
print(f"✅ Training config - batch size: {training_config.training.batch_size}")
print(f"✅ Inference config - beam width: {inference_config.inference.beam_width}")
print(f"✅ Lightweight config - hidden dim: {lightweight_config.model.hidden_dim}")
# Update configuration
manager.config = training_config
manager.update_config(training__batch_size=64, model__hidden_dim=512)
updated_config = manager.get_config()
print(f"✅ Updated batch size: {updated_config.training.batch_size}")
print(f"✅ Updated hidden dim: {updated_config.model.hidden_dim}")
# Validate configuration
is_valid = manager.validate_config()
print(f"✅ Configuration valid: {is_valid}")
def example_4_model_export():
"""Example 4: Model export and deployment"""
print("\n📦 Example 4: Model Export")
print("-" * 40)
# Create dummy model and vocabulary
vocab = ['<pad>', '<sos>', '<eos>'] + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
model = create_model(len(vocab))
# Save dummy files
with open("export_vocab.json", "w") as f:
import json
json.dump(vocab, f)
import torch
torch.save({"model": model.state_dict()}, "export_model.pth")
try:
# Create exporter
exporter = ModelExporter("export_model.pth", "export_vocab.json", "exported_example")
# Export all formats
exported_files = exporter.export_all()
print("✅ Exported files:")
for format_name, file_path in exported_files.items():
file_size = Path(file_path).stat().st_size / (1024 * 1024)
print(f" • {format_name.upper()}: {Path(file_path).name} ({file_size:.1f} MB)")
finally:
# Cleanup
if os.path.exists("export_vocab.json"):
os.remove("export_vocab.json")
if os.path.exists("export_model.pth"):
os.remove("export_model.pth")
# Remove exported directory
import shutil
if os.path.exists("exported_example"):
shutil.rmtree("exported_example")
def example_5_advanced_model():
"""Example 5: Advanced model architecture"""
print("\n🏗️ Example 5: Advanced Model Architecture")
print("-" * 40)
# Create advanced model
model = AdvancedFastPlateOCR(
vocab_size=39,
hidden=256,
num_layers=4,
nhead=8,
use_pe=True
)
# Get model information
model_info = model.get_model_info()
print(f"✅ Model parameters: {model_info['total_parameters']:,}")
print(f"✅ Model size: {model_info['model_size_mb']:.1f} MB")
print(f"✅ Uses positional encoding: {model_info['use_pe']}")
# Test forward pass
import torch
dummy_images = torch.randn(2, 3, 96, 512)
dummy_targets = torch.randint(0, 39, (2, 10))
with torch.no_grad():
output = model(dummy_images, dummy_targets)
print(f"✅ Forward pass: {dummy_images.shape} -> {output.shape}")
# Test greedy decoding
pred_ids = model.greedy_decode(dummy_images)
print(f"✅ Greedy decode: {len(pred_ids)} sequences")
# Test beam search
single_image = dummy_images[0:1]
beam_pred = model.beam_decode(single_image)
print(f"✅ Beam search: {len(beam_pred)} tokens")
def example_6_batch_processing():
"""Example 6: Batch processing"""
print("\n📦 Example 6: Batch Processing")
print("-" * 40)
# Create dummy vocabulary and model
vocab = ['<pad>', '<sos>', '<eos>'] + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
model = create_model(len(vocab))
with open("batch_vocab.json", "w") as f:
import json
json.dump(vocab, f)
import torch
torch.save({"model": model.state_dict()}, "batch_model.pth")
try:
# Initialize predictor
predictor = PlateOCRPredictor("batch_model.pth", "batch_vocab.json")
# Create multiple dummy images
dummy_images = [
np.random.randint(0, 255, (100, 300, 3), dtype=np.uint8),
np.random.randint(0, 255, (120, 350, 3), dtype=np.uint8),
np.random.randint(0, 255, (90, 280, 3), dtype=np.uint8)
]
# Batch prediction
batch_result = predictor.predict_batch(dummy_images, method="greedy")
print(f"✅ Processed {len(dummy_images)} images")
print(f"✅ Success rate: {batch_result.success_rate:.1%}")
print(f"✅ Average confidence: {batch_result.average_confidence:.3f}")
print(f"✅ Total time: {batch_result.total_time:.3f}s")
# Individual results
for i, result in enumerate(batch_result.results):
print(f" Image {i+1}: '{result.text}' (conf: {result.confidence:.3f})")
finally:
# Cleanup
if os.path.exists("batch_vocab.json"):
os.remove("batch_vocab.json")
if os.path.exists("batch_model.pth"):
os.remove("batch_model.pth")
def example_7_performance_benchmark():
"""Example 7: Performance benchmarking"""
print("\n⚡ Example 7: Performance Benchmark")
print("-" * 40)
# Create dummy vocabulary and model
vocab = ['<pad>', '<sos>', '<eos>'] + list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
model = create_model(len(vocab))
with open("bench_vocab.json", "w") as f:
import json
json.dump(vocab, f)
import torch
torch.save({"model": model.state_dict()}, "bench_model.pth")
try:
# Initialize predictor
predictor = PlateOCRPredictor("bench_model.pth", "bench_vocab.json")
# Create dummy image
dummy_image = np.random.randint(0, 255, (100, 300, 3), dtype=np.uint8)
# Warmup
print("🔥 Warming up...")
for i in range(3):
predictor.predict_single(dummy_image, method="greedy")
# Benchmark
print("📊 Running benchmark...")
import time
times = []
for i in range(10):
start_time = time.time()
result = predictor.predict_single(dummy_image, method="greedy")
end_time = time.time()
times.append(end_time - start_time)
# Calculate statistics
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
fps = 1.0 / avg_time
print(f"✅ Average time: {avg_time:.3f}s")
print(f"✅ Min time: {min_time:.3f}s")
print(f"✅ Max time: {max_time:.3f}s")
print(f"✅ FPS: {fps:.1f}")
# Performance stats
stats = predictor.get_performance_stats()
print(f"✅ Total predictions: {stats['total_predictions']}")
print(f"✅ Average time per prediction: {stats['average_time']:.3f}s")
finally:
# Cleanup
if os.path.exists("bench_vocab.json"):
os.remove("bench_vocab.json")
if os.path.exists("bench_model.pth"):
os.remove("bench_model.pth")
def main():
"""Run all examples"""
print("🚀 Advanced OCR Prediction Pipeline - Examples")
print("=" * 60)
try:
example_1_basic_prediction()
example_2_image_preprocessing()
example_3_configuration_management()
example_4_model_export()
example_5_advanced_model()
example_6_batch_processing()
example_7_performance_benchmark()
print("\n🎉 All examples completed successfully!")
print("=" * 60)
except Exception as e:
print(f"\n❌ Example failed: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()