Gareth commited on
Commit
efb1801
·
0 Parent(s):

Initial clean commit for Hugging Face

Browse files

- Added comprehensive README.md with project documentation
- Included essential Python scripts for inference and detection
- Added Arduino code for robotic control
- Included configuration files and requirements
- Added documentation and example notebooks
- Excluded large datasets, models, and binary files for Hugging Face compatibility

This clean version focuses on code and documentation while maintaining
full project functionality. Large files can be hosted separately via
Hugging Face Datasets or other storage solutions.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. .gitignore +235 -0
  3. ArduinoCode/codingservoarm.ino +70 -0
  4. CITATION.cff +57 -0
  5. LICENSE +21 -0
  6. README +309 -0
  7. README.md +615 -0
  8. classification/README.md +146 -0
  9. classification/training_summary.md +21 -0
  10. classification_model/README.md +98 -0
  11. classification_model/training_summary.md +21 -0
  12. config.yaml +179 -0
  13. detection/README.md +138 -0
  14. docs/GITHUB_SETUP.md +316 -0
  15. docs/TRAINING_README.md +234 -0
  16. git-xet +0 -0
  17. inference_example.py +65 -0
  18. notebooks/strawberry_training.ipynb +92 -0
  19. notebooks/train_yolov8_colab.ipynb +309 -0
  20. requirements.txt +20 -0
  21. results.csv +51 -0
  22. scripts/all_combine3.py +177 -0
  23. scripts/auto_label_strawberries.py +220 -0
  24. scripts/benchmark_models.py +342 -0
  25. scripts/collect_dataset.py +41 -0
  26. scripts/combine3.py +177 -0
  27. scripts/complete_final_labeling.py +180 -0
  28. scripts/convert_tflite.py +119 -0
  29. scripts/data/preprocess_strawberry_dataset.py +91 -0
  30. scripts/detect_realtime.py +124 -0
  31. scripts/download_dataset.py +8 -0
  32. scripts/export_onnx.py +194 -0
  33. scripts/export_tflite_int8.py +200 -0
  34. scripts/get-pip.py +0 -0
  35. scripts/label_ripeness_dataset.py +192 -0
  36. scripts/optimization/optimized_onnx_inference.py +291 -0
  37. scripts/organize_labeled_images.py +115 -0
  38. scripts/setup_training.py +213 -0
  39. scripts/train_model.py +59 -0
  40. scripts/train_ripeness_classifier.py +345 -0
  41. scripts/train_yolov8.py +223 -0
  42. scripts/validate_model.py +234 -0
  43. scripts/webcam_capture.py +20 -0
  44. src/arduino_bridge.py +544 -0
  45. src/coordinate_transformer.py +482 -0
  46. src/integrated_detection_classification.py +248 -0
  47. src/strawberry_picker_pipeline.py +465 -0
  48. sync_to_huggingface.py +162 -0
  49. webcam_inference.py +441 -0
  50. yolov11n/README.md +136 -0
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
3
+ *.pth filter=lfs diff=lfs merge=lfs -text
4
+ yolov11n/*.pt filter=lfs diff=lfs merge=lfs -text
5
+ yolov11n/*.onnx filter=lfs diff=lfs merge=lfs -text
6
+ yolov8n/*.pt filter=lfs diff=lfs merge=lfs -text
7
+ yolov8n/*.onnx filter=lfs diff=lfs merge=lfs -text
8
+ yolov8s/*.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # Python
3
+ # ============================================
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+ *.so
8
+ .Python
9
+ build/
10
+ develop-eggs/
11
+ dist/
12
+ downloads/
13
+ eggs/
14
+ .eggs/
15
+ lib/
16
+ lib64/
17
+ parts/
18
+ sdist/
19
+ var/
20
+ wheels/
21
+ pip-wheel-metadata/
22
+ share/python-wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # ============================================
29
+ # Virtual Environments
30
+ # ============================================
31
+ .env
32
+ .venv
33
+ env/
34
+ venv/
35
+ ENV/
36
+ env.bak/
37
+ venv.bak/
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # ============================================
42
+ # IDE & Editor
43
+ # ============================================
44
+ .vscode/
45
+ .idea/
46
+ *.swp
47
+ *.swo
48
+ *~
49
+ .DS_Store
50
+ .DS_Store?
51
+ ._*
52
+ .Spotlight-V100
53
+ .Trashes
54
+ ehthumbs.db
55
+ Thumbs.db
56
+
57
+ # ============================================
58
+ # Jupyter Notebook
59
+ # ============================================
60
+ .ipynb_checkpoints
61
+ */.ipynb_checkpoints/*
62
+
63
+ # ============================================
64
+ # Model Files & Training Artifacts
65
+ # ============================================
66
+ # PyTorch / TensorFlow / Keras
67
+ *.pth
68
+ *.pt
69
+ *.pkl
70
+ *.h5
71
+ *.tflite
72
+ *.onnx
73
+ *.pb
74
+ *.weights
75
+ *.cfg
76
+ *.data
77
+ *.pth.tar
78
+
79
+ # Model directories (auto-generated)
80
+ model/weights/
81
+ model/results/
82
+ model/exports/
83
+ runs/
84
+ detect/
85
+ predict/
86
+ val/
87
+ checkpoint/
88
+ checkpoints/
89
+
90
+ # ============================================
91
+ # Dataset & Large Files
92
+ # ============================================
93
+ # Raw images (can be regenerated)
94
+ model/dataset/**/*.jpg
95
+ model/dataset/**/*.jpeg
96
+ model/dataset/**/*.png
97
+ model/dataset/**/*.zip
98
+ !model/dataset/**/data.yaml
99
+ !model/dataset/**/*.txt
100
+
101
+ # Roboflow exports
102
+ straw-detect.v1-straw-detect.yolov8/
103
+
104
+ # ============================================
105
+ # Logs & Caches
106
+ # ============================================
107
+ *.log
108
+ logs/
109
+ .cache/
110
+ .parquet
111
+ .pytest_cache/
112
+ .coverage
113
+ htmlcov/
114
+ .tox/
115
+ .nox/
116
+
117
+ # ============================================
118
+ # Build & Distribution
119
+ # ============================================
120
+ build/
121
+ develop-eggs/
122
+ dist/
123
+ downloads/
124
+ eggs/
125
+ .eggs/
126
+ lib/
127
+ lib64/
128
+ parts/
129
+ sdist/
130
+ var/
131
+ wheels/
132
+ share/python-wheels/
133
+ *.egg-info/
134
+ .installed.cfg
135
+ *.egg
136
+ MANIFEST
137
+ pyproject.toml
138
+ poetry.lock
139
+
140
+ # ============================================
141
+ # Temporary Files
142
+ # ============================================
143
+ tmp/
144
+ temp/
145
+ *.tmp
146
+ *.temp
147
+ *.swp
148
+ *.swo
149
+ *~
150
+
151
+ # ============================================
152
+ # Media & Output Files
153
+ # ============================================
154
+ *.avi
155
+ *.mp4
156
+ *.mov
157
+ *.mkv
158
+ *.gcode
159
+ *.factory
160
+
161
+ # ============================================
162
+ # ML Experiment Tracking
163
+ # ============================================
164
+ mlruns/
165
+ mlflow.db
166
+ wandb/
167
+
168
+ # ============================================
169
+ # TensorBoard
170
+ # ============================================
171
+ runs/
172
+ *.tfevents.*
173
+ events.out.tfevents.*
174
+
175
+ # ============================================
176
+ # Calibration & Configuration
177
+ # ============================================
178
+ *.calib
179
+ *.yaml
180
+ *.yml
181
+ !data.yaml
182
+ !requirements.txt
183
+ !setup.py
184
+ !pyproject.toml
185
+
186
+ # ============================================
187
+ # Arduino Build Files
188
+ # ============================================
189
+ ArduinoCode/*.hex
190
+ ArduinoCode/*.elf
191
+ ArduinoCode/build/
192
+ ArduinoCode/*.bin
193
+
194
+ # ============================================
195
+ # SolidWorks & CAD Temporary Files
196
+ # ============================================
197
+ assets/solidworks/*.tmp
198
+ assets/solidworks/~$*
199
+ assets/solidworks/*.bak
200
+ assets/solidworks/*.swp
201
+
202
+ # ============================================
203
+ # Documentation Builds
204
+ # ============================================
205
+ docs/_build/
206
+ site/
207
+ *.pdf
208
+ *.docx
209
+ *.doc
210
+ *.xlsx
211
+ *.pptx
212
+
213
+ # ============================================
214
+ # Secrets & Local Configuration
215
+ # ============================================
216
+ config.local.yaml
217
+ secrets.yaml
218
+ credentials.json
219
+ *.key
220
+ *.pem
221
+
222
+ # ============================================
223
+ # Debug & Development
224
+ # ============================================
225
+ debug/
226
+ debug_images/
227
+ *.debug
228
+
229
+ # ============================================
230
+ # Robot Specific
231
+ # ============================================
232
+ robot/logs/
233
+ robot/calibration/
234
+ robot/trajectories/
235
+ *.bag
ArduinoCode/codingservoarm.ino ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <Wire.h>
2
+ #include <Adafruit_PWMServoDriver.h>
3
+
4
+ Adafruit_PWMServoDriver pwm = Adafruit_PWMServoDriver();
5
+
6
+ // Servo pulse limits (tune for your specific servos)
7
+ #define SERVOMIN 150 // 0 degrees
8
+ #define SERVOMAX 600 // 180 degrees
9
+
10
+ float currentAngle = 0; // Start at 90 degrees
11
+ float targetAngle = 0;
12
+
13
+ int angleToPulse(float angle) {
14
+ return map((int)angle, 0, 180, SERVOMIN, SERVOMAX);
15
+ }
16
+
17
+ void setup() {
18
+ Serial.begin(9600);
19
+ Serial.println("Type an angle (0–180) and press Enter:");
20
+
21
+ pwm.begin();
22
+ pwm.setPWMFreq(50);
23
+ delay(10);
24
+
25
+ // Initialize both servos to center position
26
+ int pulse = angleToPulse(currentAngle);
27
+ pwm.setPWM(0, 0, pulse);
28
+ pwm.setPWM(1, 0, pulse);
29
+ }
30
+
31
+ void loop() {
32
+ if (Serial.available()) {
33
+ int angle = Serial.parseInt();
34
+
35
+ if (angle >= 0 && angle <= 180) {
36
+ targetAngle = angle;
37
+ Serial.print("Slowly moving servos to ");
38
+ Serial.print(targetAngle);
39
+ Serial.println(" degrees...");
40
+
41
+ // Smoothly step toward target angle
42
+ float step = 0.10; // smaller = smoother and slower
43
+ int delayTime = 5; // larger = slower (try 50 or 70 for extra slow)
44
+
45
+ if (targetAngle > currentAngle) {
46
+ for (float pos = currentAngle; pos <= targetAngle; pos += step) {
47
+ int pulse = angleToPulse(pos);
48
+ pwm.setPWM(0, 0, pulse);
49
+ pwm.setPWM(1, 0, pulse);
50
+ delay(delayTime);
51
+ }
52
+ } else {
53
+ for (float pos = currentAngle; pos >= targetAngle; pos -= step) {
54
+ int pulse = angleToPulse(pos);
55
+ pwm.setPWM(0, 0, pulse);
56
+ pwm.setPWM(1, 0, pulse);
57
+ delay(delayTime);
58
+ }
59
+ }
60
+
61
+ currentAngle = targetAngle;
62
+ Serial.println("Done!");
63
+ } else {
64
+ Serial.println("Please enter an angle between 0 and 180.");
65
+ }
66
+
67
+ // Clear serial buffer
68
+ while (Serial.available()) Serial.read();
69
+ }
70
+ }
CITATION.cff ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ type: software
4
+ authors:
5
+ - given-names: "Gareth"
6
+ family-names: "theonegareth"
7
+ email: "gareth@example.com"
8
+ title: "Strawberry Picker AI System"
9
+ version: "1.0.0"
10
+ date-released: "2025-12-13"
11
+ url: "https://huggingface.co/theonegareth/strawberryPicker"
12
+ repository-code: "https://github.com/theonegareth/strawberryPicker"
13
+ keywords:
14
+ - computer-vision
15
+ - object-detection
16
+ - image-classification
17
+ - agriculture
18
+ - robotics
19
+ - strawberry
20
+ - ripeness-detection
21
+ - yolov11
22
+ - efficientnet
23
+ license: MIT
24
+ abstract: >
25
+ A complete AI-powered strawberry picking system that combines object detection
26
+ and ripeness classification to identify and pick only ripe strawberries.
27
+ This two-stage pipeline achieves 91.71% accuracy in ripeness classification
28
+ while maintaining real-time performance suitable for robotic harvesting applications.
29
+
30
+ references:
31
+ - type: software
32
+ title: "strawberry-models"
33
+ authors:
34
+ - given-names: "Gareth"
35
+ family-names: "theonegareth"
36
+ url: "https://huggingface.co/theonegareth/strawberry-models"
37
+ version: "1.0.0"
38
+ date-released: "2025-12-13"
39
+
40
+ - type: dataset
41
+ title: "strawberry-detect"
42
+ authors:
43
+ - given-names: "Gareth"
44
+ family-names: "theonegareth"
45
+ url: "https://universe.roboflow.com/theonegareth/strawberry-detect"
46
+ version: "1.0.0"
47
+ date-released: "2025-12-13"
48
+
49
+ preferred-citation:
50
+ type: software
51
+ title: "Strawberry Picker AI System: A Two-Stage Approach for Automated Harvesting"
52
+ authors:
53
+ - given-names: "Gareth"
54
+ family-names: "theonegareth"
55
+ version: "1.0.0"
56
+ url: "https://huggingface.co/theonegareth/strawberryPicker"
57
+ date-released: "2025-12-13"
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Gareth
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Strawberry Picker - AI-Powered Robotic System
2
+
3
+ ## 🎯 Project Overview
4
+
5
+ The Strawberry Picker is a sophisticated AI-powered robotic system that automatically detects, classifies, and picks ripe strawberries using computer vision and machine learning. The system combines YOLOv11 object detection with custom ripeness classification to enable precise robotic harvesting.
6
+
7
+ ## ✅ Project Status: COMPLETE
8
+
9
+ **Dataset**: 100% Complete (889/889 images labeled)
10
+ **Models**: Trained and validated (94% accuracy)
11
+ **Pipeline**: Fully integrated and tested
12
+ **Hardware Integration**: Arduino communication ready
13
+
14
+ ## 📊 Key Achievements
15
+
16
+ ### 1. Dataset Completion
17
+ - **Total Images**: 889 labeled images
18
+ - **Classes**: 3-class ripeness classification
19
+ - Unripe: 317 images (35.7%)
20
+ - Ripe: 446 images (50.2%)
21
+ - Overripe: 126 images (14.2%)
22
+ - **Automation**: 82% automated labeling success rate
23
+
24
+ ### 2. Machine Learning Models
25
+ - **Detection Model**: YOLOv11n optimized for strawberry detection
26
+ - **Classification Model**: Custom CNN with 94% accuracy
27
+ - **Model Formats**: PyTorch, ONNX, TensorFlow Lite (INT8 quantized)
28
+ - **Performance**: Optimized for Raspberry Pi deployment
29
+
30
+ ### 3. Robotic Integration
31
+ - **Coordinate Transformation**: Pixel-to-robot coordinate mapping
32
+ - **Arduino Communication**: Serial bridge for robotic arm control
33
+ - **Real-time Processing**: Live detection and classification pipeline
34
+ - **Error Handling**: Comprehensive recovery mechanisms
35
+
36
+ ## 🏗️ Project Structure
37
+
38
+ ```
39
+ strawberryPicker/
40
+ ├── config.yaml # Unified configuration file
41
+ ├── README.md # This file
42
+ ├── scripts/ # Utility and training scripts
43
+ │ ├── train_yolov8.py # YOLOv11 training script
44
+ │ ├── train_ripeness_classifier.py # Ripeness classification training
45
+ │ ├── detect_realtime.py # Real-time detection script
46
+ │ ├── auto_label_strawberries.py # Automated labeling tool
47
+ │ ├── benchmark_models.py # Performance benchmarking
48
+ │ └── export_*.py # Model export scripts
49
+ ├── model/ # Trained models and datasets
50
+ │ ├── weights/ # YOLOv11 model weights
51
+ │ ├── ripeness_classifier.pkl # Trained classifier
52
+ │ └── dataset_strawberry_detect_v3/ # Detection dataset
53
+ ├── src/ # Core pipeline components
54
+ │ ├── strawberry_picker_pipeline.py # Main pipeline
55
+ │ ├── arduino_bridge.py # Arduino communication
56
+ │ └── coordinate_transformer.py # Coordinate mapping
57
+ ├── docs/ # Documentation
58
+ │ ├── INTEGRATION_GUIDE.md # ML to Arduino integration
59
+ │ ├── TROUBLESHOOTING.md # Common issues and solutions
60
+ │ └── PERFORMANCE.md # Benchmarks and optimization
61
+ └── ArduinoCode/ # Arduino robotic arm code
62
+ └── codingservoarm.ino # Servo control firmware
63
+ ```
64
+
65
+ ## 🚀 Quick Start
66
+
67
+ ### 1. Environment Setup
68
+ ```bash
69
+ # Install dependencies
70
+ pip install ultralytics opencv-python numpy scikit-learn pyyaml
71
+
72
+ # Clone and setup
73
+ git clone <repository>
74
+ cd strawberryPicker
75
+ ```
76
+
77
+ ### 2. Configuration
78
+ Edit `config.yaml` with your specific settings:
79
+ ```yaml
80
+ detection:
81
+ model_path: model/weights/yolo11n_strawberry_detect_v3.pt
82
+ confidence_threshold: 0.5
83
+
84
+ serial:
85
+ port: /dev/ttyUSB0
86
+ baudrate: 115200
87
+
88
+ robot:
89
+ workspace_bounds:
90
+ x_min: 0, x_max: 300
91
+ y_min: 0, y_max: 200
92
+ z_min: 50, z_max: 150
93
+ ```
94
+
95
+ ### 3. Run Detection
96
+ ```bash
97
+ # Real-time detection
98
+ python3 scripts/detect_realtime.py --model model/weights/yolo11n_strawberry_detect_v3.pt
99
+
100
+ # With ripeness classification
101
+ python3 scripts/integrated_detection_classification.py
102
+ ```
103
+
104
+ ### 4. Arduino Integration
105
+ ```bash
106
+ # Test Arduino communication
107
+ python3 src/arduino_bridge.py
108
+
109
+ # Run complete pipeline
110
+ python3 src/strawberry_picker_pipeline.py --config config.yaml
111
+ ```
112
+
113
+ ## 📈 Performance Metrics
114
+
115
+ ### Model Performance
116
+ - **Detection Accuracy**: 94.2% mAP@0.5
117
+ - **Classification Accuracy**: 94.0% overall
118
+ - **Inference Speed**: 15ms per frame (Raspberry Pi 4B)
119
+ - **Memory Usage**: <500MB RAM
120
+
121
+ ### System Performance
122
+ - **Coordinate Transformation**: <1ms per conversion
123
+ - **Image Processing**: 30 FPS real-time capability
124
+ - **Serial Communication**: 115200 baud stable
125
+ - **End-to-end Latency**: <100ms detection to action
126
+
127
+ ## 🔧 Hardware Requirements
128
+
129
+ ### Minimum System
130
+ - **Raspberry Pi 4B** (4GB RAM recommended)
131
+ - **USB Camera** (640x480 @ 30fps)
132
+ - **Arduino Uno/Nano** with servo shield
133
+ - **3x SG90 Servos** (robotic arm)
134
+ - **Power Supply** (5V 3A for Pi, 6V 2A for servos)
135
+
136
+ ### Recommended Setup
137
+ - **Raspberry Pi 5** (8GB RAM)
138
+ - **USB 3.0 Camera** (1080p @ 60fps)
139
+ - **Arduino Mega** (more servo channels)
140
+ - **Industrial Servos** (MG996R or similar)
141
+ - **Stereo Camera Setup** (for depth estimation)
142
+
143
+ ## 🎮 Usage Examples
144
+
145
+ ### Basic Detection
146
+ ```python
147
+ from ultralytics import YOLO
148
+
149
+ # Load model
150
+ model = YOLO('model/weights/yolo11n_strawberry_detect_v3.pt')
151
+
152
+ # Run detection
153
+ results = model('test_image.jpg')
154
+ ```
155
+
156
+ ### Ripeness Classification
157
+ ```python
158
+ import pickle
159
+ from PIL import Image
160
+ import cv2
161
+
162
+ # Load classifier
163
+ with open('model/ripeness_classifier.pkl', 'rb') as f:
164
+ classifier = pickle.load(f)
165
+
166
+ # Classify strawberry
167
+ image = cv2.imread('strawberry_crop.jpg')
168
+ prediction = classifier.predict([image.flatten()])
169
+ ```
170
+
171
+ ### Coordinate Transformation
172
+ ```python
173
+ # Pixel to robot coordinates
174
+ robot_x, robot_y, robot_z = pixel_to_robot(320, 240, depth=100)
175
+ print(f"Robot position: ({robot_x:.1f}, {robot_y:.1f}, {robot_z:.1f})")
176
+ ```
177
+
178
+ ## 🔍 Troubleshooting
179
+
180
+ ### Common Issues
181
+
182
+ 1. **Camera Not Detected**
183
+ ```bash
184
+ # Check camera permissions
185
+ sudo usermod -a -G video $USER
186
+ # Restart session
187
+ ```
188
+
189
+ 2. **Arduino Connection Failed**
190
+ ```bash
191
+ # Check port permissions
192
+ sudo usermod -a -G dialout $USER
193
+ # Verify port
194
+ ls /dev/ttyUSB*
195
+ ```
196
+
197
+ 3. **Model Loading Errors**
198
+ ```bash
199
+ # Verify model files exist
200
+ ls -la model/weights/
201
+ # Check file permissions
202
+ chmod 644 model/weights/*.pt
203
+ ```
204
+
205
+ 4. **Performance Issues**
206
+ ```bash
207
+ # Monitor system resources
208
+ htop
209
+ # Check temperature
210
+ vcgencmd measure_temp
211
+ ```
212
+
213
+ ### Performance Optimization
214
+
215
+ 1. **Reduce Model Size**
216
+ ```python
217
+ # Use smaller YOLOv8n model
218
+ model = YOLO('yolov8n.pt') # Instead of yolov8s.pt
219
+ ```
220
+
221
+ 2. **Optimize Camera Settings**
222
+ ```python
223
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
224
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
225
+ cap.set(cv2.CAP_PROP_FPS, 15) # Reduce FPS
226
+ ```
227
+
228
+ 3. **Enable Hardware Acceleration**
229
+ ```bash
230
+ # Enable Pi GPU
231
+ sudo raspi-config
232
+ # Interface Options > GL Driver > GL (Fake KMS)
233
+ ```
234
+
235
+ ## 📚 API Reference
236
+
237
+ ### Core Classes
238
+
239
+ #### `StrawberryPickerPipeline`
240
+ Main pipeline class for end-to-end operation.
241
+
242
+ ```python
243
+ pipeline = StrawberryPickerPipeline(config_path="config.yaml")
244
+ pipeline.run() # Start real-time processing
245
+ ```
246
+
247
+ #### `ArduinoBridge`
248
+ Handles serial communication with Arduino.
249
+
250
+ ```python
251
+ bridge = ArduinoBridge(port="/dev/ttyUSB0")
252
+ bridge.connect()
253
+ bridge.send_command("PICK,100,50,80")
254
+ ```
255
+
256
+ #### `CoordinateTransformer`
257
+ Manages coordinate transformations.
258
+
259
+ ```python
260
+ transformer = CoordinateTransformer()
261
+ robot_coords = transformer.pixel_to_robot(pixel_x, pixel_y, depth)
262
+ ```
263
+
264
+ ### Configuration Options
265
+
266
+ | Parameter | Type | Default | Description |
267
+ |-----------|------|---------|-------------|
268
+ | `detection.confidence_threshold` | float | 0.5 | Detection confidence cutoff |
269
+ | `detection.image_size` | int | 640 | Input image size |
270
+ | `serial.baudrate` | int | 115200 | Arduino communication speed |
271
+ | `robot.workspace_bounds` | dict | - | Robot movement limits |
272
+
273
+ ## 🤝 Contributing
274
+
275
+ ### Development Setup
276
+ 1. Fork the repository
277
+ 2. Create feature branch: `git checkout -b feature-name`
278
+ 3. Make changes and test thoroughly
279
+ 4. Submit pull request with detailed description
280
+
281
+ ### Code Standards
282
+ - Follow PEP 8 style guidelines
283
+ - Add docstrings to all functions
284
+ - Include unit tests for new features
285
+ - Update documentation as needed
286
+
287
+ ## 📄 License
288
+
289
+ This project is licensed under the MIT License - see the LICENSE file for details.
290
+
291
+ ## 🙏 Acknowledgments
292
+
293
+ - **Ultralytics** for YOLOv11 implementation
294
+ - **OpenCV** for computer vision tools
295
+ - **Arduino Community** for servo control examples
296
+ - **Raspberry Pi Foundation** for embedded computing platform
297
+
298
+ ## 📞 Support
299
+
300
+ For questions, issues, or contributions:
301
+ - Create an issue on GitHub
302
+ - Check the troubleshooting guide in `docs/`
303
+ - Review the integration guide for setup help
304
+
305
+ ---
306
+
307
+ **Status**: ✅ Production Ready
308
+ **Last Updated**: December 15, 2025
309
+ **Version**: 1.0.0
README.md ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # 🍓 StrawberryPicker - AI-Powered Robotic Harvesting System
3
+
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
5
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
6
+ [![YOLOv8](https://img.shields.io/badge/YOLOv8-Ultralytics-green.svg)](https://github.com/ultralytics/ultralytics)
7
+ [![Raspberry Pi](https://img.shields.io/badge/Raspberry%20Pi-4B-red.svg)](https://www.raspberrypi.org/)
8
+
9
+ A complete AI-powered robotic system for automated strawberry detection, ripeness classification, and precision harvesting. Built with YOLOv8, custom CNN classifiers, and Arduino integration for real-world agricultural automation.
10
+
11
+ ## 🎯 Project Overview
12
+
13
+ StrawberryPicker is a comprehensive machine learning and robotics system designed for autonomous strawberry harvesting. The system combines computer vision, deep learning, and robotic control to identify ripe strawberries and coordinate precise robotic picking operations.
14
+
15
+ ### 🚀 Key Features
16
+
17
+ - **Real-time Detection**: YOLOv8 optimized for 30 FPS performance on Raspberry Pi 4B
18
+ - **3-Class Ripeness Classification**: Unripe, Ripe, Overripe with 94% accuracy
19
+ - **Complete Dataset**: 889 labeled images across all ripeness stages
20
+ - **Arduino Integration**: Serial communication with robotic arm control
21
+ - **Edge Deployment**: TensorFlow Lite with INT8 quantization
22
+ - **Stereo Vision**: Dual-camera depth estimation system
23
+ - **Safety Systems**: Emergency stops, coordinate validation, error recovery
24
+
25
+ ## 📁 Project Structure
26
+
27
+ ```
28
+ strawberryPicker/
29
+ ├── src/ # Core integration modules
30
+ │ ├── arduino_bridge.py # Serial communication (14.9 KB)
31
+ │ ├── coordinate_transformer.py # Pixel-to-robot mapping (20.1 KB)
32
+ │ ├── integrated_detection_classification.py # ML pipeline (9.5 KB)
33
+ │ └── strawberry_picker_pipeline.py # End-to-end system (16.8 KB)
34
+ ├── scripts/ # Utility and training scripts
35
+ │ ├── collect_dataset.py # Dataset management
36
+ │ ├── train_yolov8.py # YOLO training
37
+ │ ├── train_ripeness_classifier.py # CNN classification
38
+ │ ├── export_tflite_int8.py # Edge optimization
39
+ │ ├── benchmark_models.py # Performance testing
40
+ │ └── auto_label_strawberries.py # Automated labeling
41
+ ├── model/ # Datasets and trained models
42
+ │ ├── dataset_stem_label/ # YOLO detection dataset (889 images)
43
+ │ ├── ripeness_classification_dataset/ # CNN classification data
44
+ │ └── *.pt, *.onnx, *.tflite # Trained model exports
45
+ ├── notebooks/ # Jupyter training notebooks
46
+ ├── docs/ # Documentation and guides
47
+ ├── assets/ # CAD files and reference images
48
+ ├── calibration/ # Camera calibration data
49
+ ├── ArduinoCode/ # Robotic arm firmware
50
+ ├── config.yaml # Central configuration
51
+ ├── requirements.txt # Python dependencies
52
+ └── README.md # This file
53
+ ```
54
+
55
+ ## 🏗️ System Architecture
56
+
57
+ ### Machine Learning Pipeline
58
+ 1. **Detection Stage**: YOLOv8 identifies strawberry locations
59
+ 2. **Classification Stage**: Custom CNN determines ripeness (3 classes)
60
+ 3. **Coordinate Transformation**: Pixel coordinates → Robot coordinates
61
+ 4. **Action Decision**: Harvest, skip, or wait based on ripeness
62
+
63
+ ### Hardware Integration
64
+ - **Dual Cameras**: Stereo vision for depth estimation
65
+ - **Raspberry Pi 4B**: Main processing unit
66
+ - **Arduino Uno**: Robotic arm control
67
+ - **Servo Motors**: Precision positioning and gripper control
68
+
69
+ ## 🚀 Quick Start
70
+
71
+ ### 1. Environment Setup
72
+
73
+ ```bash
74
+ # Clone repository
75
+ git clone https://huggingface.co/theonegareth/strawberryPicker
76
+ =======
77
+ ---
78
+ language: en
79
+ license: mit
80
+ tags:
81
+ - computer-vision
82
+ - object-detection
83
+ - image-classification
84
+ - agriculture
85
+ - robotics
86
+ - strawberry
87
+ - ripeness-detection
88
+ - yolov11
89
+ - efficientnet
90
+ - pytorch
91
+ datasets:
92
+ - custom
93
+ metrics:
94
+ - accuracy
95
+ - precision
96
+ - recall
97
+ - f1-score
98
+ - mAP50
99
+ pipeline_tag: object-detection
100
+ inference: true
101
+ ---
102
+
103
+ # 🍓 Strawberry Picker AI System
104
+
105
+ <div align="center">
106
+ <img src="https://img.shields.io/badge/Accuracy-91.71%25-brightgreen" alt="Accuracy">
107
+ <img src="https://img.shields.io/badge/Model-YOLOv11n%20%2B%20EfficientNet-blue" alt="Model Type">
108
+ <img src="https://img.shields.io/badge/License-MIT-yellow" alt="License">
109
+ <img src="https://img.shields.io/badge/Python-3.8%2B-blue" alt="Python">
110
+ <img src="https://img.shields.io/badge/PyTorch-2.0%2B-orange" alt="PyTorch">
111
+ </div>
112
+
113
+ ## 🎯 Overview
114
+
115
+ A complete AI-powered strawberry picking system that combines **object detection** and **ripeness classification** to identify and pick only ripe strawberries. This two-stage pipeline achieves **91.71% accuracy** in ripeness classification while maintaining real-time performance suitable for robotic harvesting applications.
116
+
117
+ **Repository**: [https://huggingface.co/theonegareth/strawberryPicker](https://huggingface.co/theonegareth/strawberryPicker)
118
+ **GitHub**: [https://github.com/theonegareth/strawberryPicker](https://github.com/theonegareth/strawberryPicker)
119
+
120
+ ## 🏗️ System Architecture
121
+
122
+ ```mermaid
123
+ graph TD
124
+ A[Input Image] --> B[YOLOv11n Detector]
125
+ B --> C{Detected Strawberries}
126
+ C --> D[Crop & Resize]
127
+ D --> E[EfficientNet-B0 Classifier]
128
+ E --> F{Ripeness Prediction}
129
+ F --> G[Decision: Pick Only Ripe]
130
+
131
+ style A fill:#f9f9f9
132
+ style B fill:#e3f2fd
133
+ style E fill:#fff3e0
134
+ style G fill:#c8e6c9
135
+ ```
136
+
137
+ ### **Two-Stage Pipeline:**
138
+
139
+ 1. **Detection Stage**: YOLOv11n model identifies and locates strawberries in images
140
+ 2. **Classification Stage**: EfficientNet-B0 classifies each detected strawberry into 4 ripeness categories
141
+ 3. **Decision Stage**: System recommends picking only ripe strawberries
142
+
143
+ ## 📊 Model Overview
144
+
145
+ ### Two-Stage Picking System
146
+
147
+ | Component | Model | Architecture | Performance | Size | Purpose |
148
+ |-----------|-------|--------------|-------------|------|---------|
149
+ | [Detection](detection/) | YOLOv11n | Object Detection | mAP@50: 84.0% | 5.2MB | Locate strawberries |
150
+ | [Classification](classification/) | EfficientNet-B0 | Image Classification | Accuracy: 91.71% | 56MB | Classify ripeness |
151
+
152
+ ### Additional Detection Models
153
+
154
+ | Model | Architecture | Performance | Size | Best For |
155
+ |-------|--------------|-------------|------|----------|
156
+ | [YOLOv8n](yolov8n/) | YOLOv8 Nano | mAP@50: 98.9% | 5.7MB | Edge deployment, real-time |
157
+ | [YOLOv8s](yolov8s/) | YOLOv8 Small | mAP@50: 93.7% | 21MB | Higher accuracy applications |
158
+ | [YOLOv11n](yolov11n/) | YOLOv11 Nano | Testing | 10.4MB | Latest architecture testing |
159
+
160
+ ## 🚀 Quick Start
161
+
162
+ ### Installation
163
+
164
+ ```bash
165
+ # Clone repository
166
+ git clone https://github.com/theonegareth/strawberryPicker.git
167
+ >>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
168
+ cd strawberryPicker
169
+
170
+ # Install dependencies
171
+ pip install -r requirements.txt
172
+ <<<<<<< HEAD
173
+
174
+ # Validate setup
175
+ python scripts/setup_training.py --validate-only
176
+ ```
177
+
178
+ ### 2. Model Training
179
+
180
+ **Detection Model (YOLOv8)**:
181
+ ```bash
182
+ python scripts/train_yolov8.py --epochs 100 --export-onnx --export-tflite
183
+ ```
184
+
185
+ **Ripeness Classification (CNN)**:
186
+ ```bash
187
+ python scripts/train_ripeness_classifier.py --epochs 50 --batch-size 32
188
+ ```
189
+
190
+ ### 3. Real-time Pipeline
191
+
192
+ ```bash
193
+ # Start complete system
194
+ python src/strawberry_picker_pipeline.py --config config.yaml
195
+
196
+ # Or run individual components
197
+ python src/integrated_detection_classification.py --model-path model/exports/
198
+ ```
199
+
200
+ ### 4. Arduino Integration
201
+
202
+ ```bash
203
+ # Upload firmware to Arduino
204
+ # Upload ArduinoCode/codingservoarm.ino via Arduino IDE
205
+
206
+ # Start bridge communication
207
+ python src/arduino_bridge.py --port /dev/ttyUSB0 --baud 9600
208
+ ```
209
+
210
+ ## 📊 Performance Metrics
211
+
212
+ ### Detection Performance
213
+ - **Model**: YOLOv8n (nano)
214
+ - **mAP@0.5**: 94.2%
215
+ - **Inference Speed**: 30 FPS on Raspberry Pi 4B
216
+ - **Model Size**: 6.2MB (INT8 quantized)
217
+
218
+ ### Classification Performance
219
+ - **Model**: Custom CNN (3 classes)
220
+ - **Accuracy**: 94% across unripe/ripe/overripe
221
+ - **Training Time**: 45 minutes on GPU
222
+ - **Inference Speed**: 15ms per image
223
+
224
+ ### System Performance
225
+ - **End-to-end Latency**: <100ms
226
+ - **Power Consumption**: 5W (Raspberry Pi + cameras)
227
+ - **Operating Range**: 0.3m - 2.0m from target
228
+ - **Precision**: ±2mm positioning accuracy
229
+
230
+ ## 🔧 Model Optimization
231
+
232
+ ### TensorFlow Lite Pipeline
233
+ ```bash
234
+ # Export and optimize for edge deployment
235
+ python scripts/export_tflite_int8.py \
236
+ --model-path model/yolov8n_strawberry.pt \
237
+ --input-size 640 \
238
+ --quantize-int8 \
239
+ --calibration-dataset model/calibration/
240
+ ```
241
+
242
+ ### ONNX Export
243
+ ```bash
244
+ # Cross-platform model export
245
+ python scripts/export_onnx.py \
246
+ --model-path model/yolov8n_strawberry.pt \
247
+ --opset 11 \
248
+ --dynamic-axis
249
+ ```
250
+
251
+ ## 📈 Dataset Details
252
+
253
+ ### Detection Dataset (YOLO Format)
254
+ - **Total Images**: 889 labeled images
255
+ - **Classes**: 1 (strawberry)
256
+ - **Format**: YOLOv8 with bounding box annotations
257
+ - **Split**: 70% train, 20% validation, 10% test
258
+ - **Resolution**: 640x640 pixels
259
+
260
+ ### Classification Dataset
261
+ - **Total Crops**: 2,847 strawberry crops
262
+ - **Classes**: 3 (unripe: 317, ripe: 446, overripe: 126)
263
+ - **Format**: Individual cropped images (224x224)
264
+ - **Augmentation**: Rotation, brightness, contrast variations
265
+
266
+ ### Automated Labeling
267
+ - **Color-based Analysis**: HSV color space analysis
268
+ - **Success Rate**: 82% automatic labeling accuracy
269
+ - **Manual Review**: Batch processing interface for corrections
270
+
271
+ ## 🛠️ Development
272
+
273
+ ### Key Scripts
274
+ - `src/strawberry_picker_pipeline.py` - Main system integration
275
+ - `src/integrated_detection_classification.py` - ML pipeline
276
+ - `src/coordinate_transformer.py` - Coordinate transformation
277
+ - `src/arduino_bridge.py` - Hardware communication
278
+
279
+ ### Configuration
280
+ All system parameters are centralized in `config.yaml`:
281
+ ```yaml
282
+ model:
283
+ detection_model: "model/exports/yolov8n_strawberry_int8.tflite"
284
+ classification_model: "model/ripeness_classifier.h5"
285
+
286
+ camera:
287
+ width: 640
288
+ height: 480
289
+ fps: 30
290
+
291
+ robot:
292
+ arduino_port: "/dev/ttyUSB0"
293
+ baud_rate: 9600
294
+ workspace_limits:
295
+ x: [0, 300] # mm
296
+ y: [0, 300] # mm
297
+ z: [0, 200] # mm
298
+ ```
299
+
300
+ ## 🔬 Technical Specifications
301
+
302
+ ### Computer Vision
303
+ - **Framework**: PyTorch + Ultralytics YOLOv8
304
+ - **Preprocessing**: Letterbox resize, normalization
305
+ - **Post-processing**: Non-maximum suppression, confidence thresholding
306
+ - **Augmentation**: Mosaic, random perspective, HSV adjustment
307
+
308
+ ### Machine Learning
309
+ - **Detection**: YOLOv8n architecture (9.1M parameters)
310
+ - **Classification**: Custom CNN (5 layers, 2.3M parameters)
311
+ - **Optimizer**: AdamW with cosine annealing
312
+ - **Loss Functions**: CIoU (detection), Categorical Crossentropy (classification)
313
+
314
+ ### Hardware Requirements
315
+ - **Minimum**: Raspberry Pi 4B (4GB RAM), USB cameras
316
+ - **Recommended**: Raspberry Pi 4B (8GB RAM), CSI cameras
317
+ - **Development**: GPU with 8GB+ VRAM for training
318
+
319
+ ## 🚨 Safety Features
320
+
321
+ - **Emergency Stop**: Hardware and software emergency stops
322
+ - **Coordinate Validation**: Bounds checking for robot movements
323
+ - **Error Recovery**: Automatic retry mechanisms for failed operations
324
+ - **Limit Switches**: Physical safety limits on robotic arm
325
+ - **Collision Detection**: Vision-based obstacle avoidance
326
+
327
+ ## 📚 Documentation
328
+
329
+ - **[Training Guide](docs/TRAINING_README.md)** - Detailed training instructions
330
+ - **[Integration Guide](docs/INTEGRATION.md)** - Hardware setup and calibration
331
+ - **[API Reference](docs/API.md)** - Code documentation and examples
332
+ - **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
333
+
334
+ ## 🤝 Contributing
335
+
336
+ 1. **Fork** the repository
337
+ 2. **Create** a feature branch (`git checkout -b feature/amazing-feature`)
338
+ 3. **Commit** your changes (`git commit -m 'Add amazing feature'`)
339
+ 4. **Push** to the branch (`git push origin feature/amazing-feature`)
340
+ 5. **Open** a Pull Request
341
+
342
+ ### Development Guidelines
343
+ - Follow PEP 8 style guidelines
344
+ - Add tests for new features
345
+ - Update documentation for API changes
346
+ - Ensure compatibility with Raspberry Pi deployment
347
+
348
+ =======
349
+ ```
350
+
351
+ **Requirements:**
352
+ - Python 3.8+
353
+ - PyTorch 1.8+
354
+ - torchvision 0.9+
355
+ - OpenCV 4.5+
356
+ - Pillow 8.0+
357
+ - ultralytics 8.0+
358
+ - huggingface_hub
359
+
360
+ ### Download Models from HuggingFace
361
+
362
+ ```python
363
+ from huggingface_hub import hf_hub_download
364
+
365
+ # Download detection model
366
+ detector_path = hf_hub_download(
367
+ repo_id="theonegareth/strawberryPicker",
368
+ filename="detection/best.pt"
369
+ )
370
+
371
+ # Download classification model
372
+ classifier_path = hf_hub_download(
373
+ repo_id="theonegareth/strawberryPicker",
374
+ filename="classification/best_enhanced_classifier.pth"
375
+ )
376
+
377
+ print(f"Models downloaded to:\n- {detector_path}\n- {classifier_path}")
378
+ ```
379
+
380
+ ### Basic Usage Example
381
+
382
+ ```python
383
+ import torch
384
+ import cv2
385
+ from PIL import Image
386
+ from torchvision import transforms
387
+ import numpy as np
388
+
389
+ # Load detection model
390
+ detector = torch.hub.load('ultralytics/yolov8', 'custom', path=detector_path)
391
+
392
+ # Load classification model
393
+ classifier = torch.load(classifier_path, map_location='cpu')
394
+ classifier.eval()
395
+
396
+ # Preprocessing for classifier
397
+ transform = transforms.Compose([
398
+ transforms.Resize((128, 128)),
399
+ transforms.ToTensor(),
400
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
401
+ std=[0.229, 0.224, 0.225])
402
+ ])
403
+
404
+ # Process image
405
+ def detect_and_classify(image_path):
406
+ """
407
+ Detect strawberries and classify their ripeness
408
+
409
+ Args:
410
+ image_path: Path to input image
411
+
412
+ Returns:
413
+ results: List of dicts with bbox, ripeness, confidence
414
+ """
415
+ # Load image
416
+ image = cv2.imread(image_path)
417
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
418
+
419
+ # Detect strawberries
420
+ detection_results = detector(image_rgb)
421
+
422
+ results = []
423
+ for result in detection_results:
424
+ boxes = result.boxes.xyxy.cpu().numpy()
425
+ confidences = result.boxes.conf.cpu().numpy()
426
+ class_ids = result.boxes.cls.cpu().numpy()
427
+
428
+ for box, conf, cls_id in zip(boxes, confidences, class_ids):
429
+ if conf < 0.5: # Filter low confidence detections
430
+ continue
431
+
432
+ x1, y1, x2, y2 = map(int, box)
433
+
434
+ # Crop strawberry
435
+ crop = image_rgb[y1:y2, x1:x2]
436
+ if crop.size == 0:
437
+ continue
438
+
439
+ # Classify ripeness
440
+ crop_pil = Image.fromarray(crop)
441
+ input_tensor = transform(crop_pil).unsqueeze(0)
442
+
443
+ with torch.no_grad():
444
+ output = classifier(input_tensor)
445
+ probabilities = torch.softmax(output, dim=1)
446
+ predicted_class = torch.argmax(probabilities, dim=1).item()
447
+ confidence = probabilities[0][predicted_class].item()
448
+
449
+ # Ripeness classes
450
+ classes = ['unripe', 'partially-ripe', 'ripe', 'overripe']
451
+
452
+ results.append({
453
+ 'bbox': (x1, y1, x2, y2),
454
+ 'ripeness': classes[predicted_class],
455
+ 'confidence': confidence,
456
+ 'detection_confidence': float(conf),
457
+ 'detection_class': int(cls_id)
458
+ })
459
+
460
+ return results
461
+
462
+ # Example usage
463
+ if __name__ == "__main__":
464
+ image_path = "strawberries.jpg"
465
+ results = detect_and_classify(image_path)
466
+
467
+ print(f"Detected {len(results)} strawberries:")
468
+ for i, result in enumerate(results, 1):
469
+ print(f" {i}. Ripeness: {result['ripeness']} "
470
+ f"(conf: {result['confidence']:.2f})")
471
+ ```
472
+
473
+ ## 📁 Repository Structure
474
+
475
+ ```
476
+ strawberryPicker/
477
+ ├── detection/ # YOLOv11n detection model (Two-stage system)
478
+ │ ├── best.pt # PyTorch weights
479
+ │ └── README.md # Model documentation
480
+ ├── classification/ # EfficientNet-B0 classification model (Two-stage system)
481
+ │ ├── best_enhanced_classifier.pth # PyTorch weights
482
+ │ ├── training_summary.md
483
+ │ └── README.md # Model documentation
484
+ ├── yolov8n/ # YOLOv8 Nano model (98.9% mAP@50)
485
+ │ ├── best.pt # PyTorch weights
486
+ │ ├── best.onnx # ONNX format
487
+ │ ├── best_fp16.onnx # FP16 ONNX for edge deployment
488
+ │ └── README.md # Model documentation
489
+ ├── yolov8s/ # YOLOv8 Small model (93.7% mAP@50)
490
+ │ ├── best.pt # PyTorch weights
491
+ │ ├── strawberry_yolov8s_enhanced.pt # Enhanced version
492
+ │ └── README.md # Model documentation
493
+ ├── yolov11n/ # YOLOv11 Nano model (Testing)
494
+ │ ├── strawberry_yolov11n.pt # PyTorch weights
495
+ │ ├── strawberry_yolov11n.onnx # ONNX format
496
+ │ └── README.md # Model documentation
497
+ ├── scripts/ # Optimization scripts
498
+ ├── benchmark_results/ # Performance benchmarks
499
+ ├── results/ # Training results/plots
500
+ ├── LICENSE # MIT license
501
+ ├── CITATION.cff # Academic citation
502
+ ├── sync_to_huggingface.py # Automation script
503
+ ├── requirements.txt # Python dependencies
504
+ ├── inference_example.py # Basic inference script
505
+ ├── webcam_inference.py # Real-time webcam demo
506
+ └── README.md # This file
507
+ ```
508
+
509
+ ## 🎯 Use Cases
510
+
511
+ ### **1. Automated Harvesting**
512
+ Integrate with robotic arms for autonomous strawberry picking:
513
+ ```python
514
+ # Pseudo-code for robotics integration
515
+ for strawberry in detected_strawberries:
516
+ if strawberry.ripeness == 'ripe':
517
+ robot_arm.move_to(strawberry.position)
518
+ robot_arm.pick()
519
+ ```
520
+
521
+ ### **2. Quality Control in Packaging**
522
+ Sort strawberries by ripeness in processing facilities:
523
+ ```python
524
+ # Conveyor belt sorting
525
+ if ripeness == 'ripe':
526
+ conveyor.route_to('premium_package')
527
+ elif ripeness == 'partially-ripe':
528
+ conveyor.route_to('delayed_shipping')
529
+ else:
530
+ conveyor.route_to('rejection_bin')
531
+ ```
532
+
533
+ ### **3. Agricultural Research**
534
+ Study ripening patterns and optimize harvest timing:
535
+ ```python
536
+ # Track ripeness distribution over time
537
+ daily_ripeness_counts = analyze_temporal_ripeness(images_over_time)
538
+ optimal_harvest_day = find_peak_ripe_day(daily_ripeness_counts)
539
+ ```
540
+
541
+ >>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
542
+ ## 📄 License
543
+
544
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
545
+
546
+ ## 🙏 Acknowledgments
547
+
548
+ <<<<<<< HEAD
549
+ - **Ultralytics** - YOLOv8 framework and documentation
550
+ - **Roboflow** - Dataset management and annotation tools
551
+ - **Arduino Community** - Open-source hardware ecosystem
552
+ - **Raspberry Pi Foundation** - Edge computing platform
553
+
554
+ ## 📞 Support
555
+
556
+ For questions and support:
557
+ - **Issues**: [GitHub Issues](https://github.com/theonegareth/strawberryPicker/issues)
558
+ - **Discussions**: [GitHub Discussions](https://github.com/theonegareth/strawberryPicker/discussions)
559
+ - **Documentation**: [Project Wiki](https://github.com/theonegareth/strawberryPicker/wiki)
560
+
561
+ ## 🔄 Changelog
562
+
563
+ ### v2.0.0 - Complete System Integration (Current)
564
+ - ✅ Complete dataset labeling (889 images)
565
+ - ✅ 3-class ripeness classification (94% accuracy)
566
+ - ✅ Arduino robotic integration
567
+ - ✅ Real-time pipeline (30 FPS)
568
+ - ✅ TensorFlow Lite optimization
569
+ - ✅ Professional repository structure
570
+ - ✅ Comprehensive documentation
571
+
572
+ ### v1.0.0 - Initial Release
573
+ - YOLOv8 training pipeline
574
+ - Basic detection functionality
575
+ - Dataset preparation tools
576
+
577
+ ## 🏆 Achievements
578
+
579
+ - **94.2%** detection accuracy (mAP@0.5)
580
+ - **94%** ripeness classification accuracy
581
+ - **30 FPS** real-time performance on Raspberry Pi 4B
582
+ - **889** fully labeled training images
583
+ - **Complete** end-to-end robotic system
584
+
585
+ ---
586
+
587
+ **Built with ❤️ for sustainable agriculture and precision farming**
588
+
589
+ *StrawberryPicker - Where AI meets Agriculture*
590
+ =======
591
+ - **YOLOv11**: Ultralytics for the state-of-the-art detection model
592
+ - **EfficientNet**: Google AI for the efficient classification architecture
593
+ - **PyTorch Team**: For the excellent deep learning framework
594
+
595
+ ## 📚 Citation
596
+
597
+ If you use this model in your research, please cite:
598
+
599
+ ```bibtex
600
+ @misc{strawberryPicker2024,
601
+ title={Strawberry Picker AI System: A Two-Stage Approach for Automated Harvesting},
602
+ author={The One Gareth},
603
+ year={2024},
604
+ publisher={HuggingFace},
605
+ url={https://huggingface.co/theonegareth/strawberryPicker}
606
+ }
607
+ ```
608
+
609
+ ---
610
+
611
+ <div align="center">
612
+ <h3>🚀 Ready to revolutionize strawberry harvesting!</h3>
613
+ <p>This AI system will help you harvest only the ripest, most delicious strawberries with precision and efficiency.</p>
614
+ </div>
615
+ >>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
classification/README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - image-classification
4
+ - efficientnet
5
+ - strawberry
6
+ - agriculture
7
+ - robotics
8
+ - computer-vision
9
+ - pytorch
10
+ - ripeness-classification
11
+ license: mit
12
+ datasets:
13
+ - custom
14
+ language:
15
+ - python
16
+ pretty_name: EfficientNet-B0 Strawberry Ripeness Classification
17
+ description: EfficientNet-B0 model for detailed strawberry ripeness classification with 4-class output
18
+ pipeline_tag: image-classification
19
+ ---
20
+
21
+ # EfficientNet-B0 Strawberry Ripeness Classification Model
22
+
23
+ This directory contains the EfficientNet-B0 model for detailed strawberry ripeness classification, the second stage of the Strawberry Picker AI system.
24
+
25
+ ## 📊 Model Performance
26
+
27
+ | Metric | Value |
28
+ |--------|-------|
29
+ | **Overall Accuracy** | 91.71% |
30
+ | **Macro F1-Score** | 0.92 |
31
+ | **Weighted F1-Score** | 0.93 |
32
+ | **Model Size** | 56MB |
33
+ | **Input Size** | 128x128 |
34
+
35
+ ### Class Performance (Validation Set)
36
+
37
+ | Class | Precision | Recall | F1-Score | Support |
38
+ |-------|-----------|--------|----------|---------|
39
+ | unripe | 0.92 | 0.89 | 0.91 | 163 |
40
+ | partially-ripe | 0.88 | 0.91 | 0.89 | 135 |
41
+ | ripe | 0.94 | 0.93 | 0.93 | 124 |
42
+ | overripe | 0.96 | 0.95 | 0.95 | 422 |
43
+
44
+ ## 🎯 Ripeness Classes
45
+
46
+ | Class | Description | Pick? |
47
+ |-------|-------------|-------|
48
+ | **unripe** | Green, hard texture | ❌ |
49
+ | **partially-ripe** | Pink/red, firm | ❌ |
50
+ | **ripe** | Bright red, soft | ✅ |
51
+ | **overripe** | Dark red/brown, mushy | ❌ |
52
+
53
+ ## 🚀 Quick Start
54
+
55
+ ### Installation
56
+ ```bash
57
+ pip install torch torchvision pillow
58
+ ```
59
+
60
+ ### Python Inference
61
+ ```python
62
+ import torch
63
+ from torchvision import transforms
64
+ from PIL import Image
65
+
66
+ # Load model
67
+ model = torch.load('best_enhanced_classifier.pth', map_location='cpu')
68
+ model.eval()
69
+
70
+ # Preprocessing
71
+ transform = transforms.Compose([
72
+ transforms.Resize((128, 128)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
75
+ std=[0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ # Load and preprocess image
79
+ image = Image.open('strawberry_crop.jpg')
80
+ input_tensor = transform(image).unsqueeze(0)
81
+
82
+ # Inference
83
+ with torch.no_grad():
84
+ output = model(input_tensor)
85
+ probabilities = torch.softmax(output, dim=1)
86
+ predicted_class = torch.argmax(probabilities, dim=1).item()
87
+ confidence = probabilities[0][predicted_class].item()
88
+
89
+ class_names = ['unripe', 'partially-ripe', 'ripe', 'overripe']
90
+ print(f"Ripeness: {class_names[predicted_class]} ({confidence:.2f})")
91
+ ```
92
+
93
+ ## 📁 Files
94
+
95
+ - `best_enhanced_classifier.pth` - PyTorch model weights
96
+ - `training_summary.md` - Detailed training information
97
+
98
+ ## 🎯 Use Cases
99
+
100
+ - **Automated Harvesting**: Second stage ripeness verification
101
+ - **Quality Control**: Precise ripeness assessment for sorting
102
+ - **Agricultural Research**: Ripeness pattern analysis
103
+
104
+ ## 🔧 Technical Details
105
+
106
+ - **Architecture**: EfficientNet-B0
107
+ - **Input Size**: 128x128 RGB
108
+ - **Output**: 4-class probabilities
109
+ - **Training Dataset**: 844 cropped strawberry images
110
+ - **Training Epochs**: 50 (early stopping)
111
+ - **Batch Size**: 8
112
+ - **Optimizer**: AdamW
113
+ - **Learning Rate**: 0.002 (cosine annealing)
114
+
115
+ ## 📈 Training Configuration
116
+
117
+ ```python
118
+ # Model Architecture
119
+ model = EfficientNet.from_pretrained('efficientnet-b0')
120
+ model._fc = nn.Linear(model._fc.in_features, 4)
121
+
122
+ # Training Setup
123
+ criterion = nn.CrossEntropyLoss()
124
+ optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)
125
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
126
+ ```
127
+
128
+ ## 🔗 Related Components
129
+
130
+ - [Detection Model](../detection/) - First stage for strawberry localization
131
+ - [Training Repository](https://github.com/theonegareth/strawberryPicker)
132
+
133
+ ## 📚 Documentation
134
+
135
+ - [Full System Documentation](https://github.com/theonegareth/strawberryPicker)
136
+ - [Training Summary](training_summary.md)
137
+
138
+ ## 📄 License
139
+
140
+ MIT License - See main repository for details.
141
+
142
+ ---
143
+
144
+ **Model Version**: 1.0.0
145
+ **Training Date**: November 2025
146
+ **Part of**: Strawberry Picker AI System
classification/training_summary.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Enhanced Ripeness Classifier Training Summary
2
+
3
+ - **Best Validation Accuracy**: 91.71%
4
+ - **Final Training Accuracy**: 89.76%
5
+ - **Final Validation Accuracy**: 88.63%
6
+ - **Target Achieved**: ❌ No
7
+ - **Training Images**: 1,436 (564 unripe + 872 ripe)
8
+ - **Model**: EfficientNet-B0 with dropout
9
+ - **Key Improvements**: OneCycleLR, heavy augmentation, label smoothing
10
+
11
+ ## 📈 Improvement
12
+
13
+ Accuracy improved from 91.94% to 91.71%.
14
+ Consider additional improvements.
15
+
16
+ ## Next Steps
17
+
18
+ 1. Test the enhanced model on sample images
19
+ 2. Compare with baseline model
20
+ 3. Export to TFLite for Raspberry Pi deployment
21
+ 4. Integrate with strawberry detector
classification_model/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Strawberry Ripeness Classification Model
2
+
3
+ ## Model Description
4
+
5
+ This is a 4-class strawberry ripeness classification model trained on PyTorch with 91.71% validation accuracy. The model classifies strawberry crops into four ripeness categories:
6
+
7
+ - **unripe**: Green, hard strawberries not ready for picking
8
+ - **partially-ripe**: Pink/red, firm strawberries
9
+ - **ripe**: Bright red, soft strawberries ready for picking
10
+ - **overripe**: Dark red/brown, mushy strawberries past optimal ripeness
11
+
12
+ ## Training Details
13
+
14
+ - **Architecture**: EfficientNet-B0 with custom classification head
15
+ - **Input Size**: 128x128 RGB images
16
+ - **Training Epochs**: 50 (early stopping at epoch 14)
17
+ - **Batch Size**: 8
18
+ - **Optimizer**: Adam with cosine annealing LR scheduler
19
+ - **Dataset**: 2,436 total images (889 strawberry crops + 800 Kaggle overripe images)
20
+ - **Validation Accuracy**: 91.71%
21
+ - **Training Time**: ~14 epochs with early stopping
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ import torch
27
+ from torchvision import transforms
28
+ from PIL import Image
29
+
30
+ # Load model
31
+ model = torch.load("best_enhanced_classifier.pth")
32
+ model.eval()
33
+
34
+ # Preprocessing
35
+ transform = transforms.Compose([
36
+ transforms.Resize((128, 128)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ ])
40
+
41
+ # Classify image
42
+ image = Image.open("strawberry_crop.jpg")
43
+ input_tensor = transform(image).unsqueeze(0)
44
+
45
+ with torch.no_grad():
46
+ output = model(input_tensor)
47
+ predicted_class = torch.argmax(output, dim=1).item()
48
+
49
+ classes = ["unripe", "partially-ripe", "ripe", "overripe"]
50
+ print(f"Predicted ripeness: {classes[predicted_class]}")
51
+ ```
52
+
53
+ ## Model Files
54
+
55
+ - `classification_model/best_enhanced_classifier.pth`: Trained PyTorch model (4.7MB)
56
+ - `classification_model/training_summary.md`: Detailed training metrics and results
57
+ - `classification_model/enhanced_training_curves.png`: Training/validation curves
58
+
59
+ ## Integration
60
+
61
+ This model is designed to work with the strawberry detection model for a complete picking system:
62
+
63
+ 1. **Detection**: YOLOv8 finds strawberries in images
64
+ 2. **Classification**: This model determines ripeness of each detected strawberry
65
+ 3. **Decision**: Only pick ripe strawberries (avoid unripe, partially-ripe, and overripe)
66
+
67
+ ## Performance Metrics
68
+
69
+ | Class | Precision | Recall | F1-Score |
70
+ |-------|-----------|--------|----------|
71
+ | unripe | 0.92 | 0.89 | 0.91 |
72
+ | partially-ripe | 0.88 | 0.91 | 0.89 |
73
+ | ripe | 0.94 | 0.93 | 0.93 |
74
+ | overripe | 0.96 | 0.95 | 0.95 |
75
+
76
+ **Overall Accuracy**: 91.71%
77
+
78
+ ## Dataset
79
+
80
+ - **Source**: Mixed dataset with manual annotations + Kaggle fruit ripeness dataset
81
+ - **Classes**: 4 ripeness categories
82
+ - **Total Images**: 2,436 (train: 1,436, val: 422)
83
+ - **Preprocessing**: Cropped strawberry regions from detection model
84
+
85
+ ## Requirements
86
+
87
+ - PyTorch >= 1.8.0
88
+ - torchvision >= 0.9.0
89
+ - Pillow >= 8.0.0
90
+ - numpy >= 1.21.0
91
+
92
+ ## License
93
+
94
+ MIT License - see main repository for details.
95
+
96
+ ## Contact
97
+
98
+ For questions or improvements, please open an issue in the main repository.
classification_model/training_summary.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Enhanced Ripeness Classifier Training Summary
2
+
3
+ - **Best Validation Accuracy**: 91.71%
4
+ - **Final Training Accuracy**: 89.76%
5
+ - **Final Validation Accuracy**: 88.63%
6
+ - **Target Achieved**: ❌ No
7
+ - **Training Images**: 1,436 (564 unripe + 872 ripe)
8
+ - **Model**: EfficientNet-B0 with dropout
9
+ - **Key Improvements**: OneCycleLR, heavy augmentation, label smoothing
10
+
11
+ ## 📈 Improvement
12
+
13
+ Accuracy improved from 91.94% to 91.71%.
14
+ Consider additional improvements.
15
+
16
+ ## Next Steps
17
+
18
+ 1. Test the enhanced model on sample images
19
+ 2. Compare with baseline model
20
+ 3. Export to TFLite for Raspberry Pi deployment
21
+ 4. Integrate with strawberry detector
config.yaml ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Strawberry Picker Configuration
2
+ # Centralized configuration for all paths and settings
3
+
4
+ # ============================================
5
+ # Dataset Paths
6
+ # ============================================
7
+ dataset:
8
+ # Primary dataset for detection (YOLO format)
9
+ detection:
10
+ path: "model/dataset_strawberry_detect_v3"
11
+ data_yaml: "model/dataset_strawberry_detect_v3/data.yaml"
12
+ train: "train/images"
13
+ val: "valid/images"
14
+ test: "test/images"
15
+ nc: 1
16
+ names: ["strawberry"]
17
+
18
+ # Ripeness classification dataset (binary/3-class)
19
+ ripeness:
20
+ path: "model/ripeness_manual_dataset"
21
+ train: "train"
22
+ val: "val"
23
+ test: "test"
24
+ classes: ["unripe", "partially_ripe", "fully_ripe"]
25
+
26
+ # ============================================
27
+ # Model Paths
28
+ # ============================================
29
+ models:
30
+ # Detection models
31
+ detection:
32
+ weights_dir: "model/weights"
33
+ results_dir: "model/results"
34
+ exports_dir: "model/exports"
35
+
36
+ # Pretrained model files
37
+ yolov8n: "yolov8n.pt"
38
+ yolov8s: "yolov8s.pt"
39
+ yolov8m: "yolov8m.pt"
40
+
41
+ # Custom trained models
42
+ strawberry_yolov8n: "model/weights/strawberry_yolov8n.pt"
43
+ strawberry_yolov8s: "model/weights/strawberry_yolov8s.pt"
44
+
45
+ # Export formats
46
+ onnx: "model/exports/strawberry_yolov8n.onnx"
47
+ tflite: "model/exports/strawberry_yolov8n.tflite"
48
+ tflite_int8: "model/exports/strawberry_yolov8n_int8.tflite"
49
+
50
+ # Ripeness classification models
51
+ ripeness:
52
+ keras_model: "model/weights/ripeness_classifier.h5"
53
+ tflite_model: "model/exports/ripeness_classifier.tflite"
54
+
55
+ # ============================================
56
+ # Training Configuration
57
+ # ============================================
58
+ training:
59
+ # Default training parameters
60
+ epochs: 100
61
+ batch_size: 16
62
+ img_size: 640
63
+ patience: 20
64
+ save_period: 10
65
+
66
+ # Environment detection
67
+ use_gpu: true
68
+ device: "auto" # "cpu", "cuda", or "auto"
69
+
70
+ # Augmentation settings
71
+ augment: true
72
+ hsv_h: 0.015
73
+ hsv_s: 0.7
74
+ hsv_v: 0.4
75
+ degrees: 0.0
76
+ translate: 0.1
77
+ scale: 0.5
78
+ shear: 0.0
79
+ perspective: 0.0
80
+ flipud: 0.0
81
+ fliplr: 0.5
82
+ mosaic: 1.0
83
+ mixup: 0.0
84
+
85
+ # ============================================
86
+ # Inference Configuration
87
+ # ============================================
88
+ inference:
89
+ # Real-time detection
90
+ camera_index: 0
91
+ confidence_threshold: 0.5
92
+ iou_threshold: 0.45
93
+
94
+ # Image preprocessing
95
+ input_size: 224 # For classification models
96
+ normalize: true
97
+
98
+ # Output settings
99
+ show_fps: true
100
+ save_predictions: false
101
+ output_dir: "predictions"
102
+
103
+ # ============================================
104
+ # Robotic Arm Configuration
105
+ # ============================================
106
+ robot:
107
+ # Serial communication
108
+ serial_port: "/dev/ttyACM0"
109
+ baud_rate: 115200
110
+
111
+ # Coordinate transformation (pixel to robot space)
112
+ calibration:
113
+ camera_matrix: "calibration/camera_matrix.npy"
114
+ dist_coeffs: "calibration/dist_coeffs.npy"
115
+ transformation_matrix: "calibration/transformation_matrix.npy"
116
+
117
+ # Arm parameters (mm)
118
+ workspace:
119
+ x_min: 0
120
+ x_max: 300
121
+ y_min: 0
122
+ y_max: 200
123
+ z_min: 0
124
+ z_max: 150
125
+
126
+ # Servo angles (degrees)
127
+ servo:
128
+ base_min: 0
129
+ base_max: 180
130
+ shoulder_min: 30
131
+ shoulder_max: 150
132
+ elbow_min: 20
133
+ elbow_max: 160
134
+ gripper_open: 90
135
+ gripper_close: 180
136
+
137
+ # ============================================
138
+ # Logging Configuration
139
+ # ============================================
140
+ logging:
141
+ level: "INFO" # DEBUG, INFO, WARNING, ERROR
142
+ file: "logs/strawberry_picker.log"
143
+ max_size_mb: 10
144
+ backup_count: 5
145
+
146
+ # TensorBoard
147
+ tensorboard_dir: "runs"
148
+ log_images: true
149
+ log_frequency: 10 # batches
150
+
151
+ # ============================================
152
+ # Performance Optimization
153
+ # ============================================
154
+ performance:
155
+ # Raspberry Pi optimization
156
+ use_tensorrt: false
157
+ use_coral_tpu: false
158
+ threads: 4
159
+ use_neon: true
160
+
161
+ # Model quantization
162
+ quantization:
163
+ enabled: true
164
+ method: "int8" # int8, float16, dynamic_range
165
+ calibration_samples: 100
166
+
167
+ # Caching
168
+ cache_dataset: true
169
+ cache_dir: ".cache"
170
+
171
+ # ============================================
172
+ # Development & Debugging
173
+ # ============================================
174
+ debug:
175
+ enable_debug_output: false
176
+ save_debug_images: false
177
+ debug_dir: "debug"
178
+ log_predictions: true
179
+ visualize_detections: true
detection/README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - object-detection
4
+ - yolo
5
+ - yolov11
6
+ - strawberry
7
+ - agriculture
8
+ - robotics
9
+ - computer-vision
10
+ - pytorch
11
+ - ripeness-detection
12
+ license: mit
13
+ datasets:
14
+ - custom
15
+ language:
16
+ - python
17
+ pretty_name: YOLOv11n Strawberry Ripeness Detection
18
+ description: YOLOv11 Nano model for strawberry ripeness detection with 3-class classification
19
+ pipeline_tag: object-detection
20
+ ---
21
+
22
+ # YOLOv11n Strawberry Ripeness Detection Model
23
+
24
+ This directory contains the YOLOv11 Nano model for strawberry ripeness detection, part of the two-stage Strawberry Picker AI system.
25
+
26
+ ## 📊 Model Performance
27
+
28
+ | Metric | Value |
29
+ |--------|-------|
30
+ | **mAP@50** | 84.0% |
31
+ | **mAP@50-95** | 57.0% |
32
+ | **Precision** | 74.5% |
33
+ | **Recall** | 82.1% |
34
+ | **Model Size** | 5.2MB |
35
+ | **Inference Speed** | ~35 FPS (RTX 3050 Ti) |
36
+
37
+ ### Class Performance (Test Set)
38
+
39
+ | Class | Precision | Recall | mAP50 | Support |
40
+ |-------|-----------|--------|-------|---------|
41
+ | partially-ripe | 78.8% | 92.4% | 92.2% | 79 |
42
+ | ripe | 82.0% | 87.1% | 89.5% | 70 |
43
+ | unripe | 62.9% | 66.7% | 70.3% | 78 |
44
+
45
+ ## 🚀 Quick Start
46
+
47
+ ### Installation
48
+ ```bash
49
+ pip install ultralytics opencv-python
50
+ ```
51
+
52
+ ### Python Inference
53
+ ```python
54
+ from ultralytics import YOLO
55
+
56
+ # Load model
57
+ model = YOLO('best.pt')
58
+
59
+ # Run inference
60
+ results = model('strawberry_image.jpg', conf=0.5)
61
+
62
+ # Process results
63
+ for result in results:
64
+ boxes = result.boxes
65
+ for box in boxes:
66
+ cls = int(box.cls)
67
+ conf = float(box.conf)
68
+ xyxy = box.xyxy
69
+ class_names = ['partially-ripe', 'ripe', 'unripe']
70
+ print(f"{class_names[cls]} strawberry: {conf:.2f} confidence")
71
+ ```
72
+
73
+ ### Command Line
74
+ ```bash
75
+ # Single image
76
+ yolo predict model=best.pt source='strawberry.jpg'
77
+
78
+ # Webcam
79
+ yolo predict model=best.pt source=0
80
+ ```
81
+
82
+ ## 📁 Files
83
+
84
+ - `best.pt` - PyTorch model weights (recommended)
85
+
86
+ ## 🎯 Use Cases
87
+
88
+ - **Automated Harvesting**: First stage of two-stage picking system
89
+ - **Ripeness Assessment**: Initial strawberry detection and ripeness categorization
90
+ - **Quality Control**: Pre-classification for detailed ripeness analysis
91
+
92
+ ## 🔧 Technical Details
93
+
94
+ - **Architecture**: YOLOv11n (Nano)
95
+ - **Input Size**: 416x416
96
+ - **Classes**: 3 (partially-ripe, ripe, unripe)
97
+ - **Training Dataset**: Custom dataset (1200+ annotated strawberries)
98
+ - **Training Epochs**: 50 (early stopping at 20)
99
+ - **Batch Size**: 8
100
+ - **Optimizer**: AdamW
101
+ - **Learning Rate**: 0.01 (cosine annealing)
102
+
103
+ ## 📈 Training Configuration
104
+
105
+ ```yaml
106
+ model: yolov11n.pt
107
+ epochs: 50
108
+ batch: 8
109
+ imgsz: 416
110
+ optimizer: AdamW
111
+ lr0: 0.01
112
+ lrf: 0.01
113
+ weight_decay: 0.0005
114
+ warmup_epochs: 3.0
115
+ patience: 20
116
+ classes: 3
117
+ names: ['partially-ripe', 'ripe', 'unripe']
118
+ ```
119
+
120
+ ## 🔗 Related Components
121
+
122
+ - [Classification Model](../classification/) - Second stage for detailed ripeness classification
123
+ - [Training Repository](https://github.com/theonegareth/strawberryPicker)
124
+
125
+ ## 📚 Documentation
126
+
127
+ - [Full System Documentation](https://github.com/theonegareth/strawberryPicker)
128
+ - [Two-Stage Pipeline](https://github.com/theonegareth/strawberryPicker#system-architecture)
129
+
130
+ ## 📄 License
131
+
132
+ MIT License - See main repository for details.
133
+
134
+ ---
135
+
136
+ **Model Version**: 1.0.0
137
+ **Training Date**: November 2025
138
+ **Part of**: Strawberry Picker AI System
docs/GITHUB_SETUP.md ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GitHub Setup and Push Guide
2
+
3
+ Follow these steps to push your strawberry picker ML project to GitHub.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Git installed** on your system
8
+ ```bash
9
+ git --version
10
+ ```
11
+
12
+ 2. **GitHub account** (create at https://github.com if you don't have one)
13
+
14
+ 3. **Git configured** with your credentials
15
+ ```bash
16
+ git config --global user.name "Your Name"
17
+ git config --global user.email "your.email@example.com"
18
+ ```
19
+
20
+ ## Step 1: Create GitHub Repository
21
+
22
+ ### Option A: Using GitHub Website (Recommended)
23
+ 1. Go to https://github.com/new
24
+ 2. Repository name: `strawberry-picker-robot`
25
+ 3. Description: `Machine learning vision system for robotic strawberry picking`
26
+ 4. Choose: **Public** or **Private**
27
+ 5. Check: **Add a README file** (we'll overwrite it)
28
+ 6. Click: **Create repository**
29
+
30
+ ### Option B: Using GitHub CLI (if installed)
31
+ ```bash
32
+ gh repo create strawberry-picker-robot --public --source=. --remote=origin --push
33
+ ```
34
+
35
+ ## Step 2: Initialize Local Repository
36
+
37
+ Open terminal in your project folder:
38
+
39
+ ```bash
40
+ cd "G:\My Drive\University Files\5th Semester\Kinematics and Dynamics\strawberryPicker"
41
+ ```
42
+
43
+ Initialize git repository:
44
+ ```bash
45
+ git init
46
+ ```
47
+
48
+ ## Step 3: Add Files to Git
49
+
50
+ Add all files (except those in .gitignore):
51
+ ```bash
52
+ git add .
53
+ ```
54
+
55
+ Check what will be committed:
56
+ ```bash
57
+ git status
58
+ ```
59
+
60
+ You should see files like:
61
+ - `requirements.txt`
62
+ - `train_yolov8.py`
63
+ - `train_yolov8_colab.ipynb`
64
+ - `setup_training.py`
65
+ - `TRAINING_README.md`
66
+ - `README.md`
67
+ - `.gitignore`
68
+ - `ArduinoCode/`
69
+ - `assets/`
70
+
71
+ ## Step 4: Create First Commit
72
+
73
+ ```bash
74
+ git commit -m "Initial commit: YOLOv8 training pipeline for strawberry detection
75
+
76
+ - Add YOLOv8 training scripts (local, Colab, WSL)
77
+ - Add environment setup and validation
78
+ - Add comprehensive training documentation
79
+ - Add .gitignore for ML project
80
+ - Support multiple training environments"
81
+ ```
82
+
83
+ ## Step 5: Connect to GitHub Repository
84
+
85
+ If you created repo on GitHub website, link it:
86
+
87
+ ```bash
88
+ git remote add origin https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
89
+ ```
90
+
91
+ Replace `YOUR_USERNAME` with your actual GitHub username.
92
+
93
+ ## Step 6: Rename Default Branch (if needed)
94
+
95
+ ```bash
96
+ git branch -M main
97
+ ```
98
+
99
+ ## Step 7: Push to GitHub
100
+
101
+ First push (sets up remote tracking):
102
+ ```bash
103
+ git push -u origin main
104
+ ```
105
+
106
+ Enter your GitHub credentials when prompted.
107
+
108
+ ## Step 8: Verify Push
109
+
110
+ Go to https://github.com/YOUR_USERNAME/strawberry-picker-robot
111
+ You should see all your files!
112
+
113
+ ## Step 9: Add .gitignore for Large Files
114
+
115
+ If you want to add dataset or large model files later, use Git LFS:
116
+
117
+ ```bash
118
+ # Install Git LFS
119
+ git lfs install
120
+
121
+ # Track large file types
122
+ git lfs track "*.pt"
123
+ git lfs track "*.onnx"
124
+ git lfs track "*.tflite"
125
+ git lfs track "*.h5"
126
+
127
+ # Add .gitattributes
128
+ git add .gitattributes
129
+ git commit -m "Add Git LFS tracking for large model files"
130
+ git push
131
+ ```
132
+
133
+ ## Step 10: Create .gitattributes File
134
+
135
+ Create `.gitattributes` file in project root:
136
+
137
+ ```
138
+ *.pt filter=lfs diff=lfs merge=lfs -text
139
+ *.onnx filter=lfs diff=lfs merge=lfs -text
140
+ *.tflite filter=lfs diff=lfs merge=lfs -text
141
+ *.h5 filter=lfs diff=lfs merge=lfs -text
142
+ *.pth filter=lfs diff=lfs merge=lfs -text
143
+ ```
144
+
145
+ ## Quick Push Commands (Summary)
146
+
147
+ ```bash
148
+ # One-time setup
149
+ git init
150
+ git add .
151
+ git commit -m "Initial commit: YOLOv8 training pipeline"
152
+ git remote add origin https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
153
+ git branch -M main
154
+ git push -u origin main
155
+
156
+ # Future updates
157
+ git add .
158
+ git commit -m "Your commit message"
159
+ git push
160
+ ```
161
+
162
+ ## Troubleshooting
163
+
164
+ ### Authentication Issues
165
+
166
+ If you get authentication errors:
167
+
168
+ **Option 1: Use Personal Access Token**
169
+ 1. Go to GitHub Settings → Developer settings → Personal access tokens
170
+ 2. Generate new token (classic) with `repo` scope
171
+ 3. Use token as password when prompted
172
+
173
+ **Option 2: Use SSH (Recommended)**
174
+ ```bash
175
+ # Check if SSH key exists
176
+ ls ~/.ssh/id_rsa.pub
177
+
178
+ # If not, create one
179
+ ssh-keygen -t rsa -b 4096 -C "your.email@example.com"
180
+
181
+ # Add to GitHub
182
+ cat ~/.ssh/id_rsa.pub
183
+ # Copy output and add to GitHub → Settings → SSH and GPG keys
184
+
185
+ # Change remote to SSH
186
+ git remote set-url origin git@github.com:YOUR_USERNAME/strawberry-picker-robot.git
187
+ ```
188
+
189
+ ### Large File Issues
190
+
191
+ If files are too large for GitHub (max 100MB):
192
+ - Use Git LFS (see Step 9)
193
+ - Or add to `.gitignore` and upload separately to Google Drive/Dropbox
194
+
195
+ ### Proxy Issues (if behind firewall)
196
+
197
+ ```bash
198
+ git config --global http.proxy http://proxy.company.com:8080
199
+ git config --global https.proxy https://proxy.company.com:8080
200
+ ```
201
+
202
+ ## Best Practices
203
+
204
+ ### Commit Messages
205
+ Write clear, descriptive commit messages:
206
+ ```bash
207
+ git commit -m "Add YOLOv8 training script with Colab support
208
+
209
+ - Auto-detects training environment
210
+ - Supports local, WSL, and Google Colab
211
+ - Includes dataset validation
212
+ - Exports to ONNX format"
213
+ ```
214
+
215
+ ### Branching Strategy
216
+ ```bash
217
+ # Create feature branch
218
+ git checkout -b feature/add-ripeness-detection
219
+
220
+ # Work on changes
221
+ git add .
222
+ git commit -m "Add ripeness classification dataset collection"
223
+
224
+ # Push branch
225
+ git push -u origin feature/add-ripeness-detection
226
+
227
+ # Create pull request on GitHub
228
+ ```
229
+
230
+ ### Regular Pushes
231
+ Push frequently to avoid losing work:
232
+ ```bash
233
+ # Daily push
234
+ git add .
235
+ git commit -m "Training progress: epoch 50/100, loss: 0.123"
236
+ git push
237
+ ```
238
+
239
+ ## GitHub Repository Settings
240
+
241
+ ### Protect Main Branch
242
+ 1. Go to Settings → Branches
243
+ 2. Add rule for `main` branch:
244
+ - Require pull request reviews
245
+ - Require status checks
246
+ - Include administrators
247
+
248
+ ### Add Description and Topics
249
+ 1. Go to repository page
250
+ 2. Click "Edit" next to description
251
+ 3. Add topics: `machine-learning`, `yolov8`, `raspberry-pi`, `robotics`, `computer-vision`
252
+
253
+ ### Enable Issues and Projects
254
+ - Use Issues to track bugs and features
255
+ - Use Projects to organize development phases
256
+
257
+ ## Continuous Integration (Optional)
258
+
259
+ Add `.github/workflows/train.yml` for automated training:
260
+
261
+ ```yaml
262
+ name: Train Model
263
+
264
+ on:
265
+ push:
266
+ branches: [ main ]
267
+ pull_request:
268
+ branches: [ main ]
269
+
270
+ jobs:
271
+ train:
272
+ runs-on: ubuntu-latest
273
+ steps:
274
+ - uses: actions/checkout@v3
275
+ - name: Set up Python
276
+ uses: actions/setup-python@v4
277
+ with:
278
+ python-version: '3.9'
279
+ - name: Install dependencies
280
+ run: |
281
+ pip install -r requirements.txt
282
+ - name: Validate dataset
283
+ run: |
284
+ python train_yolov8.py --validate-only
285
+ ```
286
+
287
+ ## Next Steps After Push
288
+
289
+ 1. **Share repository** with teammates/collaborators
290
+ 2. **Create issues** for Phase 2, 3, 4 tasks
291
+ 3. **Set up project board** to track progress
292
+ 4. **Add documentation** to wiki if needed
293
+ 5. **Enable GitHub Pages** for documentation (optional)
294
+
295
+ ## Getting Help
296
+
297
+ - GitHub Docs: https://docs.github.com
298
+ - Git Cheat Sheet: https://education.github.com/git-cheat-sheet-education.pdf
299
+ - Git LFS Docs: https://git-lfs.github.com
300
+
301
+ ## Repository URL
302
+
303
+ Your repository will be at:
304
+ `https://github.com/YOUR_USERNAME/strawberry-picker-robot`
305
+
306
+ ## Clone Command (for others)
307
+
308
+ ```bash
309
+ git clone https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
310
+ cd strawberry-picker-robot
311
+ pip install -r requirements.txt
312
+ ```
313
+
314
+ ---
315
+
316
+ **Ready to push?** Run the commands in Step 1-7 above!
docs/TRAINING_README.md ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Strawberry Detection Model Training Guide
2
+
3
+ This guide covers training a YOLOv8 model for strawberry detection using multiple environments.
4
+
5
+ ## Quick Start
6
+
7
+ ### Option 1: Local/WSL Training (Recommended for initial setup)
8
+
9
+ ```bash
10
+ # 1. Setup environment
11
+ python setup_training.py
12
+
13
+ # 2. Train model
14
+ python train_yolov8.py --epochs 100 --batch-size 16
15
+
16
+ # 3. Validate dataset only (without training)
17
+ python train_yolov8.py --validate-only
18
+ ```
19
+
20
+ ### Option 2: Google Colab Training (Recommended for faster training)
21
+
22
+ 1. Open `train_yolov8_colab.ipynb` in Google Colab
23
+ 2. Connect to GPU runtime: Runtime → Change runtime type → GPU
24
+ 3. Run cells sequentially
25
+ 4. Download trained model when complete
26
+
27
+ ### Option 3: VS Code with Colab Extension
28
+
29
+ 1. Install VS Code Google Colab extension
30
+ 2. Open `train_yolov8_colab.ipynb`
31
+ 3. Connect to Colab kernel
32
+ 4. Run cells
33
+
34
+ ## Environment Setup
35
+
36
+ ### Prerequisites
37
+
38
+ - Python 3.8+
39
+ - pip package manager
40
+ - Git (for version control)
41
+
42
+ ### Installation Steps
43
+
44
+ 1. **Install Python dependencies:**
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ 2. **Validate setup:**
50
+ ```bash
51
+ python setup_training.py --validate-only
52
+ ```
53
+
54
+ 3. **Full setup (install + validate):**
55
+ ```bash
56
+ python setup_training.py
57
+ ```
58
+
59
+ ## Dataset Structure
60
+
61
+ Your dataset should be organized as follows:
62
+
63
+ ```
64
+ model/dataset/straw-detect.v1-straw-detect.yolov8/
65
+ ├── data.yaml
66
+ ├── train/
67
+ │ ├── images/
68
+ │ │ ├── image1.jpg
69
+ │ │ ├── image2.jpg
70
+ │ │ └── ...
71
+ │ └── labels/
72
+ │ ├── image1.txt
73
+ │ ├── image2.txt
74
+ │ └── ...
75
+ ├── valid/
76
+ │ ├── images/
77
+ │ └── labels/
78
+ └── test/
79
+ ├── images/
80
+ └── labels/
81
+ ```
82
+
83
+ ### data.yaml Format
84
+
85
+ ```yaml
86
+ train: ../train/images
87
+ val: ../valid/images
88
+ test: ../test/images
89
+
90
+ nc: 1
91
+ names: ['strawberry']
92
+ ```
93
+
94
+ ## Training Parameters
95
+
96
+ ### Basic Training
97
+
98
+ ```bash
99
+ python train_yolov8.py \
100
+ --epochs 100 \
101
+ --batch-size 16 \
102
+ --img-size 640 \
103
+ --export-onnx
104
+ ```
105
+
106
+ ### Advanced Options
107
+
108
+ - `--dataset PATH`: Custom dataset path
109
+ - `--epochs N`: Number of training epochs (default: 100)
110
+ - `--batch-size N`: Batch size (default: 16)
111
+ - `--img-size N`: Image size (default: 640)
112
+ - `--export-onnx`: Export to ONNX format after training
113
+ - `--validate-only`: Only validate dataset without training
114
+
115
+ ### Model Sizes
116
+
117
+ Choose different YOLOv8 models based on your needs:
118
+
119
+ | Model | Parameters | Speed (CPU) | Speed (GPU) | Accuracy |
120
+ |-------|------------|-------------|-------------|----------|
121
+ | yolov8n | 3.2M | Fastest | Fastest | Good |
122
+ | yolov8s | 11.2M | Fast | Fast | Better |
123
+ | yolov8m | 25.9M | Medium | Medium | Best |
124
+
125
+ To use a different model, edit the `MODEL_NAME` variable in the training script.
126
+
127
+ ## Training on Different Environments
128
+
129
+ ### Google Colab Advantages
130
+ - **Free GPU**: Tesla T4 with 16GB VRAM
131
+ - **Faster training**: 5-10x faster than CPU
132
+ - **No local setup**: Everything runs in the cloud
133
+
134
+ ### WSL (Windows Subsystem for Linux)
135
+ - **Native GPU support**: If you have NVIDIA GPU
136
+ - **Persistent**: Files saved locally
137
+ - **Better for large datasets**: No upload needed
138
+
139
+ ### Local Python
140
+ - **Simplest setup**: No additional configuration
141
+ - **CPU only**: Slower training
142
+ - **Good for testing**: Quick validation
143
+
144
+ ## Monitoring Training
145
+
146
+ ### TensorBoard (Local Training)
147
+ ```bash
148
+ tensorboard --logdir model/results
149
+ ```
150
+
151
+ ### Colab Training
152
+ - Use the built-in progress bars
153
+ - View loss curves in real-time
154
+ - Download metrics after training
155
+
156
+ ## Expected Training Time
157
+
158
+ | Environment | Epochs | Estimated Time |
159
+ |-------------|--------|----------------|
160
+ | Colab GPU | 100 | 30-60 minutes |
161
+ | WSL GPU | 100 | 30-60 minutes |
162
+ | CPU only | 100 | 5-8 hours |
163
+
164
+ ## Troubleshooting
165
+
166
+ ### Common Issues
167
+
168
+ 1. **CUDA out of memory**
169
+ - Reduce batch size: `--batch-size 8` or `--batch-size 4`
170
+ - Use smaller model: `yolov8n` instead of `yolov8s`
171
+
172
+ 2. **Dataset not found**
173
+ - Check data.yaml paths are correct
174
+ - Verify images are in the right directories
175
+ - Run with `--validate-only` to debug
176
+
177
+ 3. **Import errors**
178
+ - Run `setup_training.py` again
179
+ - Check Python version: `python --version`
180
+ - Reinstall ultralytics: `pip install --force-reinstall ultralytics`
181
+
182
+ 4. **Slow training**
183
+ - Verify GPU is being used: Check setup output
184
+ - Reduce image size: `--img-size 416`
185
+ - Use smaller model: `yolov8n`
186
+
187
+ ### Getting Help
188
+
189
+ 1. Run validation: `python train_yolov8.py --validate-only`
190
+ 2. Check setup: `python setup_training.py --validate-only`
191
+ 3. Review error messages carefully
192
+ 4. Check dataset structure matches expected format
193
+
194
+ ## Next Steps After Training
195
+
196
+ 1. **Export models**: ONNX, TensorFlow Lite formats
197
+ 2. **Test inference**: Run on sample images
198
+ 3. **Deploy to Raspberry Pi**: Optimize for edge deployment
199
+ 4. **Integrate with robot**: Connect to robotic arm control
200
+
201
+ ## Performance Optimization Tips
202
+
203
+ ### For Faster Training
204
+ - Use GPU environment (Colab or WSL with GPU)
205
+ - Reduce image size (`--img-size 416`)
206
+ - Use smaller batch size if GPU memory limited
207
+ - Enable image caching (already enabled by default)
208
+
209
+ ### For Better Accuracy
210
+ - Increase epochs to 150-200
211
+ - Use larger model (`yolov8s` or `yolov8m`)
212
+ - Collect more diverse training images
213
+ - Use data augmentation
214
+ - Fine-tune learning rate
215
+
216
+ ### For Raspberry Pi Deployment
217
+ - Use `yolov8n` model (smallest)
218
+ - Export to TensorFlow Lite
219
+ - Apply INT8 quantization
220
+ - Reduce input resolution to 320x320 or 416x416
221
+
222
+ ## Version Control
223
+
224
+ Track your training experiments:
225
+
226
+ ```bash
227
+ git add model/results/
228
+ git commit -m "Training: yolov8n, 100 epochs, 640px"
229
+ git tag v1.0-yolov8n-baseline
230
+ ```
231
+
232
+ ## License
233
+
234
+ This training pipeline is part of the Strawberry Picker project. See main README for license information.
git-xet ADDED
File without changes
inference_example.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple inference example for the strawberry detection model.
4
+ """
5
+
6
+ from ultralytics import YOLO
7
+ import cv2
8
+ import sys
9
+
10
+ def main():
11
+ # Load the model
12
+ print("Loading strawberry detection model...")
13
+ model = YOLO('best.pt')
14
+
15
+ # Run inference on an image
16
+ if len(sys.argv) > 1:
17
+ image_path = sys.argv[1]
18
+ else:
19
+ print("Usage: python inference_example.py <path_to_image>")
20
+ print("Using default test - loading webcam...")
21
+
22
+ # Webcam inference
23
+ cap = cv2.VideoCapture(0)
24
+
25
+ while True:
26
+ ret, frame = cap.read()
27
+ if not ret:
28
+ break
29
+
30
+ # Run inference
31
+ results = model(frame, conf=0.5)
32
+
33
+ # Draw results
34
+ annotated_frame = results[0].plot()
35
+
36
+ # Display
37
+ cv2.imshow('Strawberry Detection', annotated_frame)
38
+
39
+ if cv2.waitKey(1) & 0xFF == ord('q'):
40
+ break
41
+
42
+ cap.release()
43
+ cv2.destroyAllWindows()
44
+ return
45
+
46
+ # Image inference
47
+ print(f"Running inference on {image_path}...")
48
+ results = model(image_path)
49
+
50
+ # Print results
51
+ for result in results:
52
+ boxes = result.boxes
53
+ print(f"\nFound {len(boxes)} strawberries:")
54
+
55
+ for i, box in enumerate(boxes):
56
+ confidence = box.conf[0].item()
57
+ print(f" Strawberry {i+1}: {confidence:.2%} confidence")
58
+
59
+ # Save annotated image
60
+ output_path = 'output.jpg'
61
+ result.save(output_path)
62
+ print(f"\nSaved annotated image to {output_path}")
63
+
64
+ if __name__ == '__main__':
65
+ main()
notebooks/strawberry_training.ipynb ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import tensorflow as tf\n",
10
+ "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
11
+ "from tensorflow.keras.applications import MobileNetV2\n",
12
+ "from tensorflow.keras.layers import Dense, GlobalAveragePooling2D\n",
13
+ "from tensorflow.keras.models import Model\n",
14
+ "import os\n",
15
+ "\n",
16
+ "# Data directories\n",
17
+ "data_dir = 'dataset'\n",
18
+ "train_datagen = ImageDataGenerator(\n",
19
+ " rescale=1./255,\n",
20
+ " validation_split=0.2,\n",
21
+ " rotation_range=20,\n",
22
+ " width_shift_range=0.2,\n",
23
+ " height_shift_range=0.2,\n",
24
+ " horizontal_flip=True\n",
25
+ ")\n",
26
+ "\n",
27
+ "train_generator = train_datagen.flow_from_directory(\n",
28
+ " data_dir,\n",
29
+ " target_size=(224, 224),\n",
30
+ " batch_size=32,\n",
31
+ " class_mode='binary',\n",
32
+ " subset='training'\n",
33
+ ")\n",
34
+ "\n",
35
+ "validation_generator = train_datagen.flow_from_directory(\n",
36
+ " data_dir,\n",
37
+ " target_size=(224, 224),\n",
38
+ " batch_size=32,\n",
39
+ " class_mode='binary',\n",
40
+ " subset='validation'\n",
41
+ ")\n",
42
+ "\n",
43
+ "# Load pre-trained MobileNetV2\n",
44
+ "base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
45
+ "\n",
46
+ "# Add custom layers\n",
47
+ "x = base_model.output\n",
48
+ "x = GlobalAveragePooling2D()(x)\n",
49
+ "x = Dense(1024, activation='relu')(x)\n",
50
+ "predictions = Dense(1, activation='sigmoid')(x)\n",
51
+ "\n",
52
+ "model = Model(inputs=base_model.input, outputs=predictions)\n",
53
+ "\n",
54
+ "# Freeze base layers\n",
55
+ "for layer in base_model.layers:\n",
56
+ " layer.trainable = False\n",
57
+ "\n",
58
+ "# Compile\n",
59
+ "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
60
+ "\n",
61
+ "# Train\n",
62
+ "model.fit(train_generator, validation_data=validation_generator, epochs=10)\n",
63
+ "\n",
64
+ "# Save model\n",
65
+ "model.save('strawberry_model.h5')\n",
66
+ "\n",
67
+ "print(\"Model trained and saved as strawberry_model.h5\")"
68
+ ]
69
+ }
70
+ ],
71
+ "metadata": {
72
+ "kernelspec": {
73
+ "display_name": "Python 3",
74
+ "language": "python",
75
+ "name": "python3"
76
+ },
77
+ "language_info": {
78
+ "codemirror_mode": {
79
+ "name": "ipython",
80
+ "version": 3
81
+ },
82
+ "file_extension": ".py",
83
+ "mimetype": "text/x-python",
84
+ "name": "python",
85
+ "nbconvert_exporter": "python",
86
+ "pygments_lexer": "ipython3",
87
+ "version": "3.12.3"
88
+ }
89
+ },
90
+ "nbformat": 4,
91
+ "nbformat_minor": 4
92
+ }
notebooks/train_yolov8_colab.ipynb ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# YOLOv8 Strawberry Detection Training - Google Colab Version\n",
8
+ "\n",
9
+ "This notebook trains a YOLOv8 model for strawberry detection using Google Colab's free GPU.\n",
10
+ "\n",
11
+ "## Setup\n",
12
+ "\n",
13
+ "1. Connect to GPU runtime: Runtime → Change runtime type → GPU\n",
14
+ "2. Mount your Google Drive if needed for dataset access"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Check GPU availability\n",
24
+ "import torch\n",
25
+ "print(f\"GPU Available: {torch.cuda.is_available()}\")\n",
26
+ "if torch.cuda.is_available():\n",
27
+ " print(f\"GPU Name: {torch.cuda.get_device_name(0)}\")\n",
28
+ " print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "# Install dependencies\n",
38
+ "!pip install ultralytics torch torchvision opencv-python matplotlib tqdm tensorboard"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "## Dataset Setup\n",
46
+ "\n",
47
+ "Option 1: Upload dataset to Colab (temporary, lost after session ends)\n",
48
+ "Option 2: Mount Google Drive for persistent storage\n",
49
+ "Option 3: Download from URL"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# Option 2: Mount Google Drive (recommended)\n",
59
+ "from google.colab import drive\n",
60
+ "drive.mount('/content/drive')\n",
61
+ "\n",
62
+ "# Update this path to your dataset location in Google Drive\n",
63
+ "DATASET_PATH = '/content/drive/MyDrive/strawberry-dataset/straw-detect.v1-straw-detect.yolov8'\n",
64
+ "\n",
65
+ "# Or use direct path if dataset is uploaded to Colab\n",
66
+ "# DATASET_PATH = '/content/straw-detect.v1-straw-detect.yolov8'"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# Validate dataset structure\n",
76
+ "import os\n",
77
+ "import yaml\n",
78
+ "from pathlib import Path\n",
79
+ "\n",
80
+ "dataset_path = Path(DATASET_PATH)\n",
81
+ "data_yaml = dataset_path / 'data.yaml'\n",
82
+ "\n",
83
+ "if not data_yaml.exists():\n",
84
+ " print(f\"ERROR: data.yaml not found at {data_yaml}\")\n",
85
+ "else:\n",
86
+ " print(f\"✓ Found data.yaml at {data_yaml}\")\n",
87
+ " \n",
88
+ " with open(data_yaml, 'r') as f:\n",
89
+ " data = yaml.safe_load(f)\n",
90
+ " \n",
91
+ " print(f\"Dataset info:\")\n",
92
+ " print(f\" Classes: {data['nc']}\")\n",
93
+ " print(f\" Names: {data['names']}\")\n",
94
+ " \n",
95
+ " # Check image counts\n",
96
+ " train_path = dataset_path / data['train']\n",
97
+ " val_path = dataset_path / data['val']\n",
98
+ " \n",
99
+ " if train_path.exists():\n",
100
+ " train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))\n",
101
+ " print(f\" Training images: {len(train_images)}\")\n",
102
+ " \n",
103
+ " if val_path.exists():\n",
104
+ " val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))\n",
105
+ " print(f\" Validation images: {len(val_images)}\")"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "## Training Configuration"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "# Training parameters\n",
122
+ "EPOCHS = 100\n",
123
+ "IMG_SIZE = 640\n",
124
+ "BATCH_SIZE = 16 # Adjust based on GPU memory\n",
125
+ "MODEL_NAME = 'yolov8n' # yolov8n, yolov8s, yolov8m, yolov8l, yolov8x\n",
126
+ "\n",
127
+ "# Output directories\n",
128
+ "RESULTS_DIR = '/content/strawberry-results'\n",
129
+ "WEIGHTS_DIR = '/content/strawberry-weights'\n",
130
+ "\n",
131
+ "import os\n",
132
+ "os.makedirs(RESULTS_DIR, exist_ok=True)\n",
133
+ "os.makedirs(WEIGHTS_DIR, exist_ok=True)\n",
134
+ "\n",
135
+ "print(f\"Results will be saved to: {RESULTS_DIR}\")\n",
136
+ "print(f\"Weights will be saved to: {WEIGHTS_DIR}\")"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "## Train YOLOv8 Model"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "from ultralytics import YOLO\n",
153
+ "import torch\n",
154
+ "\n",
155
+ "# Load pretrained model\n",
156
+ "print(f\"Loading {MODEL_NAME} model...\")\n",
157
+ "model = YOLO(f'{MODEL_NAME}.pt')\n",
158
+ "\n",
159
+ "# Train the model\n",
160
+ "print(f\"Starting training for {EPOCHS} epochs...\")\n",
161
+ "results = model.train(\n",
162
+ " data=str(data_yaml),\n",
163
+ " epochs=EPOCHS,\n",
164
+ " imgsz=IMG_SIZE,\n",
165
+ " batch=BATCH_SIZE,\n",
166
+ " device='0' if torch.cuda.is_available() else 'cpu',\n",
167
+ " project=RESULTS_DIR,\n",
168
+ " name='strawberry_detection',\n",
169
+ " exist_ok=True,\n",
170
+ " patience=20, # Early stopping\n",
171
+ " save=True,\n",
172
+ " save_period=10, # Save checkpoint every 10 epochs\n",
173
+ " cache=True,\n",
174
+ " verbose=True\n",
175
+ ")"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {},
181
+ "source": [
182
+ "## Save and Export Model"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# Save final model\n",
192
+ "final_model_path = f'{WEIGHTS_DIR}/strawberry_{MODEL_NAME}.pt'\n",
193
+ "model.save(final_model_path)\n",
194
+ "print(f\"✓ Model saved to: {final_model_path}\")\n",
195
+ "\n",
196
+ "# Export to ONNX format\n",
197
+ "print(\"Exporting to ONNX format...\")\n",
198
+ "onnx_path = model.export(format='onnx', imgsz=IMG_SIZE, dynamic=True)\n",
199
+ "print(f\"✓ ONNX model exported to: {onnx_path}\")\n",
200
+ "\n",
201
+ "# Export to TensorFlow Lite format (for Raspberry Pi)\n",
202
+ "print(\"Exporting to TensorFlow Lite format...\")\n",
203
+ "tflite_path = model.export(format='tflite', imgsz=IMG_SIZE)\n",
204
+ "print(f\"✓ TFLite model exported to: {tflite_path}\")"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "markdown",
209
+ "metadata": {},
210
+ "source": [
211
+ "## View Training Results"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "# Display training results\n",
221
+ "import matplotlib.pyplot as plt\n",
222
+ "from pathlib import Path\n",
223
+ "\n",
224
+ "results_dir = Path(RESULTS_DIR) / 'strawberry_detection'\n",
225
+ "plots_dir = results_dir / 'plots'\n",
226
+ "\n",
227
+ "if plots_dir.exists():\n",
228
+ " print(\"Training plots:\")\n",
229
+ " for plot_file in plots_dir.glob('*.png'):\n",
230
+ " print(f\" - {plot_file.name}\")\n",
231
+ "else:\n",
232
+ " print(\"Plots directory not found yet. Training may still be in progress.\")"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "# Show confusion matrix\n",
242
+ "from IPython.display import Image, display\n",
243
+ "\n",
244
+ "confusion_matrix = plots_dir / 'confusion_matrix.png'\n",
245
+ "if confusion_matrix.exists():\n",
246
+ " print(\"Confusion Matrix:\")\n",
247
+ " display(Image(filename=str(confusion_matrix)))"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "## Download Trained Model\n",
255
+ "\n",
256
+ "Download the trained model to your local machine:"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "from google.colab import files\n",
266
+ "\n",
267
+ "# Download PyTorch model\n",
268
+ "files.download(final_model_path)\n",
269
+ "\n",
270
+ "# Download ONNX model\n",
271
+ "files.download(str(onnx_path))\n",
272
+ "\n",
273
+ "# Download TFLite model\n",
274
+ "files.download(str(tflite_path))"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "metadata": {},
280
+ "source": [
281
+ "## Next Steps\n",
282
+ "\n",
283
+ "1. Copy the trained model to your Raspberry Pi\n",
284
+ "2. Use the `detect_realtime_pi.py` script for inference\n",
285
+ "3. Integrate with your robotic arm control system\n",
286
+ "\n",
287
+ "## Tips for Better Results\n",
288
+ "\n",
289
+ "- If accuracy is low, increase EPOCHS to 150-200\n",
290
+ "- Try different MODEL_NAME sizes: yolov8s, yolov8m (slower but more accurate)\n",
291
+ "- Collect more training images with varied lighting conditions\n",
292
+ "- Use data augmentation techniques"
293
+ ]
294
+ }
295
+ ],
296
+ "metadata": {
297
+ "kernelspec": {
298
+ "display_name": "Python 3",
299
+ "language": "python",
300
+ "name": "python3"
301
+ },
302
+ "language_info": {
303
+ "name": "python",
304
+ "version": "3.8.0"
305
+ }
306
+ },
307
+ "nbformat": 4,
308
+ "nbformat_minor": 4
309
+ }
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ torch>=1.8.0
3
+ torchvision>=0.9.0
4
+ ultralytics>=8.0.0
5
+ opencv-python>=4.5.0
6
+ numpy>=1.21.0
7
+ matplotlib>=3.3.0
8
+ Pillow>=8.0.0
9
+ tqdm>=4.60.0
10
+ tensorboard>=2.7.0
11
+ onnx>=1.10.0
12
+ onnxruntime>=1.10.0
13
+ tensorflow>=2.8.0
14
+ =======
15
+ ultralytics>=8.0.0
16
+ opencv-python>=4.5.0
17
+ numpy>=1.21.0
18
+ torch>=1.8.0
19
+ torchvision>=0.9.0
20
+ >>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
results.csv ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
2
+ 1,6.86824,1.32115,2.72359,1.31365,0.0042,0.96396,0.34849,0.15119,1.72013,3.19145,1.6125,0.08416,0.00016,0.00016
3
+ 2,12.0365,1.29183,1.25338,1.18955,0.825,0.59459,0.73878,0.51353,1.36044,3.55862,1.26268,0.0673235,0.000323466,0.000323466
4
+ 3,16.7713,1.21518,0.99614,1.17689,0.89744,0.63063,0.7709,0.49709,1.44287,3.62474,1.297,0.0504802,0.0004802,0.0004802
5
+ 4,21.4651,1.2383,0.9566,1.17915,0.47534,0.95495,0.91297,0.62608,1.21511,2.44191,1.15645,0.0336302,0.000630202,0.000630202
6
+ 5,26.0183,1.23907,0.91469,1.16288,0.88867,0.9349,0.94819,0.61003,1.24465,1.14903,1.21458,0.0167735,0.000773472,0.000773472
7
+ 6,30.4778,1.22134,0.90462,1.19466,0.91765,0.90347,0.9523,0.62325,1.19483,0.9782,1.23663,0.000901,0.000901,0.000901
8
+ 7,34.6509,1.20304,0.8498,1.20764,0.7236,0.75676,0.78748,0.52191,1.15781,1.36482,1.33971,0.0008812,0.0008812,0.0008812
9
+ 8,39.22,1.16651,0.84314,1.21721,0.86762,0.88573,0.95211,0.63231,1.19173,0.83562,1.39059,0.0008614,0.0008614,0.0008614
10
+ 9,43.7182,1.19071,0.75512,1.21773,0.92471,0.88288,0.9604,0.64617,1.16949,0.8096,1.34081,0.0008416,0.0008416,0.0008416
11
+ 10,48.0715,1.15343,0.75549,1.19624,0.95433,0.85586,0.94036,0.63013,1.15544,0.87035,1.3024,0.0008218,0.0008218,0.0008218
12
+ 11,52.1875,1.17306,0.77483,1.19032,0.90561,0.86486,0.91627,0.61251,1.1512,0.87167,1.33988,0.000802,0.000802,0.000802
13
+ 12,56.3393,1.15909,0.72052,1.21331,0.93901,0.92793,0.97553,0.65695,1.19064,0.67324,1.35,0.0007822,0.0007822,0.0007822
14
+ 13,60.7417,1.12648,0.70107,1.1631,0.92996,0.9009,0.95406,0.63081,1.20018,0.74019,1.3586,0.0007624,0.0007624,0.0007624
15
+ 14,65.2298,1.09008,0.70923,1.15539,0.91466,0.91892,0.96205,0.67663,1.10559,0.7621,1.25653,0.0007426,0.0007426,0.0007426
16
+ 15,69.571,1.14397,0.69424,1.1814,0.9719,0.93464,0.97478,0.66421,1.14459,0.60123,1.3482,0.0007228,0.0007228,0.0007228
17
+ 16,74.0715,1.1273,0.69053,1.15822,0.94736,0.89189,0.96811,0.67023,1.15103,0.6696,1.29367,0.000703,0.000703,0.000703
18
+ 17,78.482,1.14856,0.70743,1.19375,0.93885,0.95495,0.97045,0.6589,1.1875,0.61151,1.36611,0.0006832,0.0006832,0.0006832
19
+ 18,82.5976,1.11894,0.68007,1.1589,0.98034,0.89868,0.97938,0.65622,1.1357,0.6026,1.32801,0.0006634,0.0006634,0.0006634
20
+ 19,87.0478,1.11983,0.64666,1.16769,0.92238,0.96396,0.97889,0.66274,1.15968,0.60674,1.31971,0.0006436,0.0006436,0.0006436
21
+ 20,91.4751,1.08926,0.64732,1.14873,0.95361,0.95495,0.98242,0.67862,1.13128,0.56679,1.32197,0.0006238,0.0006238,0.0006238
22
+ 21,95.769,1.06938,0.63992,1.14856,0.93911,0.97265,0.98132,0.6723,1.11844,0.53648,1.29829,0.000604,0.000604,0.000604
23
+ 22,100.516,1.0275,0.60948,1.11803,0.96362,0.9544,0.9832,0.66712,1.11005,0.57067,1.29318,0.0005842,0.0005842,0.0005842
24
+ 23,104.83,1.08322,0.63118,1.16093,0.961,0.93694,0.9822,0.66812,1.12503,0.59445,1.29324,0.0005644,0.0005644,0.0005644
25
+ 24,109.564,1.03814,0.59752,1.13687,0.92285,0.95495,0.98092,0.66146,1.16236,0.56458,1.30195,0.0005446,0.0005446,0.0005446
26
+ 25,113.789,1.05733,0.61255,1.12831,0.93744,0.94499,0.97589,0.66985,1.14589,0.58107,1.29102,0.0005248,0.0005248,0.0005248
27
+ 26,118.014,1.03262,0.59764,1.11505,0.93427,0.96396,0.97761,0.66389,1.13641,0.56017,1.31125,0.000505,0.000505,0.000505
28
+ 27,122.046,1.05034,0.59402,1.1313,0.9387,0.95495,0.98023,0.66648,1.1571,0.58369,1.29765,0.0004852,0.0004852,0.0004852
29
+ 28,126.296,1.02708,0.58665,1.14179,0.95501,0.95618,0.98264,0.64604,1.17408,0.59978,1.35501,0.0004654,0.0004654,0.0004654
30
+ 29,130.486,1.0421,0.59585,1.11415,0.94498,0.98198,0.98278,0.66907,1.13142,0.55493,1.31176,0.0004456,0.0004456,0.0004456
31
+ 30,135.122,1.00967,0.58472,1.11493,0.98144,0.95295,0.98468,0.67596,1.10479,0.53006,1.32835,0.0004258,0.0004258,0.0004258
32
+ 31,139.655,0.96788,0.52585,1.13779,0.97128,0.91892,0.98086,0.66954,1.12559,0.55724,1.33653,0.000406,0.000406,0.000406
33
+ 32,144.364,0.98252,0.56894,1.10962,0.92865,0.96396,0.98034,0.66634,1.12278,0.5327,1.33008,0.0003862,0.0003862,0.0003862
34
+ 33,148.518,0.98348,0.55054,1.11565,0.93032,0.98198,0.98389,0.6858,1.12571,0.49984,1.32711,0.0003664,0.0003664,0.0003664
35
+ 34,152.545,1.00252,0.55027,1.09994,0.96066,0.94595,0.98377,0.69641,1.12416,0.52283,1.2885,0.0003466,0.0003466,0.0003466
36
+ 35,156.54,0.94835,0.53754,1.08347,0.9303,0.96198,0.98116,0.67964,1.11046,0.5242,1.30978,0.0003268,0.0003268,0.0003268
37
+ 36,160.636,0.96319,0.53923,1.11176,0.91568,0.99099,0.98692,0.68369,1.11171,0.48497,1.30834,0.000307,0.000307,0.000307
38
+ 37,164.837,0.95618,0.53593,1.08664,0.97244,0.95364,0.98547,0.69178,1.11108,0.50238,1.28898,0.0002872,0.0002872,0.0002872
39
+ 38,168.917,0.97244,0.5482,1.08408,0.96389,0.91892,0.98071,0.66854,1.12583,0.51788,1.32157,0.0002674,0.0002674,0.0002674
40
+ 39,173.37,0.93709,0.51704,1.08327,0.97083,0.90991,0.98181,0.66532,1.14223,0.52647,1.35507,0.0002476,0.0002476,0.0002476
41
+ 40,178.082,0.92316,0.50575,1.07146,0.96253,0.92569,0.98231,0.69138,1.11065,0.51471,1.30314,0.0002278,0.0002278,0.0002278
42
+ 41,183.476,0.8683,0.46482,1.0931,0.95441,0.94301,0.98327,0.69583,1.11377,0.5044,1.29681,0.000208,0.000208,0.000208
43
+ 42,187.581,0.859,0.45945,1.08201,0.96091,0.97297,0.98492,0.70061,1.10066,0.47832,1.30928,0.0001882,0.0001882,0.0001882
44
+ 43,191.651,0.85681,0.44947,1.08108,0.9516,0.97297,0.98641,0.70187,1.10791,0.47281,1.33175,0.0001684,0.0001684,0.0001684
45
+ 44,195.634,0.82438,0.44285,1.06652,0.95144,0.97297,0.98817,0.6897,1.10938,0.47912,1.32276,0.0001486,0.0001486,0.0001486
46
+ 45,199.92,0.838,0.42537,1.06944,0.95343,0.97297,0.98846,0.68458,1.12846,0.4846,1.33303,0.0001288,0.0001288,0.0001288
47
+ 46,204.796,0.8131,0.41691,1.0408,0.94776,0.95495,0.98962,0.67581,1.15034,0.48295,1.38289,0.000109,0.000109,0.000109
48
+ 47,209.422,0.81183,0.40894,1.06726,0.93117,0.97502,0.98866,0.68269,1.13011,0.47496,1.35196,8.92e-05,8.92e-05,8.92e-05
49
+ 48,213.529,0.80037,0.39995,1.04921,0.95491,0.95402,0.98838,0.6885,1.12861,0.48232,1.32018,6.94e-05,6.94e-05,6.94e-05
50
+ 49,217.607,0.81871,0.4227,1.0661,0.95526,0.96185,0.98927,0.69025,1.14264,0.48573,1.33214,4.96e-05,4.96e-05,4.96e-05
51
+ 50,221.537,0.802,0.41167,1.05813,0.95484,0.96396,0.98857,0.68227,1.13242,0.47929,1.33976,2.98e-05,2.98e-05,2.98e-05
scripts/all_combine3.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dual_camera_15cm_tuned.py
2
+ """
3
+ FINAL TUNED VERSION
4
+ - Baseline: 15.0 cm
5
+ - Range: 20-40 cm
6
+ - Calibration: Tuned to 0.82 based on your latest log (48.5cm -> 30cm).
7
+ """
8
+ import cv2
9
+ import numpy as np
10
+ from ultralytics import YOLO
11
+ import math
12
+
13
+ # ----------------------------
14
+ # USER SETTINGS
15
+ # ----------------------------
16
+ CAM_A_ID = 2 # LEFT Camera
17
+ CAM_B_ID = 1 # RIGHT Camera
18
+
19
+ FRAME_W = 640
20
+ FRAME_H = 408
21
+ YOLO_MODEL_PATH = "strawberry.pt"
22
+
23
+ # --- GEOMETRY & CALIBRATION ---
24
+ BASELINE_CM = 15.0
25
+ FOCUS_DIST_CM = 30.0
26
+
27
+ # RE-CALIBRATED SCALAR
28
+ # Your Raw Reading is approx 36.6cm. Real is 30.0cm.
29
+ # 30.0 / 36.6 = ~0.82
30
+ DEPTH_SCALAR = 0.82
31
+
32
+ # Auto-calculate angle
33
+ val = (BASELINE_CM / 2.0) / FOCUS_DIST_CM
34
+ calc_yaw_deg = math.degrees(math.atan(val))
35
+ YAW_LEFT_DEG = calc_yaw_deg
36
+ YAW_RIGHT_DEG = -calc_yaw_deg
37
+
38
+ print(f"--- CONFIGURATION ---")
39
+ print(f"1. Baseline: {BASELINE_CM} cm")
40
+ print(f"2. Scalar: x{DEPTH_SCALAR} (Reducing raw estimate to match reality)")
41
+ print(f"3. REQUIRED ANGLE: +/- {calc_yaw_deg:.2f} degrees")
42
+ print(f"---------------------")
43
+
44
+ # Intrinsics
45
+ K_A = np.array([[629.10808758, 0.0, 347.20913144],
46
+ [0.0, 631.11321979, 277.5222819],
47
+ [0.0, 0.0, 1.0]], dtype=np.float64)
48
+ dist_A = np.array([-0.35469562, 0.10232556, -0.0005468, -0.00174671, 0.01546246], dtype=np.float64)
49
+
50
+ K_B = np.array([[1001.67997, 0.0, 367.736216],
51
+ [0.0, 996.698369, 312.866527],
52
+ [0.0, 0.0, 1.0]], dtype=np.float64)
53
+ dist_B = np.array([-0.49543094, 0.82826695, -0.00180861, -0.00362202, -1.42667838], dtype=np.float64)
54
+
55
+ # ----------------------------
56
+ # HELPERS
57
+ # ----------------------------
58
+ def capture_single(cam_id):
59
+ cap = cv2.VideoCapture(cam_id, cv2.CAP_DSHOW)
60
+ if not cap.isOpened(): return None
61
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_W)
62
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_H)
63
+ for _ in range(5): cap.read() # Warmup
64
+ ret, frame = cap.read()
65
+ cap.release()
66
+ return frame
67
+
68
+ def build_undistort_maps(K, dist):
69
+ newK, _ = cv2.getOptimalNewCameraMatrix(K, dist, (FRAME_W, FRAME_H), 1.0)
70
+ mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, newK, (FRAME_W, FRAME_H), cv2.CV_32FC1)
71
+ return mapx, mapy, newK
72
+
73
+ def detect_on_image(model, img):
74
+ results = model(img, verbose=False)[0]
75
+ dets = []
76
+ for box in results.boxes:
77
+ x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
78
+ cx = int((x1 + x2) / 2)
79
+ cy = int((y1 + y2) / 2)
80
+ conf = float(box.conf[0])
81
+ cls = int(box.cls[0])
82
+ name = model.names.get(cls, str(cls))
83
+ dets.append({'x1':x1,'y1':y1,'x2':x2,'y2':y2,'cx':cx,'cy':cy,'conf':conf,'cls':cls,'name':name})
84
+ return sorted(dets, key=lambda d: d['cx'])
85
+
86
+ def match_stereo(detL, detR):
87
+ matches = []
88
+ usedR = set()
89
+ for l in detL:
90
+ best_idx = -1
91
+ best_score = 9999
92
+ for i, r in enumerate(detR):
93
+ if i in usedR: continue
94
+ if l['cls'] != r['cls']: continue
95
+ dy = abs(l['cy'] - r['cy'])
96
+ if dy > 60: continue
97
+ if dy < best_score:
98
+ best_score = dy
99
+ best_idx = i
100
+ if best_idx != -1:
101
+ matches.append((l, detR[best_idx]))
102
+ usedR.add(best_idx)
103
+ return matches
104
+
105
+ # --- 3D MATH ---
106
+ def yaw_to_R_deg(yaw_deg):
107
+ y = math.radians(yaw_deg)
108
+ cy = math.cos(y); sy = math.sin(y)
109
+ return np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float64)
110
+
111
+ def build_projection_matrices(newK_A, newK_B, yaw_L, yaw_R, baseline):
112
+ R_L = yaw_to_R_deg(yaw_L)
113
+ R_R = yaw_to_R_deg(yaw_R)
114
+ R_W2A = R_L.T
115
+ t_W2A = np.zeros((3,1))
116
+ R_W2B = R_R.T
117
+ C_B_world = np.array([[baseline], [0.0], [0.0]])
118
+ t_W2B = -R_W2B @ C_B_world
119
+ P1 = newK_A @ np.hstack((R_W2A, t_W2A))
120
+ P2 = newK_B @ np.hstack((R_W2B, t_W2B))
121
+ return P1, P2
122
+
123
+ def triangulate_matrix(dL, dR, P1, P2):
124
+ ptsL = np.array([[float(dL['cx'])],[float(dL['cy'])]], dtype=np.float64)
125
+ ptsR = np.array([[float(dR['cx'])],[float(dR['cy'])]], dtype=np.float64)
126
+
127
+ Xh = cv2.triangulatePoints(P1, P2, ptsL, ptsR)
128
+ Xh /= Xh[3]
129
+
130
+ X = float(Xh[0].item())
131
+ Y = float(Xh[1].item())
132
+ Z_raw = float(Xh[2].item())
133
+
134
+ return X, Y, Z_raw * DEPTH_SCALAR
135
+
136
+ def main():
137
+ print("[INFO] Loading YOLO...")
138
+ model = YOLO(YOLO_MODEL_PATH)
139
+
140
+ print(f"[INFO] Capturing...")
141
+ frameA = capture_single(CAM_A_ID)
142
+ frameB = capture_single(CAM_B_ID)
143
+ if frameA is None or frameB is None: return
144
+
145
+ mapAx, mapAy, newKA = build_undistort_maps(K_A, dist_A)
146
+ mapBx, mapBy, newKB = build_undistort_maps(K_B, dist_B)
147
+ undA = cv2.remap(frameA, mapAx, mapAy, cv2.INTER_LINEAR)
148
+ undB = cv2.remap(frameB, mapBx, mapBy, cv2.INTER_LINEAR)
149
+
150
+ detA = detect_on_image(model, undA)
151
+ detB = detect_on_image(model, undB)
152
+
153
+ matches = match_stereo(detA, detB)
154
+ print(f"--- Matches found: {len(matches)} ---")
155
+
156
+ P1, P2 = build_projection_matrices(newKA, newKB, YAW_LEFT_DEG, YAW_RIGHT_DEG, BASELINE_CM)
157
+
158
+ combo = np.hstack((undA, undB))
159
+
160
+ for l, r in matches:
161
+ XYZ = triangulate_matrix(l, r, P1, P2)
162
+ X,Y,Z = XYZ
163
+
164
+ label = f"Z={Z:.1f}cm"
165
+ print(f"Target ({l['name']}): {label} (X={X:.1f}, Y={Y:.1f})")
166
+
167
+ cv2.line(combo, (l['cx'], l['cy']), (r['cx']+FRAME_W, r['cy']), (0,255,0), 2)
168
+ cv2.rectangle(combo, (l['x1'], l['y1']), (l['x2'], l['y2']), (0,0,255), 2)
169
+ cv2.rectangle(combo, (r['x1']+FRAME_W, r['y1']), (r['x2']+FRAME_W, r['y2']), (0,0,255), 2)
170
+ cv2.putText(combo, label, (l['cx'], l['cy']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,255), 2)
171
+
172
+ cv2.imshow("Tuned Depth Result", combo)
173
+ cv2.waitKey(0)
174
+ cv2.destroyAllWindows()
175
+
176
+ if __name__ == "__main__":
177
+ main()
scripts/auto_label_strawberries.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Automated Strawberry Ripeness Labeling System
4
+ Uses color analysis to automatically label strawberry ripeness
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import cv2
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from PIL import Image
13
+ import argparse
14
+ import json
15
+ from datetime import datetime
16
+
17
+ class AutoRipenessLabeler:
18
+ def __init__(self):
19
+ """Initialize the automatic ripeness labeler"""
20
+ print("✅ Initialized automatic ripeness labeler")
21
+
22
+ def analyze_strawberry_color(self, image_path):
23
+ """Analyze the color of strawberries to determine ripeness"""
24
+ try:
25
+ # Load image
26
+ img = cv2.imread(str(image_path))
27
+ if img is None:
28
+ return None
29
+
30
+ # Convert BGR to RGB
31
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
32
+
33
+ # Convert to HSV for better color analysis
34
+ hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
35
+
36
+ # Define color ranges for different ripeness stages
37
+ # Red range (ripe strawberries)
38
+ red_lower1 = np.array([0, 50, 50])
39
+ red_upper1 = np.array([10, 255, 255])
40
+ red_lower2 = np.array([170, 50, 50])
41
+ red_upper2 = np.array([180, 255, 255])
42
+
43
+ # Green range (unripe strawberries)
44
+ green_lower = np.array([40, 40, 40])
45
+ green_upper = np.array([80, 255, 255])
46
+
47
+ # Dark red range (overripe strawberries)
48
+ dark_red_lower = np.array([0, 100, 0])
49
+ dark_red_upper = np.array([20, 255, 100])
50
+
51
+ # Create masks for each color range
52
+ red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
53
+ red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
54
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
55
+
56
+ green_mask = cv2.inRange(hsv, green_lower, green_upper)
57
+ dark_red_mask = cv2.inRange(hsv, dark_red_lower, dark_red_upper)
58
+
59
+ # Calculate percentages
60
+ total_pixels = hsv.shape[0] * hsv.shape[1]
61
+ red_pixels = np.sum(red_mask > 0)
62
+ green_pixels = np.sum(green_mask > 0)
63
+ dark_red_pixels = np.sum(dark_red_mask > 0)
64
+
65
+ red_percentage = red_pixels / total_pixels
66
+ green_percentage = green_pixels / total_pixels
67
+ dark_red_percentage = dark_red_pixels / total_pixels
68
+
69
+ # Calculate brightness and saturation for fallback analysis
70
+ gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
71
+ avg_brightness = np.mean(gray)
72
+ avg_saturation = np.mean(hsv[:, :, 1])
73
+
74
+ # Determine ripeness based on color percentages
75
+ if green_percentage > 0.3:
76
+ ripeness = "unripe"
77
+ confidence = min(green_percentage * 2, 0.9)
78
+ elif dark_red_percentage > 0.2:
79
+ ripeness = "overripe"
80
+ confidence = min(dark_red_percentage * 2, 0.9)
81
+ elif red_percentage > 0.2:
82
+ ripeness = "ripe"
83
+ confidence = min(red_percentage * 2, 0.9)
84
+ else:
85
+ # Fallback: use brightness and saturation
86
+ if avg_brightness < 80:
87
+ ripeness = "overripe"
88
+ confidence = 0.6
89
+ elif avg_brightness > 150:
90
+ ripeness = "unripe"
91
+ confidence = 0.6
92
+ else:
93
+ ripeness = "ripe"
94
+ confidence = 0.7
95
+
96
+ return {
97
+ 'ripeness': ripeness,
98
+ 'confidence': confidence,
99
+ 'color_analysis': {
100
+ 'red_percentage': red_percentage,
101
+ 'green_percentage': green_percentage,
102
+ 'dark_red_percentage': dark_red_percentage,
103
+ 'avg_brightness': float(avg_brightness),
104
+ 'avg_saturation': float(avg_saturation)
105
+ }
106
+ }
107
+
108
+ except Exception as e:
109
+ print(f"Error analyzing color in {image_path}: {e}")
110
+ return None
111
+
112
+ def batch_auto_label(self, image_files, output_dirs, confidence_threshold=0.6):
113
+ """Automatically label a batch of images"""
114
+ results = []
115
+
116
+ for i, image_path in enumerate(image_files):
117
+ print(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
118
+
119
+ analysis = self.analyze_strawberry_color(image_path)
120
+
121
+ if analysis and analysis['confidence'] >= confidence_threshold:
122
+ ripeness = analysis['ripeness']
123
+ confidence = analysis['confidence']
124
+
125
+ # Copy image to appropriate directory
126
+ dest_path = output_dirs[ripeness] / image_path.name
127
+ try:
128
+ import shutil
129
+ shutil.copy2(image_path, dest_path)
130
+ print(f" ✅ {ripeness} (confidence: {confidence:.2f})")
131
+ results.append({
132
+ 'image': image_path.name,
133
+ 'label': ripeness,
134
+ 'confidence': confidence,
135
+ 'analysis': analysis['color_analysis']
136
+ })
137
+ except Exception as e:
138
+ print(f" ❌ Error copying file: {e}")
139
+ else:
140
+ print(f" ⚠️ Low confidence or analysis failed")
141
+ results.append({
142
+ 'image': image_path.name,
143
+ 'label': 'unknown',
144
+ 'confidence': analysis['confidence'] if analysis else 0.0,
145
+ 'analysis': analysis['color_analysis'] if analysis else None
146
+ })
147
+
148
+ return results
149
+
150
+ def main():
151
+ parser = argparse.ArgumentParser(description='Automatically label strawberry ripeness dataset')
152
+ parser.add_argument('--dataset-path', type=str,
153
+ default='model/ripeness_manual_dataset',
154
+ help='Path to the ripeness dataset directory')
155
+ parser.add_argument('--confidence-threshold', type=float, default=0.6,
156
+ help='Minimum confidence for automatic labeling')
157
+ parser.add_argument('--max-images', type=int, default=50,
158
+ help='Maximum number of images to process')
159
+
160
+ args = parser.parse_args()
161
+
162
+ base_path = Path(args.dataset_path)
163
+ to_label_path = base_path / 'to_label'
164
+
165
+ if not to_label_path.exists():
166
+ print(f"Error: to_label directory not found at {to_label_path}")
167
+ return
168
+
169
+ # Create output directories
170
+ output_dirs = {}
171
+ for label in ['unripe', 'ripe', 'overripe']:
172
+ dir_path = base_path / label
173
+ dir_path.mkdir(exist_ok=True)
174
+ output_dirs[label] = dir_path
175
+
176
+ # Get image files
177
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
178
+ image_files = []
179
+ for file_path in to_label_path.iterdir():
180
+ if file_path.suffix.lower() in image_extensions:
181
+ image_files.append(file_path)
182
+
183
+ image_files = sorted(image_files)[:args.max_images]
184
+
185
+ print(f"Found {len(image_files)} images to process")
186
+ print(f"Confidence threshold: {args.confidence_threshold}")
187
+
188
+ if not image_files:
189
+ print("No images found to process.")
190
+ return
191
+
192
+ # Initialize auto labeler
193
+ labeler = AutoRipenessLabeler()
194
+
195
+ # Process images
196
+ results = labeler.batch_auto_label(image_files, output_dirs, args.confidence_threshold)
197
+
198
+ # Save results
199
+ results_file = base_path / f'auto_labeling_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
200
+ with open(results_file, 'w') as f:
201
+ json.dump(results, f, indent=2)
202
+
203
+ # Print summary
204
+ label_counts = {'unripe': 0, 'ripe': 0, 'overripe': 0, 'unknown': 0}
205
+ for result in results:
206
+ label_counts[result['label']] += 1
207
+
208
+ print("\n=== AUTOMATIC LABELING RESULTS ===")
209
+ for label, count in label_counts.items():
210
+ print(f"{label}: {count} images")
211
+
212
+ print(f"\nResults saved to: {results_file}")
213
+
214
+ if label_counts['unknown'] > 0:
215
+ print(f"\n⚠️ {label_counts['unknown']} images need manual review")
216
+ print("You can use the manual labeling tool for these:")
217
+ print("python3 label_ripeness_dataset.py")
218
+
219
+ if __name__ == '__main__':
220
+ main()
scripts/benchmark_models.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Benchmark YOLO models for performance on Raspberry Pi 4B (or current machine).
4
+ Measures inference time, FPS, and memory usage for different model formats.
5
+ """
6
+
7
+ import argparse
8
+ import time
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import cv2
14
+ import yaml
15
+ from ultralytics import YOLO
16
+ import psutil
17
+ import platform
18
+
19
+ def get_system_info():
20
+ """Get system information for benchmarking context."""
21
+ info = {
22
+ 'system': platform.system(),
23
+ 'processor': platform.processor(),
24
+ 'architecture': platform.architecture()[0],
25
+ 'python_version': platform.python_version(),
26
+ 'cpu_count': psutil.cpu_count(logical=False),
27
+ 'memory_gb': psutil.virtual_memory().total / (1024**3),
28
+ }
29
+ return info
30
+
31
+ def load_test_images(dataset_path, max_images=50):
32
+ """Load test images from dataset for benchmarking."""
33
+ test_images = []
34
+
35
+ # Try multiple possible locations
36
+ possible_paths = [
37
+ Path(dataset_path) / "test" / "images",
38
+ Path(dataset_path) / "valid" / "images",
39
+ Path(dataset_path) / "val" / "images",
40
+ Path(dataset_path) / "train" / "images",
41
+ ]
42
+
43
+ for path in possible_paths:
44
+ if path.exists():
45
+ image_files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
46
+ if image_files:
47
+ test_images = [str(p) for p in image_files[:max_images]]
48
+ print(f"Found {len(test_images)} images in {path}")
49
+ break
50
+
51
+ if not test_images:
52
+ # Create dummy images if no dataset found
53
+ print("No test images found. Creating dummy images for benchmarking.")
54
+ test_images = []
55
+ for i in range(10):
56
+ # Create a dummy image
57
+ dummy_img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
58
+ dummy_path = f"/tmp/dummy_{i}.jpg"
59
+ cv2.imwrite(dummy_path, dummy_img)
60
+ test_images.append(dummy_path)
61
+
62
+ return test_images
63
+
64
+ def benchmark_model(model_path, test_images, img_size=640, warmup=10, runs=100):
65
+ """
66
+ Benchmark a single model.
67
+
68
+ Args:
69
+ model_path: Path to model file (.pt, .onnx, .tflite)
70
+ test_images: List of image paths for testing
71
+ img_size: Input image size
72
+ warmup: Number of warmup runs
73
+ runs: Number of benchmark runs
74
+
75
+ Returns:
76
+ Dictionary with benchmark results
77
+ """
78
+ print(f"\n{'='*60}")
79
+ print(f"Benchmarking: {model_path}")
80
+ print(f"{'='*60}")
81
+
82
+ results = {
83
+ 'model': os.path.basename(model_path),
84
+ 'format': Path(model_path).suffix[1:],
85
+ 'size_mb': os.path.getsize(model_path) / (1024 * 1024) if os.path.exists(model_path) else 0,
86
+ 'inference_times': [],
87
+ 'memory_usage_mb': [],
88
+ 'success': False
89
+ }
90
+
91
+ # Check if model exists
92
+ if not os.path.exists(model_path):
93
+ print(f" ❌ Model not found: {model_path}")
94
+ return results
95
+
96
+ try:
97
+ # Load model
98
+ print(f" Loading model...")
99
+ start_load = time.time()
100
+ model = YOLO(model_path)
101
+ load_time = time.time() - start_load
102
+ results['load_time'] = load_time
103
+
104
+ # Warmup
105
+ print(f" Warming up ({warmup} runs)...")
106
+ for i in range(warmup):
107
+ if i >= len(test_images):
108
+ img_path = test_images[0]
109
+ else:
110
+ img_path = test_images[i]
111
+ img = cv2.imread(img_path)
112
+ if img is None:
113
+ # Create dummy image
114
+ img = np.random.randint(0, 255, (img_size, img_size, 3), dtype=np.uint8)
115
+ _ = model(img, verbose=False)
116
+
117
+ # Benchmark runs
118
+ print(f" Running benchmark ({runs} runs)...")
119
+ for i in range(runs):
120
+ # Cycle through test images
121
+ img_idx = i % len(test_images)
122
+ img_path = test_images[img_idx]
123
+ img = cv2.imread(img_path)
124
+ if img is None:
125
+ img = np.random.randint(0, 255, (img_size, img_size, 3), dtype=np.uint8)
126
+
127
+ # Measure memory before
128
+ process = psutil.Process(os.getpid())
129
+ mem_before = process.memory_info().rss / 1024 / 1024 # MB
130
+
131
+ # Inference
132
+ start_time = time.perf_counter()
133
+ results_inference = model(img, verbose=False)
134
+ inference_time = time.perf_counter() - start_time
135
+
136
+ # Measure memory after
137
+ mem_after = process.memory_info().rss / 1024 / 1024 # MB
138
+ mem_used = mem_after - mem_before
139
+
140
+ results['inference_times'].append(inference_time)
141
+ results['memory_usage_mb'].append(mem_used)
142
+
143
+ # Print progress
144
+ if (i + 1) % 20 == 0:
145
+ print(f" Completed {i+1}/{runs} runs...")
146
+
147
+ # Calculate statistics
148
+ if results['inference_times']:
149
+ times = np.array(results['inference_times'])
150
+ results['avg_inference_ms'] = np.mean(times) * 1000
151
+ results['std_inference_ms'] = np.std(times) * 1000
152
+ results['min_inference_ms'] = np.min(times) * 1000
153
+ results['max_inference_ms'] = np.max(times) * 1000
154
+ results['fps'] = 1.0 / np.mean(times)
155
+ results['avg_memory_mb'] = np.mean(results['memory_usage_mb'])
156
+ results['success'] = True
157
+
158
+ print(f" ✅ Benchmark completed:")
159
+ print(f" Model size: {results['size_mb']:.2f} MB")
160
+ print(f" Avg inference: {results['avg_inference_ms']:.2f} ms")
161
+ print(f" FPS: {results['fps']:.2f}")
162
+ print(f" Memory usage: {results['avg_memory_mb']:.2f} MB")
163
+ else:
164
+ print(f" ❌ No inference times recorded")
165
+
166
+ except Exception as e:
167
+ print(f" ❌ Error benchmarking {model_path}: {e}")
168
+ import traceback
169
+ traceback.print_exc()
170
+
171
+ return results
172
+
173
+ def benchmark_all_models(models_to_test, test_images, img_size=640):
174
+ """Benchmark multiple models and return results."""
175
+ all_results = []
176
+
177
+ for model_info in models_to_test:
178
+ model_path = model_info['path']
179
+ if not os.path.exists(model_path):
180
+ print(f"Skipping {model_path} - not found")
181
+ continue
182
+
183
+ results = benchmark_model(
184
+ model_path=model_path,
185
+ test_images=test_images,
186
+ img_size=img_size,
187
+ warmup=10,
188
+ runs=50 # Reduced for faster benchmarking
189
+ )
190
+
191
+ results.update({
192
+ 'name': model_info['name'],
193
+ 'description': model_info.get('description', '')
194
+ })
195
+ all_results.append(results)
196
+
197
+ return all_results
198
+
199
+ def print_results_table(results):
200
+ """Print benchmark results in a formatted table."""
201
+ print("\n" + "="*100)
202
+ print("BENCHMARK RESULTS")
203
+ print("="*100)
204
+ print(f"{'Model':<30} {'Format':<8} {'Size (MB)':<10} {'Inference (ms)':<15} {'FPS':<10} {'Memory (MB)':<12} {'Status':<10}")
205
+ print("-"*100)
206
+
207
+ for r in results:
208
+ if r['success']:
209
+ print(f"{r['name'][:28]:<30} {r['format']:<8} {r['size_mb']:>9.2f} "
210
+ f"{r['avg_inference_ms']:>14.2f} {r['fps']:>9.2f} {r['avg_memory_mb']:>11.2f} {'✅':<10}")
211
+ else:
212
+ print(f"{r['name'][:28]:<30} {r['format']:<8} {r['size_mb']:>9.2f} "
213
+ f"{'N/A':>14} {'N/A':>9} {'N/A':>11} {'❌':<10}")
214
+
215
+ print("="*100)
216
+
217
+ # Find best model by FPS
218
+ successful = [r for r in results if r['success']]
219
+ if successful:
220
+ best_by_fps = max(successful, key=lambda x: x['fps'])
221
+ best_by_size = min(successful, key=lambda x: x['size_mb'])
222
+ best_by_memory = min(successful, key=lambda x: x['avg_memory_mb'])
223
+
224
+ print(f"\n🏆 Best by FPS: {best_by_fps['name']} ({best_by_fps['fps']:.2f} FPS)")
225
+ print(f"🏆 Best by size: {best_by_size['name']} ({best_by_size['size_mb']:.2f} MB)")
226
+ print(f"🏆 Best by memory: {best_by_memory['name']} ({best_by_memory['avg_memory_mb']:.2f} MB)")
227
+
228
+ def save_results_to_csv(results, output_path="benchmark_results.csv"):
229
+ """Save benchmark results to CSV file."""
230
+ import csv
231
+
232
+ with open(output_path, 'w', newline='') as csvfile:
233
+ fieldnames = ['name', 'format', 'size_mb', 'avg_inference_ms',
234
+ 'std_inference_ms', 'min_inference_ms', 'max_inference_ms',
235
+ 'fps', 'avg_memory_mb', 'load_time', 'success']
236
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
237
+
238
+ writer.writeheader()
239
+ for r in results:
240
+ writer.writerow({
241
+ 'name': r['name'],
242
+ 'format': r['format'],
243
+ 'size_mb': r.get('size_mb', 0),
244
+ 'avg_inference_ms': r.get('avg_inference_ms', 0),
245
+ 'std_inference_ms': r.get('std_inference_ms', 0),
246
+ 'min_inference_ms': r.get('min_inference_ms', 0),
247
+ 'max_inference_ms': r.get('max_inference_ms', 0),
248
+ 'fps': r.get('fps', 0),
249
+ 'avg_memory_mb': r.get('avg_memory_mb', 0),
250
+ 'load_time': r.get('load_time', 0),
251
+ 'success': r['success']
252
+ })
253
+
254
+ print(f"\n📊 Results saved to {output_path}")
255
+
256
+ def main():
257
+ parser = argparse.ArgumentParser(description='Benchmark YOLO models for performance')
258
+ parser.add_argument('--dataset', type=str, default='model/dataset_strawberry_detect_v3',
259
+ help='Path to dataset for test images')
260
+ parser.add_argument('--img-size', type=int, default=640,
261
+ help='Input image size for inference')
262
+ parser.add_argument('--output', type=str, default='benchmark_results.csv',
263
+ help='Output CSV file for results')
264
+ parser.add_argument('--config', type=str, default='config.yaml',
265
+ help='Path to config file')
266
+
267
+ args = parser.parse_args()
268
+
269
+ # Load config
270
+ config = {}
271
+ if os.path.exists(args.config):
272
+ with open(args.config, 'r') as f:
273
+ config = yaml.safe_load(f)
274
+
275
+ # Get system info
276
+ system_info = get_system_info()
277
+ print("="*60)
278
+ print("SYSTEM INFORMATION")
279
+ print("="*60)
280
+ for key, value in system_info.items():
281
+ print(f"{key.replace('_', ' ').title():<20}: {value}")
282
+
283
+ # Define models to test
284
+ models_to_test = [
285
+ # Base YOLO models
286
+ {'name': 'YOLOv8n', 'path': 'yolov8n.pt', 'description': 'Ultralytics YOLOv8n'},
287
+ {'name': 'YOLOv8s', 'path': 'yolov8s.pt', 'description': 'Ultralytics YOLOv8s'},
288
+ {'name': 'YOLOv8m', 'path': 'yolov8m.pt', 'description': 'Ultralytics YOLOv8m'},
289
+
290
+ # Custom trained models
291
+ {'name': 'Strawberry YOLOv11n', 'path': 'model/weights/strawberry_yolov11n.pt', 'description': 'Custom trained on strawberry dataset'},
292
+ {'name': 'Strawberry YOLOv11n ONNX', 'path': 'model/weights/strawberry_yolov11n.onnx', 'description': 'ONNX export'},
293
+
294
+ # Ripeness detection models
295
+ {'name': 'Ripeness YOLOv11n', 'path': 'model/weights/ripeness_detection_yolov11n.pt', 'description': 'Ripeness detection model'},
296
+ {'name': 'Ripeness YOLOv11n ONNX', 'path': 'model/weights/ripeness_detection_yolov11n.onnx', 'description': 'ONNX export'},
297
+ ]
298
+
299
+ # Check which models exist
300
+ existing_models = []
301
+ for model in models_to_test:
302
+ if os.path.exists(model['path']):
303
+ existing_models.append(model)
304
+ else:
305
+ print(f"⚠️ Model not found: {model['path']}")
306
+
307
+ if not existing_models:
308
+ print("❌ No models found for benchmarking.")
309
+ print("Please train a model first or download pretrained weights.")
310
+ sys.exit(1)
311
+
312
+ # Load test images
313
+ print(f"\n📷 Loading test images from {args.dataset}...")
314
+ test_images = load_test_images(args.dataset, max_images=50)
315
+ print(f" Loaded {len(test_images)} test images")
316
+
317
+ # Run benchmarks
318
+ print(f"\n🚀 Starting benchmarks...")
319
+ results = benchmark_all_models(existing_models, test_images, img_size=args.img_size)
320
+
321
+ # Print results
322
+ print_results_table(results)
323
+
324
+ # Save results
325
+ save_results_to_csv(results, args.output)
326
+
327
+ # Generate recommendations
328
+ print(f"\n💡 RECOMMENDATIONS FOR RASPBERRY PI 4B:")
329
+ print(f" 1. For fastest inference: Choose model with highest FPS")
330
+ print(f" 2. For memory-constrained environments: Choose smallest model")
331
+ print(f" 3. For best accuracy/speed tradeoff: Consider YOLOv8s")
332
+ print(f" 4. For edge deployment: Convert to TFLite INT8 for ~2-3x speedup")
333
+
334
+ # Check if we're on Raspberry Pi
335
+ if 'arm' in platform.machine().lower() or 'raspberry' in platform.system().lower():
336
+ print(f"\n🎯 Running on Raspberry Pi - results are accurate for deployment.")
337
+ else:
338
+ print(f"\n⚠️ Not running on Raspberry Pi - results are for reference only.")
339
+ print(f" Actual Raspberry Pi performance may be 2-5x slower.")
340
+
341
+ if __name__ == '__main__':
342
+ main()
scripts/collect_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+
4
+ cap = cv2.VideoCapture(0)
5
+
6
+ if not cap.isOpened():
7
+ print("Cannot open camera")
8
+ exit()
9
+
10
+ good_dir = 'dataset/good'
11
+ bad_dir = 'dataset/bad'
12
+ os.makedirs(good_dir, exist_ok=True)
13
+ os.makedirs(bad_dir, exist_ok=True)
14
+
15
+ good_count = len(os.listdir(good_dir))
16
+ bad_count = len(os.listdir(bad_dir))
17
+
18
+ print("Press 'g' to save as good, 'b' to save as bad, 'q' to quit")
19
+
20
+ while True:
21
+ ret, frame = cap.read()
22
+ if not ret:
23
+ print("Can't receive frame")
24
+ break
25
+ cv2.imshow('Dataset Collection', frame)
26
+ key = cv2.waitKey(1) & 0xFF
27
+ if key == ord('g'):
28
+ filename = f'{good_dir}/good_{good_count:04d}.jpg'
29
+ cv2.imwrite(filename, frame)
30
+ good_count += 1
31
+ print(f"Saved {filename}")
32
+ elif key == ord('b'):
33
+ filename = f'{bad_dir}/bad_{bad_count:04d}.jpg'
34
+ cv2.imwrite(filename, frame)
35
+ bad_count += 1
36
+ print(f"Saved {filename}")
37
+ elif key == ord('q'):
38
+ break
39
+
40
+ cap.release()
41
+ cv2.destroyAllWindows()
scripts/combine3.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dual_camera_15cm_tuned.py
2
+ """
3
+ FINAL TUNED VERSION
4
+ - Baseline: 15.0 cm
5
+ - Range: 20-40 cm
6
+ - Calibration: Tuned to 0.82 based on your latest log (48.5cm -> 30cm).
7
+ """
8
+ import cv2
9
+ import numpy as np
10
+ from ultralytics import YOLO
11
+ import math
12
+
13
+ # ----------------------------
14
+ # USER SETTINGS
15
+ # ----------------------------
16
+ CAM_A_ID = 2 # LEFT Camera
17
+ CAM_B_ID = 1 # RIGHT Camera
18
+
19
+ FRAME_W = 640
20
+ FRAME_H = 408
21
+ YOLO_MODEL_PATH = "strawberry.pt"
22
+
23
+ # --- GEOMETRY & CALIBRATION ---
24
+ BASELINE_CM = 15.0
25
+ FOCUS_DIST_CM = 30.0
26
+
27
+ # RE-CALIBRATED SCALAR
28
+ # Your Raw Reading is approx 36.6cm. Real is 30.0cm.
29
+ # 30.0 / 36.6 = ~0.82
30
+ DEPTH_SCALAR = 0.82
31
+
32
+ # Auto-calculate angle
33
+ val = (BASELINE_CM / 2.0) / FOCUS_DIST_CM
34
+ calc_yaw_deg = math.degrees(math.atan(val))
35
+ YAW_LEFT_DEG = calc_yaw_deg
36
+ YAW_RIGHT_DEG = -calc_yaw_deg
37
+
38
+ print(f"--- CONFIGURATION ---")
39
+ print(f"1. Baseline: {BASELINE_CM} cm")
40
+ print(f"2. Scalar: x{DEPTH_SCALAR} (Reducing raw estimate to match reality)")
41
+ print(f"3. REQUIRED ANGLE: +/- {calc_yaw_deg:.2f} degrees")
42
+ print(f"---------------------")
43
+
44
+ # Intrinsics
45
+ K_A = np.array([[629.10808758, 0.0, 347.20913144],
46
+ [0.0, 631.11321979, 277.5222819],
47
+ [0.0, 0.0, 1.0]], dtype=np.float64)
48
+ dist_A = np.array([-0.35469562, 0.10232556, -0.0005468, -0.00174671, 0.01546246], dtype=np.float64)
49
+
50
+ K_B = np.array([[1001.67997, 0.0, 367.736216],
51
+ [0.0, 996.698369, 312.866527],
52
+ [0.0, 0.0, 1.0]], dtype=np.float64)
53
+ dist_B = np.array([-0.49543094, 0.82826695, -0.00180861, -0.00362202, -1.42667838], dtype=np.float64)
54
+
55
+ # ----------------------------
56
+ # HELPERS
57
+ # ----------------------------
58
+ def capture_single(cam_id):
59
+ cap = cv2.VideoCapture(cam_id, cv2.CAP_DSHOW)
60
+ if not cap.isOpened(): return None
61
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_W)
62
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_H)
63
+ for _ in range(5): cap.read() # Warmup
64
+ ret, frame = cap.read()
65
+ cap.release()
66
+ return frame
67
+
68
+ def build_undistort_maps(K, dist):
69
+ newK, _ = cv2.getOptimalNewCameraMatrix(K, dist, (FRAME_W, FRAME_H), 1.0)
70
+ mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, newK, (FRAME_W, FRAME_H), cv2.CV_32FC1)
71
+ return mapx, mapy, newK
72
+
73
+ def detect_on_image(model, img):
74
+ results = model(img, verbose=False)[0]
75
+ dets = []
76
+ for box in results.boxes:
77
+ x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
78
+ cx = int((x1 + x2) / 2)
79
+ cy = int((y1 + y2) / 2)
80
+ conf = float(box.conf[0])
81
+ cls = int(box.cls[0])
82
+ name = model.names.get(cls, str(cls))
83
+ dets.append({'x1':x1,'y1':y1,'x2':x2,'y2':y2,'cx':cx,'cy':cy,'conf':conf,'cls':cls,'name':name})
84
+ return sorted(dets, key=lambda d: d['cx'])
85
+
86
+ def match_stereo(detL, detR):
87
+ matches = []
88
+ usedR = set()
89
+ for l in detL:
90
+ best_idx = -1
91
+ best_score = 9999
92
+ for i, r in enumerate(detR):
93
+ if i in usedR: continue
94
+ if l['cls'] != r['cls']: continue
95
+ dy = abs(l['cy'] - r['cy'])
96
+ if dy > 60: continue
97
+ if dy < best_score:
98
+ best_score = dy
99
+ best_idx = i
100
+ if best_idx != -1:
101
+ matches.append((l, detR[best_idx]))
102
+ usedR.add(best_idx)
103
+ return matches
104
+
105
+ # --- 3D MATH ---
106
+ def yaw_to_R_deg(yaw_deg):
107
+ y = math.radians(yaw_deg)
108
+ cy = math.cos(y); sy = math.sin(y)
109
+ return np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float64)
110
+
111
+ def build_projection_matrices(newK_A, newK_B, yaw_L, yaw_R, baseline):
112
+ R_L = yaw_to_R_deg(yaw_L)
113
+ R_R = yaw_to_R_deg(yaw_R)
114
+ R_W2A = R_L.T
115
+ t_W2A = np.zeros((3,1))
116
+ R_W2B = R_R.T
117
+ C_B_world = np.array([[baseline], [0.0], [0.0]])
118
+ t_W2B = -R_W2B @ C_B_world
119
+ P1 = newK_A @ np.hstack((R_W2A, t_W2A))
120
+ P2 = newK_B @ np.hstack((R_W2B, t_W2B))
121
+ return P1, P2
122
+
123
+ def triangulate_matrix(dL, dR, P1, P2):
124
+ ptsL = np.array([[float(dL['cx'])],[float(dL['cy'])]], dtype=np.float64)
125
+ ptsR = np.array([[float(dR['cx'])],[float(dR['cy'])]], dtype=np.float64)
126
+
127
+ Xh = cv2.triangulatePoints(P1, P2, ptsL, ptsR)
128
+ Xh /= Xh[3]
129
+
130
+ X = float(Xh[0].item())
131
+ Y = float(Xh[1].item())
132
+ Z_raw = float(Xh[2].item())
133
+
134
+ return X, Y, Z_raw * DEPTH_SCALAR
135
+
136
+ def main():
137
+ print("[INFO] Loading YOLO...")
138
+ model = YOLO(YOLO_MODEL_PATH)
139
+
140
+ print(f"[INFO] Capturing...")
141
+ frameA = capture_single(CAM_A_ID)
142
+ frameB = capture_single(CAM_B_ID)
143
+ if frameA is None or frameB is None: return
144
+
145
+ mapAx, mapAy, newKA = build_undistort_maps(K_A, dist_A)
146
+ mapBx, mapBy, newKB = build_undistort_maps(K_B, dist_B)
147
+ undA = cv2.remap(frameA, mapAx, mapAy, cv2.INTER_LINEAR)
148
+ undB = cv2.remap(frameB, mapBx, mapBy, cv2.INTER_LINEAR)
149
+
150
+ detA = detect_on_image(model, undA)
151
+ detB = detect_on_image(model, undB)
152
+
153
+ matches = match_stereo(detA, detB)
154
+ print(f"--- Matches found: {len(matches)} ---")
155
+
156
+ P1, P2 = build_projection_matrices(newKA, newKB, YAW_LEFT_DEG, YAW_RIGHT_DEG, BASELINE_CM)
157
+
158
+ combo = np.hstack((undA, undB))
159
+
160
+ for l, r in matches:
161
+ XYZ = triangulate_matrix(l, r, P1, P2)
162
+ X,Y,Z = XYZ
163
+
164
+ label = f"Z={Z:.1f}cm"
165
+ print(f"Target ({l['name']}): {label} (X={X:.1f}, Y={Y:.1f})")
166
+
167
+ cv2.line(combo, (l['cx'], l['cy']), (r['cx']+FRAME_W, r['cy']), (0,255,0), 2)
168
+ cv2.rectangle(combo, (l['x1'], l['y1']), (l['x2'], l['y2']), (0,0,255), 2)
169
+ cv2.rectangle(combo, (r['x1']+FRAME_W, r['y1']), (r['x2']+FRAME_W, r['y2']), (0,0,255), 2)
170
+ cv2.putText(combo, label, (l['cx'], l['cy']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,255), 2)
171
+
172
+ cv2.imshow("Tuned Depth Result", combo)
173
+ cv2.waitKey(0)
174
+ cv2.destroyAllWindows()
175
+
176
+ if __name__ == "__main__":
177
+ main()
scripts/complete_final_labeling.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Complete the final batch of ripeness labeling using conservative color analysis.
4
+ This script processes the remaining 46 images with higher confidence thresholds.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import shutil
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ def analyze_ripeness_conservative(image_path, confidence_threshold=0.8):
17
+ """
18
+ Conservative ripeness analysis with higher confidence thresholds.
19
+ """
20
+ try:
21
+ # Load and convert image
22
+ img = cv2.imread(str(image_path))
23
+ if img is None:
24
+ return None, 0.0
25
+
26
+ # Convert to HSV for better color analysis
27
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
28
+
29
+ # Define color ranges for strawberry ripeness (more conservative)
30
+ # Red ranges (ripe strawberries)
31
+ red_lower1 = np.array([0, 50, 50])
32
+ red_upper1 = np.array([10, 255, 255])
33
+ red_lower2 = np.array([170, 50, 50])
34
+ red_upper2 = np.array([180, 255, 255])
35
+
36
+ # Green ranges (unripe strawberries)
37
+ green_lower = np.array([40, 40, 40])
38
+ green_upper = np.array([80, 255, 255])
39
+
40
+ # Yellow/orange ranges (overripe strawberries)
41
+ yellow_lower = np.array([15, 50, 50])
42
+ yellow_upper = np.array([35, 255, 255])
43
+
44
+ # Create masks
45
+ red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
46
+ red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
47
+ red_mask = cv2.bitwise_or(red_mask1, red_mask2)
48
+
49
+ green_mask = cv2.inRange(hsv, green_lower, green_upper)
50
+ yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
51
+
52
+ # Calculate percentages
53
+ total_pixels = img.shape[0] * img.shape[1]
54
+ red_percentage = np.sum(red_mask > 0) / total_pixels
55
+ green_percentage = np.sum(green_mask > 0) / total_pixels
56
+ yellow_percentage = np.sum(yellow_mask > 0) / total_pixels
57
+
58
+ # Conservative classification logic
59
+ if red_percentage > 0.35 and red_percentage > green_percentage and red_percentage > yellow_percentage:
60
+ return "ripe", red_percentage
61
+ elif green_percentage > 0.25 and green_percentage > red_percentage and green_percentage > yellow_percentage:
62
+ return "unripe", green_percentage
63
+ elif yellow_percentage > 0.20 and yellow_percentage > red_percentage and yellow_percentage > green_percentage:
64
+ return "overripe", yellow_percentage
65
+ else:
66
+ # If no clear dominant color, use the highest percentage
67
+ max_percentage = max(red_percentage, green_percentage, yellow_percentage)
68
+ if max_percentage == red_percentage:
69
+ return "ripe", red_percentage
70
+ elif max_percentage == green_percentage:
71
+ return "unripe", green_percentage
72
+ else:
73
+ return "overripe", yellow_percentage
74
+
75
+ except Exception as e:
76
+ print(f"Error analyzing {image_path}: {e}")
77
+ return None, 0.0
78
+
79
+ def main():
80
+ """Complete the final batch of labeling."""
81
+
82
+ # Paths
83
+ to_label_dir = Path("model/ripeness_manual_dataset/to_label")
84
+ unripe_dir = Path("model/ripeness_manual_dataset/unripe")
85
+ ripe_dir = Path("model/ripeness_manual_dataset/ripe")
86
+ overripe_dir = Path("model/ripeness_manual_dataset/overripe")
87
+
88
+ # Get remaining files
89
+ remaining_files = list(to_label_dir.glob("*.jpg"))
90
+
91
+ if not remaining_files:
92
+ print("No images remaining to label!")
93
+ return
94
+
95
+ print(f"=== FINAL BATCH LABELING ===")
96
+ print(f"Processing {len(remaining_files)} remaining images with conservative analysis...")
97
+
98
+ results = {
99
+ "timestamp": datetime.now().isoformat(),
100
+ "total_processed": len(remaining_files),
101
+ "unripe": 0,
102
+ "ripe": 0,
103
+ "overripe": 0,
104
+ "unknown": 0,
105
+ "images": []
106
+ }
107
+
108
+ for i, image_path in enumerate(remaining_files, 1):
109
+ print(f"Processing {i}/{len(remaining_files)}: {image_path.name}")
110
+
111
+ # Analyze with conservative threshold
112
+ label, confidence = analyze_ripeness_conservative(image_path, confidence_threshold=0.8)
113
+
114
+ if label:
115
+ # Move to appropriate directory
116
+ if label == "unripe":
117
+ dest = unripe_dir / image_path.name
118
+ results["unripe"] += 1
119
+ elif label == "ripe":
120
+ dest = ripe_dir / image_path.name
121
+ results["ripe"] += 1
122
+ elif label == "overripe":
123
+ dest = overripe_dir / image_path.name
124
+ results["overripe"] += 1
125
+
126
+ try:
127
+ shutil.move(str(image_path), str(dest))
128
+ print(f" ✅ {label} (confidence: {confidence:.2f})")
129
+ results["images"].append({
130
+ "filename": image_path.name,
131
+ "label": label,
132
+ "confidence": confidence
133
+ })
134
+ except Exception as e:
135
+ print(f" ❌ Error moving file: {e}")
136
+ results["unknown"] += 1
137
+ else:
138
+ print(f" ⚠️ Analysis failed")
139
+ results["unknown"] += 1
140
+
141
+ # Save results
142
+ results_file = f"model/ripeness_manual_dataset/final_labeling_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
143
+ with open(results_file, 'w') as f:
144
+ json.dump(results, f, indent=2)
145
+
146
+ # Print final summary
147
+ print(f"\n=== FINAL LABELING COMPLETE ===")
148
+ print(f"unripe: {results['unripe']} images")
149
+ print(f"ripe: {results['ripe']} images")
150
+ print(f"overripe: {results['overripe']} images")
151
+ print(f"unknown: {results['unknown']} images")
152
+ print(f"Total processed: {results['total_processed']} images")
153
+
154
+ # Calculate final dataset statistics
155
+ total_unripe = len(list(unripe_dir.glob("*.jpg")))
156
+ total_ripe = len(list(ripe_dir.glob("*.jpg")))
157
+ total_overripe = len(list(overripe_dir.glob("*.jpg")))
158
+ remaining = len(list(to_label_dir.glob("*.jpg")))
159
+ total_dataset = total_unripe + total_ripe + total_overripe + remaining
160
+
161
+ completion_percentage = (total_dataset - remaining) / total_dataset * 100
162
+
163
+ print(f"\n=== FINAL DATASET STATUS ===")
164
+ print(f"unripe: {total_unripe} images")
165
+ print(f"ripe: {total_ripe} images")
166
+ print(f"overripe: {total_overripe} images")
167
+ print(f"to_label: {remaining} images")
168
+ print(f"TOTAL: {total_dataset} images")
169
+ print(f"Completion: {completion_percentage:.1f}%")
170
+
171
+ if remaining == 0:
172
+ print(f"\n🎉 DATASET LABELING 100% COMPLETE! 🎉")
173
+ print(f"Total labeled images: {total_dataset}")
174
+ else:
175
+ print(f"\n⚠️ {remaining} images still need manual review")
176
+
177
+ print(f"Results saved to: {results_file}")
178
+
179
+ if __name__ == "__main__":
180
+ main()
scripts/convert_tflite.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert trained model to TensorFlow Lite format with optional quantization.
4
+ Supports conversion from Keras (.h5) and PyTorch (.pt) models.
5
+ """
6
+
7
+ import argparse
8
+ import tensorflow as tf
9
+ import os
10
+ from pathlib import Path
11
+ import sys
12
+
13
+ def convert_keras_to_tflite(input_path, output_path, quantization=None, optimize_for_size=False):
14
+ """Convert Keras model to TFLite format."""
15
+ if not os.path.exists(input_path):
16
+ raise FileNotFoundError(f"Input model not found: {input_path}")
17
+
18
+ print(f"Loading Keras model from {input_path}...")
19
+ model = tf.keras.models.load_model(input_path)
20
+
21
+ # Configure converter
22
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
23
+
24
+ # Optimization options
25
+ if optimize_for_size:
26
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
27
+
28
+ # Quantization options
29
+ if quantization == 'int8':
30
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
31
+ converter.target_spec.supported_types = [tf.int8]
32
+ converter.inference_input_type = tf.int8
33
+ converter.inference_output_type = tf.int8
34
+ print("Using INT8 quantization")
35
+ elif quantization == 'float16':
36
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
37
+ converter.target_spec.supported_types = [tf.float16]
38
+ print("Using Float16 quantization")
39
+ elif quantization == 'dynamic_range':
40
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
41
+ print("Using dynamic range quantization")
42
+
43
+ # Convert model
44
+ print("Converting model to TFLite...")
45
+ tflite_model = converter.convert()
46
+
47
+ # Save model
48
+ with open(output_path, 'wb') as f:
49
+ f.write(tflite_model)
50
+
51
+ print(f"TFLite model saved to {output_path}")
52
+
53
+ # Print model size
54
+ size_kb = os.path.getsize(output_path) / 1024
55
+ print(f"Model size: {size_kb:.2f} KB")
56
+
57
+ return output_path
58
+
59
+ def convert_pytorch_to_tflite(input_path, output_path, quantization=None):
60
+ """Convert PyTorch model to TFLite format (placeholder for future implementation)."""
61
+ print("PyTorch to TFLite conversion not yet implemented.")
62
+ print("Please convert PyTorch model to ONNX first, then use TensorFlow's converter.")
63
+ return None
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser(description='Convert model to TensorFlow Lite format')
67
+ parser.add_argument('--input', type=str, default='strawberry_model.h5',
68
+ help='Input model path (Keras .h5 or PyTorch .pt)')
69
+ parser.add_argument('--output', type=str, default='strawberry_model.tflite',
70
+ help='Output TFLite model path')
71
+ parser.add_argument('--quantization', type=str, choices=['int8', 'float16', 'dynamic_range', 'none'],
72
+ default='none', help='Quantization method (default: none)')
73
+ parser.add_argument('--optimize-for-size', action='store_true',
74
+ help='Apply size optimization (reduces model size)')
75
+ parser.add_argument('--model-type', type=str, choices=['keras', 'pytorch'], default='keras',
76
+ help='Type of input model (default: keras)')
77
+
78
+ args = parser.parse_args()
79
+
80
+ # Validate input file exists
81
+ if not Path(args.input).exists():
82
+ print(f"Error: Input model '{args.input}' not found.")
83
+ print("Available model files in current directory:")
84
+ for f in os.listdir('.'):
85
+ if f.endswith(('.h5', '.pt', '.pth', '.onnx')):
86
+ print(f" - {f}")
87
+ sys.exit(1)
88
+
89
+ # Create output directory if needed
90
+ output_dir = os.path.dirname(args.output)
91
+ if output_dir and not os.path.exists(output_dir):
92
+ os.makedirs(output_dir, exist_ok=True)
93
+
94
+ try:
95
+ if args.model_type == 'keras':
96
+ convert_keras_to_tflite(
97
+ input_path=args.input,
98
+ output_path=args.output,
99
+ quantization=args.quantization if args.quantization != 'none' else None,
100
+ optimize_for_size=args.optimize_for_size
101
+ )
102
+ elif args.model_type == 'pytorch':
103
+ convert_pytorch_to_tflite(
104
+ input_path=args.input,
105
+ output_path=args.output,
106
+ quantization=args.quantization if args.quantization != 'none' else None
107
+ )
108
+ else:
109
+ print(f"Unsupported model type: {args.model_type}")
110
+ sys.exit(1)
111
+
112
+ print("Conversion completed successfully!")
113
+
114
+ except Exception as e:
115
+ print(f"Error during conversion: {e}")
116
+ sys.exit(1)
117
+
118
+ if __name__ == '__main__':
119
+ main()
scripts/data/preprocess_strawberry_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ class_name = 'strawberry'
4
+
5
+ def rewrite_single_class_data_yaml(dataset_dir, class_name='strawberry'):
6
+ dataset_dir = Path(dataset_dir)
7
+ data_yaml_path = dataset_dir / 'data.yaml'
8
+ if not data_yaml_path.exists():
9
+ print('⚠️ data.yaml not found, skipping rewrite.')
10
+ return
11
+
12
+ train_path = dataset_dir / 'train' / 'images'
13
+ val_path = dataset_dir / 'valid' / 'images'
14
+ test_path = dataset_dir / 'test' / 'images'
15
+
16
+ content_lines = [
17
+ '# Strawberry-only dataset',
18
+ f'train: {train_path}',
19
+ f'val: {val_path}',
20
+ f"test: {test_path if test_path.exists() else ''}",
21
+ '',
22
+ 'nc: 1',
23
+ f"names: ['{class_name}']",
24
+ ]
25
+
26
+ data_yaml_path.write_text('\n'.join(content_lines) + '\n')
27
+ print(f"✅ data.yaml updated for single-class training ({class_name}).")
28
+
29
+
30
+ def enforce_single_class_dataset(dataset_dir, target_class=0, class_name='strawberry'):
31
+ dataset_dir = Path(dataset_dir)
32
+ stats = {'split_kept': {}, 'labels_removed': 0, 'images_removed': 0}
33
+ allowed_ext = ['.jpg', '.jpeg', '.png']
34
+
35
+ for split in ['train', 'valid', 'test']:
36
+ labels_dir = dataset_dir / split / 'labels'
37
+ images_dir = dataset_dir / split / 'images'
38
+ if not labels_dir.exists():
39
+ continue
40
+
41
+ kept = 0
42
+ for label_path in labels_dir.glob('*.txt'):
43
+ kept_lines = []
44
+ for raw_line in label_path.read_text().splitlines():
45
+ line = raw_line.strip()
46
+ if not line:
47
+ continue
48
+ parts = line.split()
49
+ if not parts:
50
+ continue
51
+ try:
52
+ class_id = int(parts[0])
53
+ except ValueError:
54
+ continue
55
+ if class_id == target_class:
56
+ kept_lines.append(line)
57
+
58
+ if kept_lines:
59
+ label_path.write_text('\n'.join(kept_lines) + '\n')
60
+ kept += len(kept_lines)
61
+ else:
62
+ label_path.unlink()
63
+ stats['labels_removed'] += 1
64
+ for ext in allowed_ext:
65
+ candidate = images_dir / f"{label_path.stem}{ext}"
66
+ if candidate.exists():
67
+ candidate.unlink()
68
+ stats['images_removed'] += 1
69
+ break
70
+
71
+ stats['split_kept'][split] = kept
72
+
73
+ rewrite_single_class_data_yaml(dataset_dir, class_name)
74
+
75
+ print('\n🍓 Strawberry-only filtering summary:')
76
+ for split, count in stats['split_kept'].items():
77
+ print(f" {split}: {count} annotations kept")
78
+ print(f" Label files removed: {stats['labels_removed']}")
79
+ print(f" Images removed (non-strawberry or empty labels): {stats['images_removed']}")
80
+
81
+ return stats
82
+
83
+
84
+ if __name__ == "__main__":
85
+ # Example usage - replace with your dataset path
86
+ dataset_path = "path/to/your/dataset" # Update this path
87
+
88
+ if dataset_path and Path(dataset_path).exists():
89
+ strawberry_stats = enforce_single_class_dataset(dataset_path, target_class=0, class_name=class_name)
90
+ else:
91
+ print("Please set a valid dataset_path variable.")
scripts/detect_realtime.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Real-time strawberry detection/classification using TFLite model.
4
+ Supports both binary classification (good/bad) and YOLOv8 detection.
5
+ """
6
+
7
+ import argparse
8
+ import cv2
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from pathlib import Path
12
+ import sys
13
+
14
+ def load_tflite_model(model_path):
15
+ """Load TFLite model and allocate tensors."""
16
+ if not Path(model_path).exists():
17
+ raise FileNotFoundError(f"Model file not found: {model_path}")
18
+
19
+ interpreter = tf.lite.Interpreter(model_path=model_path)
20
+ interpreter.allocate_tensors()
21
+ return interpreter
22
+
23
+ def get_model_details(interpreter):
24
+ """Get input and output details of the TFLite model."""
25
+ input_details = interpreter.get_input_details()
26
+ output_details = interpreter.get_output_details()
27
+ return input_details, output_details
28
+
29
+ def preprocess_image(image, input_shape):
30
+ """Preprocess image for model inference."""
31
+ height, width = input_shape[1:3] if len(input_shape) == 4 else input_shape[1:3]
32
+ img = cv2.resize(image, (width, height))
33
+ img = img / 255.0 # Normalize to [0,1]
34
+ img = np.expand_dims(img, axis=0).astype(np.float32)
35
+ return img
36
+
37
+ def run_inference(interpreter, input_details, output_details, preprocessed_img):
38
+ """Run inference on preprocessed image."""
39
+ interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
40
+ interpreter.invoke()
41
+ return interpreter.get_tensor(output_details[0]['index'])
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser(description='Real-time strawberry detection/classification')
45
+ parser.add_argument('--model', type=str, default='strawberry_model.tflite',
46
+ help='Path to TFLite model (default: strawberry_model.tflite)')
47
+ parser.add_argument('--camera', type=int, default=0,
48
+ help='Camera index (default: 0)')
49
+ parser.add_argument('--threshold', type=float, default=0.5,
50
+ help='Confidence threshold for binary classification (default: 0.5)')
51
+ parser.add_argument('--input-size', type=int, default=224,
52
+ help='Input image size (width=height) for model (default: 224)')
53
+ parser.add_argument('--mode', choices=['classification', 'detection'], default='classification',
54
+ help='Inference mode: classification (good/bad) or detection (YOLO)')
55
+ parser.add_argument('--verbose', action='store_true',
56
+ help='Print detailed inference information')
57
+
58
+ args = parser.parse_args()
59
+
60
+ # Load model
61
+ try:
62
+ interpreter = load_tflite_model(args.model)
63
+ input_details, output_details = get_model_details(interpreter)
64
+ input_shape = input_details[0]['shape']
65
+ if args.verbose:
66
+ print(f"Model loaded: {args.model}")
67
+ print(f"Input shape: {input_shape}")
68
+ print(f"Output details: {output_details[0]}")
69
+ except Exception as e:
70
+ print(f"Error loading model: {e}")
71
+ sys.exit(1)
72
+
73
+ # Open camera
74
+ cap = cv2.VideoCapture(args.camera)
75
+ if not cap.isOpened():
76
+ print(f"Cannot open camera index {args.camera}")
77
+ sys.exit(1)
78
+
79
+ print(f"Starting real-time inference (mode: {args.mode})")
80
+ print("Press 'q' to quit, 's' to save current frame")
81
+
82
+ while True:
83
+ ret, frame = cap.read()
84
+ if not ret:
85
+ print("Failed to capture frame")
86
+ break
87
+
88
+ # Preprocess
89
+ preprocessed = preprocess_image(frame, input_shape)
90
+
91
+ # Inference
92
+ predictions = run_inference(interpreter, input_details, output_details, preprocessed)
93
+
94
+ # Process predictions based on mode
95
+ if args.mode == 'classification':
96
+ # Binary classification: single probability
97
+ confidence = predictions[0][0]
98
+ label = 'Good' if confidence > args.threshold else 'Bad'
99
+ display_text = f'{label}: {confidence:.2f}'
100
+ color = (0, 255, 0) if confidence > args.threshold else (0, 0, 255)
101
+ else:
102
+ # Detection mode (YOLO) - placeholder for future implementation
103
+ display_text = 'Detection mode not yet implemented'
104
+ color = (255, 255, 0)
105
+
106
+ # Display
107
+ cv2.putText(frame, display_text, (10, 30),
108
+ cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
109
+ cv2.imshow('Strawberry Detection', frame)
110
+
111
+ key = cv2.waitKey(1) & 0xFF
112
+ if key == ord('q'):
113
+ break
114
+ elif key == ord('s'):
115
+ filename = f'capture_{cv2.getTickCount()}.jpg'
116
+ cv2.imwrite(filename, frame)
117
+ print(f"Frame saved as {filename}")
118
+
119
+ cap.release()
120
+ cv2.destroyAllWindows()
121
+ print("Real-time detection stopped.")
122
+
123
+ if __name__ == '__main__':
124
+ main()
scripts/download_dataset.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from roboflow import Roboflow
2
+
3
+ # Replace with your Roboflow API key and project details
4
+ rf = Roboflow(api_key="YOUR_API_KEY")
5
+ project = rf.workspace("YOUR_WORKSPACE").project("YOUR_PROJECT")
6
+ dataset = project.version("YOUR_VERSION").download("folder")
7
+
8
+ print("Dataset downloaded to dataset/")
scripts/export_onnx.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export YOLOv8/v11 model to ONNX format for optimized inference.
4
+ Supports dynamic axes, batch size, and different opset versions.
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+ import yaml
12
+ from ultralytics import YOLO
13
+
14
+ def load_config(config_path="config.yaml"):
15
+ """Load configuration from YAML file."""
16
+ if not os.path.exists(config_path):
17
+ print(f"Warning: Config file {config_path} not found. Using defaults.")
18
+ return {}
19
+ with open(config_path, 'r') as f:
20
+ return yaml.safe_load(f)
21
+
22
+ def export_to_onnx(
23
+ model_path,
24
+ output_path=None,
25
+ imgsz=640,
26
+ batch=1,
27
+ dynamic=False,
28
+ simplify=True,
29
+ opset=12,
30
+ half=False
31
+ ):
32
+ """
33
+ Export YOLO model to ONNX format.
34
+
35
+ Args:
36
+ model_path: Path to .pt model file
37
+ output_path: Output .onnx file path (optional, auto-generated if None)
38
+ imgsz: Input image size
39
+ batch: Batch size (1 for static, -1 for dynamic)
40
+ dynamic: Enable dynamic axes (batch, height, width)
41
+ simplify: Apply ONNX simplifier
42
+ opset: ONNX opset version
43
+ half: FP16 quantization
44
+ """
45
+ print(f"Loading model from {model_path}")
46
+ model = YOLO(model_path)
47
+
48
+ # Determine output path
49
+ if output_path is None:
50
+ model_name = Path(model_path).stem
51
+ output_dir = Path(model_path).parent / "exports"
52
+ output_dir.mkdir(exist_ok=True)
53
+ output_path = str(output_dir / f"{model_name}.onnx")
54
+
55
+ # Create output directory if needed
56
+ output_dir = os.path.dirname(output_path)
57
+ if output_dir and not os.path.exists(output_dir):
58
+ os.makedirs(output_dir, exist_ok=True)
59
+
60
+ # Prepare export arguments
61
+ export_args = {
62
+ 'format': 'onnx',
63
+ 'imgsz': imgsz,
64
+ 'batch': batch,
65
+ 'simplify': simplify,
66
+ 'opset': opset,
67
+ 'half': half,
68
+ }
69
+
70
+ if dynamic:
71
+ export_args['dynamic'] = True
72
+
73
+ print(f"Exporting to ONNX with args: {export_args}")
74
+
75
+ try:
76
+ # Export model
77
+ exported_path = model.export(**export_args)
78
+
79
+ # The exported file will be in the same directory as the model
80
+ # Find the .onnx file that was just created
81
+ exported_files = list(Path(model_path).parent.glob("*.onnx"))
82
+ if exported_files:
83
+ latest_onnx = max(exported_files, key=os.path.getctime)
84
+ # Move to desired output path if different
85
+ if str(latest_onnx) != output_path:
86
+ import shutil
87
+ shutil.move(str(latest_onnx), output_path)
88
+ print(f"Model moved to {output_path}")
89
+ else:
90
+ print(f"Model exported to {output_path}")
91
+ else:
92
+ # Try to find the exported file in the current directory
93
+ exported_files = list(Path('.').glob("*.onnx"))
94
+ if exported_files:
95
+ latest_onnx = max(exported_files, key=os.path.getctime)
96
+ if str(latest_onnx) != output_path:
97
+ import shutil
98
+ shutil.move(str(latest_onnx), output_path)
99
+ print(f"Model moved to {output_path}")
100
+ else:
101
+ print(f"Model exported to {output_path}")
102
+ else:
103
+ print(f"Warning: Could not locate exported ONNX file.")
104
+ print(f"Expected at: {output_path}")
105
+ return None
106
+
107
+ # Print model info
108
+ size_mb = os.path.getsize(output_path) / (1024 * 1024)
109
+ print(f"✅ ONNX export successful!")
110
+ print(f" Output: {output_path}")
111
+ print(f" Size: {size_mb:.2f} MB")
112
+ print(f" Input shape: {batch if batch > 0 else 'dynamic'}x3x{imgsz}x{imgsz}")
113
+ print(f" Opset: {opset}")
114
+ print(f" Dynamic: {dynamic}")
115
+ print(f" FP16: {half}")
116
+
117
+ return output_path
118
+
119
+ except Exception as e:
120
+ print(f"❌ Error during ONNX export: {e}")
121
+ import traceback
122
+ traceback.print_exc()
123
+ return None
124
+
125
+ def main():
126
+ parser = argparse.ArgumentParser(description='Export YOLO model to ONNX format')
127
+ parser.add_argument('--model', type=str, default='yolov8n.pt',
128
+ help='Path to YOLO model (.pt file)')
129
+ parser.add_argument('--output', type=str, default=None,
130
+ help='Output ONNX file path (default: model/exports/<model_name>.onnx)')
131
+ parser.add_argument('--img-size', type=int, default=640,
132
+ help='Input image size (default: 640)')
133
+ parser.add_argument('--batch', type=int, default=1,
134
+ help='Batch size (default: 1, use -1 for dynamic)')
135
+ parser.add_argument('--dynamic', action='store_true',
136
+ help='Enable dynamic axes (batch, height, width)')
137
+ parser.add_argument('--no-simplify', action='store_true',
138
+ help='Disable ONNX simplifier')
139
+ parser.add_argument('--opset', type=int, default=12,
140
+ help='ONNX opset version (default: 12)')
141
+ parser.add_argument('--half', action='store_true',
142
+ help='Use FP16 quantization')
143
+ parser.add_argument('--config', type=str, default='config.yaml',
144
+ help='Path to config file')
145
+
146
+ args = parser.parse_args()
147
+
148
+ # Load config
149
+ config = load_config(args.config)
150
+
151
+ # Use model from config if not specified
152
+ if args.model == 'yolov8n.pt' and config:
153
+ models_config = config.get('models', {})
154
+ detection_config = models_config.get('detection', {})
155
+ default_model = detection_config.get('strawberry_yolov8n', 'yolov8n.pt')
156
+ if os.path.exists(default_model):
157
+ args.model = default_model
158
+ else:
159
+ # Check for other available models
160
+ available_models = ['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt',
161
+ 'model/weights/strawberry_yolov11n.pt',
162
+ 'model/weights/ripeness_detection_yolov11n.pt']
163
+ for model in available_models:
164
+ if os.path.exists(model):
165
+ args.model = model
166
+ print(f"Using available model: {model}")
167
+ break
168
+
169
+ # Export model
170
+ success = export_to_onnx(
171
+ model_path=args.model,
172
+ output_path=args.output,
173
+ imgsz=args.img_size,
174
+ batch=args.batch,
175
+ dynamic=args.dynamic,
176
+ simplify=not args.no_simplify,
177
+ opset=args.opset,
178
+ half=args.half
179
+ )
180
+
181
+ if success:
182
+ print(f"\n✅ Export completed successfully!")
183
+ print(f"\nNext steps:")
184
+ print(f"1. Test the ONNX model with ONNX Runtime:")
185
+ print(f" python -m onnxruntime.tools.onnx_model_test {success}")
186
+ print(f"2. Convert to TensorFlow Lite for edge deployment:")
187
+ print(f" python export_tflite_int8.py --model {success}")
188
+ print(f"3. Use in your application with ONNX Runtime")
189
+ else:
190
+ print("\n❌ Export failed.")
191
+ sys.exit(1)
192
+
193
+ if __name__ == '__main__':
194
+ main()
scripts/export_tflite_int8.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export YOLOv8 model to TensorFlow Lite with INT8 quantization.
4
+ Uses Ultralytics YOLOv8 export functionality with calibration dataset.
5
+ """
6
+
7
+ import argparse
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+ import yaml
12
+ import numpy as np
13
+ from ultralytics import YOLO
14
+
15
+ def load_config(config_path="config.yaml"):
16
+ """Load configuration from YAML file."""
17
+ if not os.path.exists(config_path):
18
+ print(f"Warning: Config file {config_path} not found. Using defaults.")
19
+ return {}
20
+ with open(config_path, 'r') as f:
21
+ return yaml.safe_load(f)
22
+
23
+ def get_representative_dataset(dataset_path, num_calibration=100):
24
+ """
25
+ Create representative dataset for INT8 calibration.
26
+ Returns a generator that yields normalized images.
27
+ """
28
+ import cv2
29
+ from pathlib import Path
30
+
31
+ # Find validation images
32
+ val_path = Path(dataset_path) / "valid" / "images"
33
+ if not val_path.exists():
34
+ val_path = Path(dataset_path) / "val" / "images"
35
+ if not val_path.exists():
36
+ print(f"Warning: Validation images not found at {val_path}")
37
+ return None
38
+
39
+ image_files = list(val_path.glob("*.jpg")) + list(val_path.glob("*.png"))
40
+ if len(image_files) == 0:
41
+ print("No validation images found for calibration.")
42
+ return None
43
+
44
+ # Limit to num_calibration
45
+ image_files = image_files[:num_calibration]
46
+
47
+ def representative_dataset():
48
+ for img_path in image_files:
49
+ # Load and preprocess image
50
+ img = cv2.imread(str(img_path))
51
+ if img is None:
52
+ continue
53
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54
+ img = cv2.resize(img, (640, 640))
55
+ img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
56
+ img = np.expand_dims(img, axis=0) # Add batch dimension
57
+ yield [img]
58
+
59
+ return representative_dataset
60
+
61
+ def export_yolov8_to_tflite_int8(
62
+ model_path,
63
+ output_path,
64
+ dataset_path=None,
65
+ img_size=640,
66
+ int8=True,
67
+ dynamic=False,
68
+ half=False
69
+ ):
70
+ """
71
+ Export YOLOv8 model to TFLite format with optional INT8 quantization.
72
+
73
+ Args:
74
+ model_path: Path to YOLOv8 .pt model
75
+ output_path: Output .tflite file path
76
+ dataset_path: Path to dataset for INT8 calibration
77
+ img_size: Input image size
78
+ int8: Enable INT8 quantization
79
+ dynamic: Enable dynamic range quantization (alternative to INT8)
80
+ half: Enable FP16 quantization
81
+ """
82
+ print(f"Loading YOLOv8 model from {model_path}")
83
+ model = YOLO(model_path)
84
+
85
+ # Check if model is a detection model
86
+ task = model.task if hasattr(model, 'task') else 'detect'
87
+ print(f"Model task: {task}")
88
+
89
+ # Prepare export arguments
90
+ export_args = {
91
+ 'format': 'tflite',
92
+ 'imgsz': img_size,
93
+ 'optimize': True,
94
+ 'int8': int8,
95
+ 'half': half,
96
+ 'dynamic': dynamic,
97
+ }
98
+
99
+ # If INT8 quantization is requested, provide representative dataset
100
+ if int8 and dataset_path:
101
+ print(f"Using dataset at {dataset_path} for INT8 calibration")
102
+ representative_dataset = get_representative_dataset(dataset_path)
103
+ if representative_dataset:
104
+ # Note: Ultralytics YOLOv8 export doesn't directly accept representative_dataset
105
+ # We'll need to use a different approach
106
+ print("INT8 calibration with representative dataset requires custom implementation.")
107
+ print("Falling back to Ultralytics built-in INT8 calibration...")
108
+ # Use built-in calibration images
109
+ export_args['int8'] = True
110
+ else:
111
+ print("Warning: No representative dataset available. Using default calibration.")
112
+ export_args['int8'] = True
113
+ elif int8:
114
+ print("Using built-in calibration images for INT8 quantization")
115
+ export_args['int8'] = True
116
+
117
+ # Export model
118
+ print(f"Exporting model to TFLite with args: {export_args}")
119
+ try:
120
+ # Use Ultralytics export
121
+ exported_path = model.export(**export_args)
122
+
123
+ # The exported file will be in the same directory as the model
124
+ # with a .tflite extension
125
+ exported_files = list(Path(model_path).parent.glob("*.tflite"))
126
+ if exported_files:
127
+ latest_tflite = max(exported_files, key=os.path.getctime)
128
+ # Move to desired output path
129
+ import shutil
130
+ shutil.move(str(latest_tflite), output_path)
131
+ print(f"Model exported to {output_path}")
132
+
133
+ # Print model size
134
+ size_mb = os.path.getsize(output_path) / (1024 * 1024)
135
+ print(f"Model size: {size_mb:.2f} MB")
136
+
137
+ return output_path
138
+ else:
139
+ print("Error: No .tflite file was generated")
140
+ return None
141
+
142
+ except Exception as e:
143
+ print(f"Error during export: {e}")
144
+ return None
145
+
146
+ def main():
147
+ parser = argparse.ArgumentParser(description='Export YOLOv8 model to TFLite with INT8 quantization')
148
+ parser.add_argument('--model', type=str, default='yolov8n.pt',
149
+ help='Path to YOLOv8 model (.pt file)')
150
+ parser.add_argument('--output', type=str, default='model/exports/strawberry_yolov8n_int8.tflite',
151
+ help='Output TFLite file path')
152
+ parser.add_argument('--dataset', type=str, default='model/dataset_strawberry_detect_v3',
153
+ help='Path to dataset for INT8 calibration')
154
+ parser.add_argument('--img-size', type=int, default=640,
155
+ help='Input image size (default: 640)')
156
+ parser.add_argument('--no-int8', action='store_true',
157
+ help='Disable INT8 quantization')
158
+ parser.add_argument('--dynamic', action='store_true',
159
+ help='Use dynamic range quantization')
160
+ parser.add_argument('--half', action='store_true',
161
+ help='Use FP16 quantization')
162
+ parser.add_argument('--config', type=str, default='config.yaml',
163
+ help='Path to config file')
164
+
165
+ args = parser.parse_args()
166
+
167
+ # Load config
168
+ config = load_config(args.config)
169
+
170
+ # Use dataset path from config if not provided
171
+ if args.dataset is None and config:
172
+ args.dataset = config.get('dataset', {}).get('detection', {}).get('path', 'model/dataset_strawberry_detect_v3')
173
+
174
+ # Create output directory if it doesn't exist
175
+ output_dir = os.path.dirname(args.output)
176
+ if output_dir and not os.path.exists(output_dir):
177
+ os.makedirs(output_dir, exist_ok=True)
178
+
179
+ # Export model
180
+ success = export_yolov8_to_tflite_int8(
181
+ model_path=args.model,
182
+ output_path=args.output,
183
+ dataset_path=args.dataset,
184
+ img_size=args.img_size,
185
+ int8=not args.no_int8,
186
+ dynamic=args.dynamic,
187
+ half=args.half
188
+ )
189
+
190
+ if success:
191
+ print(f"\n✅ Successfully exported model to: {success}")
192
+ print("\nUsage:")
193
+ print(f" python detect_realtime.py --model {args.output}")
194
+ print(f" python detect_realtime.py --model {args.output} --mode detection")
195
+ else:
196
+ print("\n❌ Export failed.")
197
+ sys.exit(1)
198
+
199
+ if __name__ == '__main__':
200
+ main()
scripts/get-pip.py ADDED
The diff for this file is too large to render. See raw diff
 
scripts/label_ripeness_dataset.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Strawberry Ripeness Dataset Labeling Tool
4
+ Helps organize and label the 889 unlabeled images into 3 categories:
5
+ - Unripe (green/white/pale pink)
6
+ - Ripe (bright red)
7
+ - Overripe (dark red/maroon/rotting)
8
+ """
9
+
10
+ import os
11
+ import shutil
12
+ import random
13
+ from pathlib import Path
14
+ from PIL import Image
15
+ import argparse
16
+
17
+ def create_labeling_directories(base_path):
18
+ """Create the three labeling directories"""
19
+ labels = ['unripe', 'ripe', 'overripe']
20
+ dirs = {}
21
+
22
+ for label in labels:
23
+ dir_path = base_path / label
24
+ dir_path.mkdir(exist_ok=True)
25
+ dirs[label] = dir_path
26
+
27
+ return dirs
28
+
29
+ def get_image_files(to_label_path):
30
+ """Get all image files from the to_label directory"""
31
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
32
+ image_files = []
33
+
34
+ for file_path in to_label_path.iterdir():
35
+ if file_path.suffix.lower() in image_extensions:
36
+ image_files.append(file_path)
37
+
38
+ return sorted(image_files)
39
+
40
+ def batch_label_images(image_files, dirs, batch_size=50):
41
+ """Label images in batches for easier management"""
42
+ total_images = len(image_files)
43
+ num_batches = (total_images + batch_size - 1) // batch_size
44
+
45
+ print(f"Found {total_images} images to label")
46
+ print(f"Will process in {num_batches} batches of {batch_size} images each")
47
+
48
+ for batch_num in range(num_batches):
49
+ start_idx = batch_num * batch_size
50
+ end_idx = min(start_idx + batch_size, total_images)
51
+ batch_files = image_files[start_idx:end_idx]
52
+
53
+ print(f"\n=== BATCH {batch_num + 1}/{num_batches} ===")
54
+ print(f"Images {start_idx + 1} to {end_idx}")
55
+
56
+ # Show first few images in batch for preview
57
+ for i, img_file in enumerate(batch_files[:5]):
58
+ try:
59
+ with Image.open(img_file) as img:
60
+ print(f" {start_idx + i + 1}. {img_file.name} ({img.size[0]}x{img.size[1]})")
61
+ except Exception as e:
62
+ print(f" {start_idx + i + 1}. {img_file.name} (Error: {e})")
63
+
64
+ if len(batch_files) > 5:
65
+ print(f" ... and {len(batch_files) - 5} more images")
66
+
67
+ # Interactive labeling for this batch
68
+ label_batch(batch_files, dirs)
69
+
70
+ def label_batch(batch_files, dirs):
71
+ """Interactive labeling for a batch of images"""
72
+ print("\nLabeling Instructions:")
73
+ print("1 = Unripe (green/white/pale pink)")
74
+ print("2 = Ripe (bright red)")
75
+ print("3 = Overripe (dark red/maroon/rotting)")
76
+ print("s = Skip this image")
77
+ print("q = Quit labeling")
78
+ print("Enter = Next image")
79
+
80
+ for img_file in batch_files:
81
+ try:
82
+ # Try to display image info
83
+ with Image.open(img_file) as img:
84
+ print(f"\nProcessing: {img_file.name}")
85
+ print(f"Size: {img.size[0]}x{img.size[1]}, Mode: {img.mode}")
86
+
87
+ while True:
88
+ choice = input("Label (1/2/3/s/q/Enter): ").strip().lower()
89
+
90
+ if choice == '1' or choice == 'unripe':
91
+ shutil.move(str(img_file), str(dirs['unripe'] / img_file.name))
92
+ print(f"✓ Moved to unripe/")
93
+ break
94
+ elif choice == '2' or choice == 'ripe':
95
+ shutil.move(str(img_file), str(dirs['ripe'] / img_file.name))
96
+ print(f"✓ Moved to ripe/")
97
+ break
98
+ elif choice == '3' or choice == 'overripe':
99
+ shutil.move(str(img_file), str(dirs['overripe'] / img_file.name))
100
+ print(f"✓ Moved to overripe/")
101
+ break
102
+ elif choice == 's' or choice == 'skip':
103
+ print("⏭️ Skipped")
104
+ break
105
+ elif choice == 'q' or choice == 'quit':
106
+ print("Quitting...")
107
+ return
108
+ elif choice == '':
109
+ print("⏭️ Skipped (no input)")
110
+ break
111
+ else:
112
+ print("Invalid choice. Please enter 1, 2, 3, s, q, or press Enter.")
113
+
114
+ except Exception as e:
115
+ print(f"Error processing {img_file.name}: {e}")
116
+ continue
117
+
118
+ def count_labeled_images(dirs):
119
+ """Count images in each label directory"""
120
+ counts = {}
121
+ for label, dir_path in dirs.items():
122
+ count = len([f for f in dir_path.iterdir() if f.is_file()])
123
+ counts[label] = count
124
+ return counts
125
+
126
+ def main():
127
+ parser = argparse.ArgumentParser(description='Label strawberry ripeness dataset')
128
+ parser.add_argument('--dataset-path', type=str,
129
+ default='model/ripeness_manual_dataset',
130
+ help='Path to the ripeness dataset directory')
131
+ parser.add_argument('--batch-size', type=int, default=50,
132
+ help='Number of images to process in each batch')
133
+ parser.add_argument('--count-only', action='store_true',
134
+ help='Only count current images, do not start labeling')
135
+
136
+ args = parser.parse_args()
137
+
138
+ base_path = Path(args.dataset_path)
139
+ to_label_path = base_path / 'to_label'
140
+
141
+ if not to_label_path.exists():
142
+ print(f"Error: to_label directory not found at {to_label_path}")
143
+ return
144
+
145
+ # Create labeling directories
146
+ dirs = create_labeling_directories(base_path)
147
+
148
+ # Count current state
149
+ remaining_files = get_image_files(to_label_path)
150
+ labeled_counts = count_labeled_images(dirs)
151
+
152
+ print("=== CURRENT STATUS ===")
153
+ print(f"Images remaining to label: {len(remaining_files)}")
154
+ print(f"Already labeled:")
155
+ for label, count in labeled_counts.items():
156
+ print(f" {label}: {count} images")
157
+ print(f"Total labeled: {sum(labeled_counts.values())}")
158
+
159
+ if args.count_only:
160
+ return
161
+
162
+ if len(remaining_files) == 0:
163
+ print("\n✅ All images have been labeled!")
164
+ print("You can now run: python3 train_ripeness_classifier.py")
165
+ return
166
+
167
+ # Ask user if they want to continue
168
+ response = input(f"\nStart labeling {len(remaining_files)} remaining images? (y/n): ")
169
+ if response.lower() != 'y':
170
+ print("Labeling cancelled.")
171
+ return
172
+
173
+ # Start labeling process
174
+ batch_label_images(remaining_files, dirs, args.batch_size)
175
+
176
+ # Final count
177
+ final_counts = count_labeled_images(dirs)
178
+ print("\n=== FINAL COUNTS ===")
179
+ for label, count in final_counts.items():
180
+ print(f"{label}: {count} images")
181
+
182
+ total_labeled = sum(final_counts.values())
183
+ print(f"Total labeled: {total_labeled}")
184
+
185
+ if total_labeled > 0:
186
+ print("\n✅ Labeling complete!")
187
+ print("Next steps:")
188
+ print("1. Review the labeled images for quality")
189
+ print("2. Run: python3 train_ripeness_classifier.py")
190
+
191
+ if __name__ == '__main__':
192
+ main()
scripts/optimization/optimized_onnx_inference.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Optimized ONNX Inference for Raspberry Pi
4
+ High-performance inference with ONNX Runtime optimizations
5
+ """
6
+
7
+ import os
8
+ import cv2
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ import time
12
+ from pathlib import Path
13
+ from typing import Tuple, List, Optional
14
+
15
+ class OptimizedONNXInference:
16
+ """
17
+ Optimized ONNX inference engine for Raspberry Pi
18
+ """
19
+
20
+ def __init__(self, model_path: str, conf_threshold: float = 0.5):
21
+ """
22
+ Initialize optimized ONNX inference engine
23
+
24
+ Args:
25
+ model_path: Path to ONNX model
26
+ conf_threshold: Confidence threshold for detections
27
+ """
28
+ self.conf_threshold = conf_threshold
29
+ self.model_path = model_path
30
+ self.session = self._create_optimized_session()
31
+ self.input_name = self.session.get_inputs()[0].name
32
+ self.input_shape = self.session.get_inputs()[0].shape
33
+
34
+ # Extract input dimensions
35
+ self.input_height = self.input_shape[2]
36
+ self.input_width = self.input_shape[3]
37
+
38
+ print(f"✅ Optimized ONNX model loaded: {model_path}")
39
+ print(f"📐 Input shape: {self.input_shape}")
40
+ print(f"🎯 Confidence threshold: {conf_threshold}")
41
+
42
+ def _create_optimized_session(self) -> ort.InferenceSession:
43
+ """
44
+ Create ONNX session with Raspberry Pi optimizations
45
+ """
46
+ # Set environment variables for optimization
47
+ os.environ["OMP_NUM_THREADS"] = "4" # Raspberry Pi 4 has 4 cores
48
+ os.environ["OMP_THREAD_LIMIT"] = "4"
49
+ os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
50
+ os.environ["MKL_NUM_THREADS"] = "4"
51
+
52
+ # Session options for maximum performance
53
+ session_options = ort.SessionOptions()
54
+
55
+ # Enable all graph optimizations
56
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
57
+
58
+ # Use sequential execution for consistency
59
+ session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
60
+
61
+ # Optimize thread usage for Raspberry Pi
62
+ session_options.intra_op_num_threads = 4
63
+ session_options.inter_op_num_threads = 1
64
+
65
+ # Enable memory pattern optimization
66
+ session_options.enable_mem_pattern = True
67
+ session_options.enable_mem_reuse = True
68
+
69
+ # CPU execution provider (Raspberry Pi doesn't have CUDA)
70
+ providers = ['CPUExecutionProvider']
71
+
72
+ try:
73
+ session = ort.InferenceSession(
74
+ self.model_path,
75
+ sess_options=session_options,
76
+ providers=providers
77
+ )
78
+ return session
79
+ except Exception as e:
80
+ print(f"❌ Failed to create optimized session: {e}")
81
+ # Fallback to basic session
82
+ return ort.InferenceSession(self.model_path, providers=providers)
83
+
84
+ def preprocess(self, image: np.ndarray) -> np.ndarray:
85
+ """
86
+ Optimized preprocessing for Raspberry Pi
87
+
88
+ Args:
89
+ image: Input image (BGR format)
90
+
91
+ Returns:
92
+ Preprocessed tensor
93
+ """
94
+ # Convert BGR to RGB
95
+ if len(image.shape) == 3:
96
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
97
+
98
+ # Resize with optimization
99
+ image = cv2.resize(image, (self.input_width, self.input_height),
100
+ interpolation=cv2.INTER_LINEAR)
101
+
102
+ # Convert to float32 and normalize
103
+ image = image.astype(np.float32) / 255.0
104
+
105
+ # Transpose to CHW format (ONNX expects this)
106
+ image = np.transpose(image, (2, 0, 1))
107
+
108
+ # Add batch dimension
109
+ image = np.expand_dims(image, axis=0)
110
+
111
+ return image
112
+
113
+ def postprocess(self, outputs: np.ndarray) -> List[dict]:
114
+ """
115
+ Post-process YOLOv8 outputs
116
+
117
+ Args:
118
+ outputs: Raw model outputs
119
+
120
+ Returns:
121
+ List of detections
122
+ """
123
+ detections = []
124
+
125
+ # YOLOv8 output shape: [1, 5, 8400] for 640x640
126
+ # Where 5 = [x, y, w, h, conf] and 8400 = 80x80 + 40x40 + 20x20
127
+
128
+ # Reshape outputs
129
+ outputs = outputs[0] # Remove batch dimension
130
+
131
+ # Filter by confidence
132
+ conf_mask = outputs[4] > self.conf_threshold
133
+ filtered_outputs = outputs[:, conf_mask]
134
+
135
+ if filtered_outputs.shape[1] == 0:
136
+ return detections
137
+
138
+ # Extract boxes and scores
139
+ boxes = filtered_outputs[:4].T # [x, y, w, h]
140
+ scores = filtered_outputs[4] # confidence scores
141
+
142
+ # Convert from center format to corner format
143
+ x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
144
+ x1 = x - w / 2
145
+ y1 = y - h / 2
146
+ x2 = x + w / 2
147
+ y2 = y + h / 2
148
+
149
+ # Clip to image boundaries
150
+ x1 = np.clip(x1, 0, self.input_width)
151
+ y1 = np.clip(y1, 0, self.input_height)
152
+ x2 = np.clip(x2, 0, self.input_width)
153
+ y2 = np.clip(y2, 0, self.input_height)
154
+
155
+ # Create detection dictionaries
156
+ for i in range(len(scores)):
157
+ detection = {
158
+ 'bbox': [float(x1[i]), float(y1[i]), float(x2[i]), float(y2[i])],
159
+ 'confidence': float(scores[i]),
160
+ 'class': 0, # Strawberry class
161
+ 'class_name': 'strawberry'
162
+ }
163
+ detections.append(detection)
164
+
165
+ return detections
166
+
167
+ def predict(self, image: np.ndarray) -> Tuple[List[dict], float]:
168
+ """
169
+ Run optimized inference
170
+
171
+ Args:
172
+ image: Input image
173
+
174
+ Returns:
175
+ Tuple of (detections, inference_time)
176
+ """
177
+ # Preprocess
178
+ input_tensor = self.preprocess(image)
179
+
180
+ # Run inference with timing
181
+ start_time = time.perf_counter()
182
+ outputs = self.session.run(None, {self.input_name: input_tensor})
183
+ inference_time = time.perf_counter() - start_time
184
+
185
+ # Post-process
186
+ detections = self.postprocess(outputs)
187
+
188
+ return detections, inference_time
189
+
190
+ def predict_batch(self, images: List[np.ndarray]) -> Tuple[List[List[dict]], float]:
191
+ """
192
+ Run batch inference for multiple images
193
+
194
+ Args:
195
+ images: List of input images
196
+
197
+ Returns:
198
+ Tuple of (list_of_detections, total_inference_time)
199
+ """
200
+ if not images:
201
+ return [], 0.0
202
+
203
+ # Preprocess all images
204
+ input_tensors = [self.preprocess(img) for img in images]
205
+ batch_tensor = np.concatenate(input_tensors, axis=0)
206
+
207
+ # Run batch inference
208
+ start_time = time.perf_counter()
209
+ outputs = self.session.run(None, {self.input_name: batch_tensor})
210
+ inference_time = time.perf_counter() - start_time
211
+
212
+ # Post-process each image in batch
213
+ all_detections = []
214
+ for i in range(len(images)):
215
+ single_output = outputs[0][i:i+1] # Extract single image output
216
+ detections = self.postprocess([single_output])
217
+ all_detections.append(detections)
218
+
219
+ return all_detections, inference_time
220
+
221
+ def benchmark_model(model_path: str, test_image_path: str, runs: int = 10) -> dict:
222
+ """
223
+ Benchmark model performance
224
+
225
+ Args:
226
+ model_path: Path to ONNX model
227
+ test_image_path: Path to test image
228
+ runs: Number of benchmark runs
229
+
230
+ Returns:
231
+ Benchmark results dictionary
232
+ """
233
+ # Load model
234
+ model = OptimizedONNXInference(model_path)
235
+
236
+ # Load test image
237
+ test_image = cv2.imread(test_image_path)
238
+ if test_image is None:
239
+ raise ValueError(f"Could not load test image: {test_image_path}")
240
+
241
+ # Warmup run
242
+ _ = model.predict(test_image)
243
+
244
+ # Benchmark runs
245
+ times = []
246
+ for _ in range(runs):
247
+ _, inference_time = model.predict(test_image)
248
+ times.append(inference_time * 1000) # Convert to milliseconds
249
+
250
+ # Calculate statistics
251
+ times_array = np.array(times)
252
+ results = {
253
+ 'mean_ms': float(np.mean(times_array)),
254
+ 'median_ms': float(np.median(times_array)),
255
+ 'std_ms': float(np.std(times_array)),
256
+ 'min_ms': float(np.min(times_array)),
257
+ 'max_ms': float(np.max(times_array)),
258
+ 'fps': float(1000 / np.mean(times_array)),
259
+ 'runs': runs
260
+ }
261
+
262
+ return results
263
+
264
+ if __name__ == "__main__":
265
+ # Example usage
266
+ model_path = "model/detection/yolov8n/best_416.onnx"
267
+ test_image = "test_detection_result.jpg"
268
+
269
+ if os.path.exists(model_path) and os.path.exists(test_image):
270
+ print("🚀 Testing Optimized ONNX Inference")
271
+ print("=" * 50)
272
+
273
+ # Load model
274
+ model = OptimizedONNXInference(model_path)
275
+
276
+ # Load and predict
277
+ image = cv2.imread(test_image)
278
+ detections, inference_time = model.predict(image)
279
+
280
+ print(".2f" print(f"📊 Detections found: {len(detections)}")
281
+
282
+ # Benchmark
283
+ print("\n📈 Running benchmark (10 runs)...")
284
+ results = benchmark_model(model_path, test_image, runs=10)
285
+
286
+ print("📊 Benchmark Results:" print(".2f" print(".2f" print(".2f" print(".2f" print(".2f" print(".1f"
287
+ print("\n✅ Optimized inference test complete!")
288
+ else:
289
+ print("❌ Model or test image not found")
290
+ print(f"Model: {model_path} - {'✅' if os.path.exists(model_path) else '❌'}")
291
+ print(f"Image: {test_image} - {'✅' if os.path.exists(test_image) else '❌'}")
scripts/organize_labeled_images.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Organize automatically labeled images from JSON results into proper directories.
4
+ This script processes the auto_labeling_results JSON files and moves images
5
+ from to_label/ to the appropriate subdirectories (unripe/ripe/overripe/).
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import shutil
11
+ from pathlib import Path
12
+ import glob
13
+
14
+ def organize_labeled_images():
15
+ """Organize automatically labeled images based on JSON results."""
16
+
17
+ # Paths
18
+ base_dir = Path("model/ripeness_manual_dataset")
19
+ to_label_dir = base_dir / "to_label"
20
+ unripe_dir = base_dir / "unripe"
21
+ ripe_dir = base_dir / "ripe"
22
+ overripe_dir = base_dir / "overripe"
23
+
24
+ # Find all auto_labeling_results JSON files
25
+ json_files = glob.glob(str(base_dir / "auto_labeling_results_*.json"))
26
+ json_files.sort() # Process in chronological order
27
+
28
+ print(f"Found {len(json_files)} auto_labeling_results files")
29
+
30
+ total_moved = 0
31
+ total_errors = 0
32
+
33
+ for json_file in json_files:
34
+ print(f"\nProcessing: {os.path.basename(json_file)}")
35
+
36
+ try:
37
+ with open(json_file, 'r') as f:
38
+ results = json.load(f)
39
+
40
+ batch_moved = 0
41
+ batch_errors = 0
42
+
43
+ # Process each labeled image (results is a list)
44
+ for label_data in results:
45
+ if isinstance(label_data, dict) and 'image' in label_data and 'label' in label_data:
46
+ image_name = label_data['image']
47
+ label = label_data['label']
48
+ confidence = label_data.get('confidence', 0.0)
49
+
50
+ # Skip if confidence is too low or label is unknown
51
+ if label in ['unknown', 'skip'] or confidence < 0.6:
52
+ continue
53
+
54
+ # Source and destination paths
55
+ src_path = to_label_dir / image_name
56
+
57
+ if not src_path.exists():
58
+ print(f" ⚠️ Source file not found: {image_name}")
59
+ batch_errors += 1
60
+ continue
61
+
62
+ # Determine destination directory
63
+ if label == 'unripe':
64
+ dst_dir = unripe_dir
65
+ elif label == 'ripe':
66
+ dst_dir = ripe_dir
67
+ elif label == 'overripe':
68
+ dst_dir = overripe_dir
69
+ else:
70
+ print(f" ⚠️ Unknown label '{label}': {image_name}")
71
+ batch_errors += 1
72
+ continue
73
+
74
+ dst_path = dst_dir / image_name
75
+
76
+ # Move the file
77
+ try:
78
+ shutil.move(str(src_path), str(dst_path))
79
+ print(f" ✅ {label} ({confidence:.2f}): {image_name}")
80
+ batch_moved += 1
81
+ except Exception as e:
82
+ print(f" ❌ Error moving {image_name}: {e}")
83
+ batch_errors += 1
84
+
85
+ print(f" 📦 Batch results: {batch_moved} moved, {batch_errors} errors")
86
+ total_moved += batch_moved
87
+ total_errors += batch_errors
88
+
89
+ except Exception as e:
90
+ print(f" ❌ Error processing {json_file}: {e}")
91
+ total_errors += 1
92
+
93
+ # Final summary
94
+ print(f"\n=== ORGANIZATION COMPLETE ===")
95
+ print(f"Total images moved: {total_moved}")
96
+ print(f"Total errors: {total_errors}")
97
+
98
+ # Show final counts
99
+ unripe_count = len(list(unripe_dir.glob("*.jpg")))
100
+ ripe_count = len(list(ripe_dir.glob("*.jpg")))
101
+ overripe_count = len(list(overripe_dir.glob("*.jpg")))
102
+ remaining_count = len(list(to_label_dir.glob("*.jpg")))
103
+
104
+ print(f"\nFinal counts:")
105
+ print(f" unripe: {unripe_count} images")
106
+ print(f" ripe: {ripe_count} images")
107
+ print(f" overripe: {overripe_count} images")
108
+ print(f" to_label: {remaining_count} images")
109
+ print(f" TOTAL: {unripe_count + ripe_count + overripe_count + remaining_count} images")
110
+
111
+ completion_rate = (unripe_count + ripe_count + overripe_count) / (unripe_count + ripe_count + overripe_count + remaining_count) * 100
112
+ print(f" Completion: {completion_rate:.1f}%")
113
+
114
+ if __name__ == "__main__":
115
+ organize_labeled_images()
scripts/setup_training.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Setup script for Strawberry Picker ML Training Environment
4
+ This script installs dependencies and validates the training setup.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import subprocess
10
+ import argparse
11
+ from pathlib import Path
12
+
13
+ def run_command(cmd, description=""):
14
+ """Run a shell command and handle errors"""
15
+ print(f"\n{'='*60}")
16
+ print(f"Running: {description or cmd}")
17
+ print(f"{'='*60}\n")
18
+
19
+ try:
20
+ result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
21
+ if result.stdout:
22
+ print(result.stdout)
23
+ return True
24
+ except subprocess.CalledProcessError as e:
25
+ print(f"ERROR: Command failed with return code {e.returncode}")
26
+ if e.stdout:
27
+ print(f"STDOUT: {e.stdout}")
28
+ if e.stderr:
29
+ print(f"STDERR: {e.stderr}")
30
+ return False
31
+
32
+ def check_python_version():
33
+ """Check if Python version is compatible"""
34
+ print("Checking Python version...")
35
+ version = sys.version_info
36
+ if version.major < 3 or (version.major == 3 and version.minor < 8):
37
+ print(f"ERROR: Python 3.8+ required. Found {version.major}.{version.minor}")
38
+ return False
39
+ print(f"✓ Python {version.major}.{version.minor}.{version.micro}")
40
+ return True
41
+
42
+ def check_pip():
43
+ """Check if pip is available"""
44
+ print("Checking pip availability...")
45
+ return run_command("pip --version", "Check pip version")
46
+
47
+ def install_requirements():
48
+ """Install Python dependencies"""
49
+ print("Installing Python dependencies...")
50
+ requirements_file = Path(__file__).parent / "requirements.txt"
51
+
52
+ if not requirements_file.exists():
53
+ print(f"ERROR: requirements.txt not found at {requirements_file}")
54
+ return False
55
+
56
+ # Upgrade pip first
57
+ if not run_command("pip install --upgrade pip", "Upgrade pip"):
58
+ return False
59
+
60
+ # Install requirements
61
+ return run_command(f"pip install -r {requirements_file}", "Install requirements")
62
+
63
+ def check_ultralytics():
64
+ """Check if ultralytics is installed correctly"""
65
+ print("Checking ultralytics installation...")
66
+ try:
67
+ from ultralytics import YOLO
68
+ print("✓ ultralytics installed successfully")
69
+ return True
70
+ except ImportError as e:
71
+ print(f"ERROR: Failed to import ultralytics: {e}")
72
+ return False
73
+
74
+ def check_torch():
75
+ """Check PyTorch installation and GPU availability"""
76
+ print("Checking PyTorch installation...")
77
+ try:
78
+ import torch
79
+ print(f"✓ PyTorch version: {torch.__version__}")
80
+
81
+ if torch.cuda.is_available():
82
+ print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
83
+ print(f"✓ CUDA version: {torch.version.cuda}")
84
+ else:
85
+ print("⚠ GPU not available, will use CPU for training")
86
+
87
+ return True
88
+ except ImportError as e:
89
+ print(f"ERROR: Failed to import torch: {e}")
90
+ return False
91
+
92
+ def validate_dataset():
93
+ """Validate dataset structure"""
94
+ print("Validating dataset structure...")
95
+
96
+ dataset_path = Path(__file__).parent / "model" / "dataset" / "straw-detect.v1-straw-detect.yolov8"
97
+ data_yaml = dataset_path / "data.yaml"
98
+
99
+ if not data_yaml.exists():
100
+ print(f"ERROR: data.yaml not found at {data_yaml}")
101
+ print("Please ensure your dataset is in the correct location")
102
+ return False
103
+
104
+ try:
105
+ import yaml
106
+ with open(data_yaml, 'r') as f:
107
+ data = yaml.safe_load(f)
108
+
109
+ print(f"✓ Dataset configuration loaded")
110
+ print(f" Classes: {data['nc']}")
111
+ print(f" Names: {data['names']}")
112
+
113
+ # Check training images
114
+ train_path = dataset_path / data['train']
115
+ if train_path.exists():
116
+ train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))
117
+ print(f" Training images: {len(train_images)}")
118
+ else:
119
+ print(f"⚠ Training path not found: {train_path}")
120
+
121
+ # Check validation images
122
+ val_path = dataset_path / data['val']
123
+ if val_path.exists():
124
+ val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))
125
+ print(f" Validation images: {len(val_images)}")
126
+ else:
127
+ print(f"⚠ Validation path not found: {val_path}")
128
+
129
+ return True
130
+
131
+ except Exception as e:
132
+ print(f"ERROR: Failed to validate dataset: {e}")
133
+ return False
134
+
135
+ def create_directories():
136
+ """Create necessary directories"""
137
+ print("Creating project directories...")
138
+
139
+ base_path = Path(__file__).parent
140
+ dirs = [
141
+ base_path / "model" / "weights",
142
+ base_path / "model" / "results",
143
+ base_path / "model" / "exports"
144
+ ]
145
+
146
+ for dir_path in dirs:
147
+ dir_path.mkdir(parents=True, exist_ok=True)
148
+ print(f"✓ Created: {dir_path}")
149
+
150
+ return True
151
+
152
+ def main():
153
+ parser = argparse.ArgumentParser(description='Setup training environment for strawberry detection')
154
+ parser.add_argument('--skip-install', action='store_true', help='Skip package installation')
155
+ parser.add_argument('--validate-only', action='store_true', help='Only validate setup without installing')
156
+
157
+ args = parser.parse_args()
158
+
159
+ print("="*60)
160
+ print("Strawberry Picker ML Training Environment Setup")
161
+ print("="*60)
162
+
163
+ # Step 1: Check Python version
164
+ if not check_python_version():
165
+ sys.exit(1)
166
+
167
+ # Step 2: Check pip
168
+ if not check_pip():
169
+ sys.exit(1)
170
+
171
+ # Step 3: Install requirements (unless skipped)
172
+ if not args.skip_install and not args.validate_only:
173
+ if not install_requirements():
174
+ print("\n⚠ Installation failed. Please check the errors above.")
175
+ response = input("Continue with validation anyway? (y/n): ")
176
+ if response.lower() != 'y':
177
+ sys.exit(1)
178
+
179
+ # Step 4: Check ultralytics
180
+ if not check_ultralytics():
181
+ sys.exit(1)
182
+
183
+ # Step 5: Check PyTorch
184
+ if not check_torch():
185
+ sys.exit(1)
186
+
187
+ # Step 6: Validate dataset
188
+ if not validate_dataset():
189
+ print("\n⚠ Dataset validation failed. Please fix the issues above.")
190
+ if not args.validate_only:
191
+ response = input("Continue with directory creation anyway? (y/n): ")
192
+ if response.lower() != 'y':
193
+ sys.exit(1)
194
+
195
+ # Step 7: Create directories
196
+ if not args.validate_only:
197
+ if not create_directories():
198
+ sys.exit(1)
199
+
200
+ print("\n" + "="*60)
201
+ if args.validate_only:
202
+ print("Setup validation completed!")
203
+ else:
204
+ print("Setup completed successfully!")
205
+
206
+ print("\nNext steps:")
207
+ print("1. Run training: python train_yolov8.py")
208
+ print("2. Or open train_yolov8_colab.ipynb in Google Colab")
209
+ print("3. Check README.md for detailed instructions")
210
+ print("="*60)
211
+
212
+ if __name__ == '__main__':
213
+ main()
scripts/train_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
3
+ from tensorflow.keras.applications import MobileNetV2
4
+ from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
5
+ from tensorflow.keras.models import Model
6
+ import os
7
+
8
+ # Data directories
9
+ data_dir = 'dataset'
10
+ train_datagen = ImageDataGenerator(
11
+ rescale=1./255,
12
+ validation_split=0.2,
13
+ rotation_range=20,
14
+ width_shift_range=0.2,
15
+ height_shift_range=0.2,
16
+ horizontal_flip=True
17
+ )
18
+
19
+ train_generator = train_datagen.flow_from_directory(
20
+ data_dir,
21
+ target_size=(224, 224),
22
+ batch_size=32,
23
+ class_mode='binary',
24
+ subset='training'
25
+ )
26
+
27
+ validation_generator = train_datagen.flow_from_directory(
28
+ data_dir,
29
+ target_size=(224, 224),
30
+ batch_size=32,
31
+ class_mode='binary',
32
+ subset='validation'
33
+ )
34
+
35
+ # Load pre-trained MobileNetV2
36
+ base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
37
+
38
+ # Add custom layers
39
+ x = base_model.output
40
+ x = GlobalAveragePooling2D()(x)
41
+ x = Dense(1024, activation='relu')(x)
42
+ predictions = Dense(1, activation='sigmoid')(x)
43
+
44
+ model = Model(inputs=base_model.input, outputs=predictions)
45
+
46
+ # Freeze base layers
47
+ for layer in base_model.layers:
48
+ layer.trainable = False
49
+
50
+ # Compile
51
+ model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
52
+
53
+ # Train
54
+ model.fit(train_generator, validation_data=validation_generator, epochs=10)
55
+
56
+ # Save model
57
+ model.save('strawberry_model.h5')
58
+
59
+ print("Model trained and saved as strawberry_model.h5")
scripts/train_ripeness_classifier.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Strawberry Ripeness Classification Training Script
4
+ Trains a 3-class classifier (unripe/ripe/overripe) using transfer learning
5
+ """
6
+
7
+ import os
8
+ import argparse
9
+ import json
10
+ import numpy as np
11
+ import pandas as pd
12
+ from pathlib import Path
13
+ import yaml
14
+ from datetime import datetime
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+
18
+ # Deep Learning
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ from torch.utils.data import Dataset, DataLoader
23
+ import torchvision.transforms as transforms
24
+ import torchvision.models as models
25
+ from torchvision.datasets import ImageFolder
26
+ from sklearn.metrics import classification_report, confusion_matrix
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ # Set random seeds for reproducibility
30
+ torch.manual_seed(42)
31
+ np.random.seed(42)
32
+
33
+ class RipenessDataset(Dataset):
34
+ """Custom dataset for strawberry ripeness classification"""
35
+
36
+ def __init__(self, data_dir, transform=None, split='train'):
37
+ self.data_dir = Path(data_dir)
38
+ self.transform = transform
39
+ self.split = split
40
+
41
+ # Get class names and counts (exclude 'to_label' directory)
42
+ self.classes = sorted([d.name for d in self.data_dir.iterdir()
43
+ if d.is_dir() and d.name != 'to_label'])
44
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
45
+
46
+ # Get all image paths and labels
47
+ self.samples = []
48
+ for class_name in self.classes:
49
+ class_dir = self.data_dir / class_name
50
+ if class_dir.exists():
51
+ for img_path in class_dir.glob('*.jpg'):
52
+ self.samples.append((str(img_path), self.class_to_idx[class_name]))
53
+
54
+ print(f"{split} dataset: {len(self.samples)} samples")
55
+ print(f"Classes: {self.classes}")
56
+
57
+ def __len__(self):
58
+ return len(self.samples)
59
+
60
+ def __getitem__(self, idx):
61
+ img_path, label = self.samples[idx]
62
+
63
+ # Load image
64
+ from PIL import Image
65
+ image = Image.open(img_path).convert('RGB')
66
+
67
+ if self.transform:
68
+ image = self.transform(image)
69
+
70
+ return image, label
71
+
72
+ def get_transforms(img_size=224):
73
+ """Get data transforms for training and validation"""
74
+
75
+ # Training transforms with augmentation
76
+ train_transform = transforms.Compose([
77
+ transforms.Resize((img_size, img_size)),
78
+ transforms.RandomHorizontalFlip(p=0.5),
79
+ transforms.RandomRotation(degrees=15),
80
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
83
+ ])
84
+
85
+ # Validation transforms (no augmentation)
86
+ val_transform = transforms.Compose([
87
+ transforms.Resize((img_size, img_size)),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
90
+ ])
91
+
92
+ return train_transform, val_transform
93
+
94
+ def create_model(num_classes=3, backbone='resnet18', pretrained=True):
95
+ """Create model with transfer learning"""
96
+
97
+ if backbone == 'resnet18':
98
+ model = models.resnet18(pretrained=pretrained)
99
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
100
+ elif backbone == 'resnet50':
101
+ model = models.resnet50(pretrained=pretrained)
102
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
103
+ elif backbone == 'efficientnet_b0':
104
+ model = models.efficientnet_b0(pretrained=pretrained)
105
+ model.classifier = nn.Linear(model.classifier.in_features, num_classes)
106
+ else:
107
+ raise ValueError(f"Unsupported backbone: {backbone}")
108
+
109
+ return model
110
+
111
+ def train_model(model, train_loader, val_loader, device, num_epochs=50, lr=0.001):
112
+ """Train the model"""
113
+
114
+ criterion = nn.CrossEntropyLoss()
115
+ optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
116
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.5)
117
+
118
+ best_val_acc = 0.0
119
+ train_losses = []
120
+ val_accuracies = []
121
+
122
+ for epoch in range(num_epochs):
123
+ # Training phase
124
+ model.train()
125
+ running_loss = 0.0
126
+ correct = 0
127
+ total = 0
128
+
129
+ for batch_idx, (images, labels) in enumerate(train_loader):
130
+ images, labels = images.to(device), labels.to(device)
131
+
132
+ optimizer.zero_grad()
133
+ outputs = model(images)
134
+ loss = criterion(outputs, labels)
135
+ loss.backward()
136
+ optimizer.step()
137
+
138
+ running_loss += loss.item()
139
+ _, predicted = outputs.max(1)
140
+ total += labels.size(0)
141
+ correct += predicted.eq(labels).sum().item()
142
+
143
+ if batch_idx % 10 == 0:
144
+ print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, '
145
+ f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
146
+
147
+ train_loss = running_loss / len(train_loader)
148
+ train_acc = 100. * correct / total
149
+
150
+ # Validation phase
151
+ model.eval()
152
+ val_correct = 0
153
+ val_total = 0
154
+ val_loss = 0.0
155
+
156
+ with torch.no_grad():
157
+ for images, labels in val_loader:
158
+ images, labels = images.to(device), labels.to(device)
159
+ outputs = model(images)
160
+ loss = criterion(outputs, labels)
161
+
162
+ val_loss += loss.item()
163
+ _, predicted = outputs.max(1)
164
+ val_total += labels.size(0)
165
+ val_correct += predicted.eq(labels).sum().item()
166
+
167
+ val_acc = 100. * val_correct / val_total
168
+ val_loss = val_loss / len(val_loader)
169
+
170
+ train_losses.append(train_loss)
171
+ val_accuracies.append(val_acc)
172
+
173
+ print(f'Epoch {epoch+1}/{num_epochs}:')
174
+ print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
175
+ print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
176
+ print('-' * 50)
177
+
178
+ # Save best model
179
+ if val_acc > best_val_acc:
180
+ best_val_acc = val_acc
181
+ torch.save(model.state_dict(), 'model/ripeness_classifier_best.pth')
182
+ print(f'New best model saved! Val Acc: {best_val_acc:.2f}%')
183
+
184
+ scheduler.step(val_acc)
185
+
186
+ return train_losses, val_accuracies, best_val_acc
187
+
188
+ def evaluate_model(model, test_loader, device, class_names):
189
+ """Evaluate model and generate reports"""
190
+
191
+ model.eval()
192
+ all_preds = []
193
+ all_labels = []
194
+
195
+ with torch.no_grad():
196
+ for images, labels in test_loader:
197
+ images, labels = images.to(device), labels.to(device)
198
+ outputs = model(images)
199
+ _, predicted = outputs.max(1)
200
+
201
+ all_preds.extend(predicted.cpu().numpy())
202
+ all_labels.extend(labels.cpu().numpy())
203
+
204
+ # Classification report
205
+ report = classification_report(all_labels, all_preds, target_names=class_names)
206
+ print("Classification Report:")
207
+ print(report)
208
+
209
+ # Confusion matrix
210
+ cm = confusion_matrix(all_labels, all_preds)
211
+
212
+ # Plot confusion matrix
213
+ plt.figure(figsize=(8, 6))
214
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
215
+ xticklabels=class_names, yticklabels=class_names)
216
+ plt.title('Confusion Matrix')
217
+ plt.ylabel('True Label')
218
+ plt.xlabel('Predicted Label')
219
+ plt.savefig('model/ripeness_confusion_matrix.png', dpi=300, bbox_inches='tight')
220
+ plt.close()
221
+
222
+ return report, cm
223
+
224
+ def plot_training_history(train_losses, val_accuracies, save_path):
225
+ """Plot training history"""
226
+
227
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
228
+
229
+ # Plot training loss
230
+ ax1.plot(train_losses)
231
+ ax1.set_title('Training Loss')
232
+ ax1.set_xlabel('Epoch')
233
+ ax1.set_ylabel('Loss')
234
+ ax1.grid(True)
235
+
236
+ # Plot validation accuracy
237
+ ax2.plot(val_accuracies)
238
+ ax2.set_title('Validation Accuracy')
239
+ ax2.set_xlabel('Epoch')
240
+ ax2.set_ylabel('Accuracy (%)')
241
+ ax2.grid(True)
242
+
243
+ plt.tight_layout()
244
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
245
+ plt.close()
246
+
247
+ def main():
248
+ parser = argparse.ArgumentParser(description='Train strawberry ripeness classifier')
249
+ parser.add_argument('--data-dir', default='model/ripeness_manual_dataset',
250
+ help='Directory containing labeled images')
251
+ parser.add_argument('--img-size', type=int, default=224, help='Image size')
252
+ parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
253
+ parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
254
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
255
+ parser.add_argument('--backbone', default='resnet18',
256
+ choices=['resnet18', 'resnet50', 'efficientnet_b0'],
257
+ help='Backbone architecture')
258
+ parser.add_argument('--val-split', type=float, default=0.2, help='Validation split ratio')
259
+ parser.add_argument('--output-dir', default='model/ripeness_classifier',
260
+ help='Output directory for models and results')
261
+
262
+ args = parser.parse_args()
263
+
264
+ # Create output directory
265
+ os.makedirs(args.output_dir, exist_ok=True)
266
+
267
+ # Load config
268
+ with open('config.yaml', 'r') as f:
269
+ config = yaml.safe_load(f)
270
+
271
+ # Set device
272
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
273
+ print(f"Using device: {device}")
274
+
275
+ # Get transforms
276
+ train_transform, val_transform = get_transforms(args.img_size)
277
+
278
+ # Create datasets
279
+ train_dataset = RipenessDataset(args.data_dir, transform=train_transform, split='train')
280
+ val_dataset = RipenessDataset(args.data_dir, transform=val_transform, split='val')
281
+
282
+ # Create data loaders
283
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
284
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
285
+
286
+ # Create model
287
+ num_classes = len(train_dataset.classes)
288
+ model = create_model(num_classes=num_classes, backbone=args.backbone, pretrained=True)
289
+ model = model.to(device)
290
+
291
+ print(f"Model created with {num_classes} classes: {train_dataset.classes}")
292
+ print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
293
+
294
+ # Train model
295
+ print("Starting training...")
296
+ train_losses, val_accuracies, best_val_acc = train_model(
297
+ model, train_loader, val_loader, device,
298
+ num_epochs=args.epochs, lr=args.lr
299
+ )
300
+
301
+ # Load best model for evaluation
302
+ model.load_state_dict(torch.load('model/ripeness_classifier_best.pth'))
303
+
304
+ # Evaluate model
305
+ print("Evaluating model...")
306
+ report, cm = evaluate_model(model, val_loader, device, train_dataset.classes)
307
+
308
+ # Plot training history
309
+ plot_training_history(train_losses, val_accuracies,
310
+ f'{args.output_dir}/training_history.png')
311
+
312
+ # Save results
313
+ results = {
314
+ 'model_architecture': args.backbone,
315
+ 'num_classes': num_classes,
316
+ 'class_names': train_dataset.classes,
317
+ 'best_val_accuracy': best_val_acc,
318
+ 'training_config': {
319
+ 'img_size': args.img_size,
320
+ 'batch_size': args.batch_size,
321
+ 'epochs': args.epochs,
322
+ 'learning_rate': args.lr,
323
+ 'val_split': args.val_split
324
+ },
325
+ 'dataset_info': {
326
+ 'total_samples': len(train_dataset),
327
+ 'class_distribution': {cls: len(list(Path(args.data_dir, cls).glob('*.jpg')))
328
+ for cls in train_dataset.classes}
329
+ }
330
+ }
331
+
332
+ with open(f'{args.output_dir}/training_results.json', 'w') as f:
333
+ json.dump(results, f, indent=2)
334
+
335
+ # Save classification report
336
+ with open(f'{args.output_dir}/classification_report.txt', 'w') as f:
337
+ f.write(report)
338
+
339
+ print(f"\nTraining completed!")
340
+ print(f"Best validation accuracy: {best_val_acc:.2f}%")
341
+ print(f"Results saved to: {args.output_dir}")
342
+ print(f"Model saved to: model/ripeness_classifier_best.pth")
343
+
344
+ if __name__ == '__main__':
345
+ main()
scripts/train_yolov8.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ YOLOv8 Training Script for Strawberry Detection
4
+ Compatible with: Local Python, WSL, Google Colab (VS Code extension)
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ from pathlib import Path
11
+ import torch
12
+ import yaml
13
+
14
+ def check_environment():
15
+ """Detect running environment and configure paths accordingly"""
16
+ env_info = {
17
+ 'is_colab': 'COLAB_GPU' in os.environ or '/content' in os.getcwd(),
18
+ 'is_wsl': 'WSL_DISTRO_NAME' in os.environ,
19
+ 'has_gpu': torch.cuda.is_available(),
20
+ 'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
21
+ }
22
+ return env_info
23
+
24
+ def setup_paths(dataset_path=None):
25
+ """Configure dataset and output paths based on environment"""
26
+ env = check_environment()
27
+
28
+ if env['is_colab']:
29
+ # Google Colab paths
30
+ base_path = Path('/content/strawberry-picker')
31
+ dataset_path = dataset_path or '/content/dataset'
32
+ weights_dir = base_path / 'weights'
33
+ results_dir = base_path / 'results'
34
+ else:
35
+ # Local/WSL paths
36
+ base_path = Path(__file__).parent
37
+ dataset_path = dataset_path or base_path / 'model' / 'dataset' / 'straw-detect.v1-straw-detect.yolov8'
38
+ weights_dir = base_path / 'model' / 'weights'
39
+ results_dir = base_path / 'model' / 'results'
40
+
41
+ # Create directories
42
+ weights_dir.mkdir(parents=True, exist_ok=True)
43
+ results_dir.mkdir(parents=True, exist_ok=True)
44
+
45
+ return {
46
+ 'dataset_path': Path(dataset_path),
47
+ 'weights_dir': weights_dir,
48
+ 'results_dir': results_dir,
49
+ 'base_path': base_path
50
+ }
51
+
52
+ def validate_dataset(dataset_path):
53
+ """Validate YOLO dataset structure"""
54
+ dataset_path = Path(dataset_path)
55
+ data_yaml = dataset_path / 'data.yaml'
56
+
57
+ if not data_yaml.exists():
58
+ raise FileNotFoundError(f"data.yaml not found at {data_yaml}")
59
+
60
+ # Load and validate YAML
61
+ with open(data_yaml, 'r') as f:
62
+ data = yaml.safe_load(f)
63
+
64
+ required_keys = ['train', 'val', 'nc', 'names']
65
+ for key in required_keys:
66
+ if key not in data:
67
+ raise ValueError(f"Missing required key '{key}' in data.yaml")
68
+
69
+ # Check if paths are relative and resolve them
70
+ train_path = dataset_path / data['train']
71
+ val_path = dataset_path / data['val']
72
+
73
+ if not train_path.exists():
74
+ raise FileNotFoundError(f"Training images not found at {train_path}")
75
+ if not val_path.exists():
76
+ raise FileNotFoundError(f"Validation images not found at {val_path}")
77
+
78
+ print(f"✓ Dataset validated: {data['nc']} classes - {data['names']}")
79
+ print(f"✓ Training images: {train_path}")
80
+ print(f"✓ Validation images: {val_path}")
81
+
82
+ return data_yaml
83
+
84
+ def train_model(data_yaml, weights_dir, results_dir, epochs=100, img_size=640, batch_size=16, weights=None, resume=False):
85
+ """Train YOLOv8 model (supports resuming from checkpoints)"""
86
+ try:
87
+ from ultralytics import YOLO
88
+ except ImportError:
89
+ print("ERROR: ultralytics not installed. Run: pip install ultralytics")
90
+ sys.exit(1)
91
+
92
+ env = check_environment()
93
+ print(f"\n{'='*60}")
94
+ print(f"Environment: {'Google Colab' if env['is_colab'] else 'Local/WSL'}")
95
+ print(f"GPU Available: {env['has_gpu']} ({env['gpu_name']})")
96
+ print(f"{'='*60}\n")
97
+
98
+ # Use GPU if available
99
+ device = '0' if env['has_gpu'] else 'cpu'
100
+
101
+ # Load model (custom weights or default YOLOv8n)
102
+ model_source = Path(weights) if weights else 'yolov8n.pt'
103
+ print(f"Loading model from {model_source}...")
104
+ model = YOLO(str(model_source))
105
+
106
+ # Training arguments
107
+ train_args = {
108
+ 'data': str(data_yaml),
109
+ 'epochs': epochs,
110
+ 'imgsz': img_size,
111
+ 'batch': batch_size,
112
+ 'device': device,
113
+ 'project': str(results_dir),
114
+ 'name': 'strawberry_detection',
115
+ 'exist_ok': True,
116
+ 'patience': 20, # Early stopping patience
117
+ 'save': True,
118
+ 'save_period': 10, # Save checkpoint every 10 epochs
119
+ 'cache': True, # Cache images for faster training
120
+ }
121
+
122
+ if resume:
123
+ train_args['resume'] = True
124
+
125
+ # Adjust batch size for Colab's limited RAM
126
+ if env['is_colab'] and batch_size > 16:
127
+ train_args['batch'] = 16
128
+ print(f"Adjusted batch size to 16 for Colab environment")
129
+
130
+ print(f"\nStarting training with arguments:")
131
+ for key, value in train_args.items():
132
+ print(f" {key}: {value}")
133
+
134
+ # Train the model
135
+ print(f"\n{'='*60}")
136
+ print("TRAINING STARTED")
137
+ print(f"{'='*60}\n")
138
+
139
+ results = model.train(**train_args)
140
+
141
+ # Save final model
142
+ final_model_path = weights_dir / 'strawberry_yolov8n.pt'
143
+ model.save(str(final_model_path))
144
+
145
+ print(f"\n{'='*60}")
146
+ print(f"Training completed!")
147
+ print(f"Final model saved to: {final_model_path}")
148
+ print(f"Results saved to: {results_dir / 'strawberry_detection'}")
149
+ print(f"{'='*60}\n")
150
+
151
+ return results, final_model_path
152
+
153
+ def export_model(model_path, weights_dir):
154
+ """Export model to ONNX format"""
155
+ try:
156
+ from ultralytics import YOLO
157
+ except ImportError:
158
+ print("ERROR: ultralytics not installed")
159
+ return None
160
+
161
+ print(f"\nExporting model to ONNX...")
162
+ model = YOLO(str(model_path))
163
+
164
+ # Export to ONNX
165
+ onnx_path = weights_dir / 'strawberry_yolov8n.onnx'
166
+ model.export(format='onnx', imgsz=640, dynamic=True)
167
+
168
+ print(f"ONNX model exported to: {onnx_path}")
169
+ return onnx_path
170
+
171
+ def main():
172
+ parser = argparse.ArgumentParser(description='Train YOLOv8 for strawberry detection')
173
+ parser.add_argument('--dataset', type=str, help='Path to dataset directory')
174
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
175
+ parser.add_argument('--img-size', type=int, default=640, help='Image size for training')
176
+ parser.add_argument('--batch-size', type=int, default=16, help='Batch size for training')
177
+ parser.add_argument('--weights', type=str, help='Path to pretrained weights or checkpoint')
178
+ parser.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint')
179
+ parser.add_argument('--export-onnx', action='store_true', help='Export to ONNX after training')
180
+ parser.add_argument('--validate-only', action='store_true', help='Only validate dataset without training')
181
+
182
+ args = parser.parse_args()
183
+
184
+ try:
185
+ # Setup paths
186
+ paths = setup_paths(args.dataset)
187
+ print(f"Base path: {paths['base_path']}")
188
+ print(f"Dataset path: {paths['dataset_path']}")
189
+ print(f"Weights directory: {paths['weights_dir']}")
190
+ print(f"Results directory: {paths['results_dir']}")
191
+
192
+ # Validate dataset
193
+ print(f"\nValidating dataset...")
194
+ data_yaml = validate_dataset(paths['dataset_path'])
195
+
196
+ if args.validate_only:
197
+ print("Dataset validation completed. Exiting without training.")
198
+ return
199
+
200
+ # Train model
201
+ results, model_path = train_model(
202
+ data_yaml=data_yaml,
203
+ weights_dir=paths['weights_dir'],
204
+ results_dir=paths['results_dir'],
205
+ epochs=args.epochs,
206
+ img_size=args.img_size,
207
+ batch_size=args.batch_size,
208
+ weights=args.weights,
209
+ resume=args.resume
210
+ )
211
+
212
+ # Export to ONNX if requested
213
+ if args.export_onnx:
214
+ export_model(model_path, paths['weights_dir'])
215
+
216
+ print("\n✓ Training pipeline completed successfully!")
217
+
218
+ except Exception as e:
219
+ print(f"\n✗ Error: {str(e)}")
220
+ sys.exit(1)
221
+
222
+ if __name__ == '__main__':
223
+ main()
scripts/validate_model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Model Validation Script for Strawberry Ripeness Classification
4
+ Tests the trained model on sample images to verify functionality
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ import numpy as np
11
+ import cv2
12
+ from pathlib import Path
13
+ import json
14
+ from datetime import datetime
15
+
16
+ # Add current directory to path for imports
17
+ sys.path.append('.')
18
+
19
+ from train_ripeness_classifier import create_model, get_transforms
20
+
21
+ def load_model(model_path):
22
+ """Load the trained classification model"""
23
+ print(f"Loading model from: {model_path}")
24
+
25
+ if not os.path.exists(model_path):
26
+ raise FileNotFoundError(f"Model file not found: {model_path}")
27
+
28
+ # Create model architecture
29
+ model = create_model(num_classes=3, backbone='resnet18', pretrained=False)
30
+
31
+ # Load trained weights
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ model.load_state_dict(torch.load(model_path, map_location=device))
34
+ model = model.to(device)
35
+ model.eval()
36
+
37
+ print(f"Model loaded successfully on {device}")
38
+ return model, device
39
+
40
+ def get_test_images():
41
+ """Get sample test images from the dataset"""
42
+ test_dirs = [
43
+ 'model/ripeness_manual_dataset/unripe',
44
+ 'model/ripeness_manual_dataset/ripe',
45
+ 'model/ripeness_manual_dataset/overripe'
46
+ ]
47
+
48
+ test_images = []
49
+ for test_dir in test_dirs:
50
+ if os.path.exists(test_dir):
51
+ images = list(Path(test_dir).glob('*.jpg'))[:3] # Get first 3 images from each class
52
+ for img_path in images:
53
+ test_images.append({
54
+ 'path': str(img_path),
55
+ 'true_label': os.path.basename(test_dir),
56
+ 'class_name': os.path.basename(test_dir)
57
+ })
58
+
59
+ return test_images
60
+
61
+ def predict_image(model, device, image_path, transform):
62
+ """Predict ripeness for a single image"""
63
+ try:
64
+ # Load and preprocess image
65
+ image = cv2.imread(image_path)
66
+ if image is None:
67
+ return None, "Failed to load image"
68
+
69
+ # Convert BGR to RGB
70
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
71
+ from PIL import Image
72
+ image_pil = Image.fromarray(image)
73
+
74
+ # Apply transforms
75
+ input_tensor = transform(image_pil).unsqueeze(0).to(device)
76
+
77
+ # Get prediction
78
+ with torch.no_grad():
79
+ outputs = model(input_tensor)
80
+ probabilities = torch.softmax(outputs, dim=1)
81
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
82
+ confidence = probabilities[0][predicted_class_idx].item()
83
+
84
+ # Get class names
85
+ class_names = ['overripe', 'ripe', 'unripe']
86
+ predicted_class = class_names[predicted_class_idx]
87
+
88
+ # Get all probabilities
89
+ probs_dict = {
90
+ class_names[i]: float(probabilities[0][i].item())
91
+ for i in range(len(class_names))
92
+ }
93
+
94
+ return {
95
+ 'predicted_class': predicted_class,
96
+ 'confidence': confidence,
97
+ 'probabilities': probs_dict
98
+ }, None
99
+
100
+ except Exception as e:
101
+ return None, str(e)
102
+
103
+ def validate_model():
104
+ """Main validation function"""
105
+ print("=== Strawberry Ripeness Classification Model Validation ===")
106
+ print(f"Validation time: {datetime.now().isoformat()}")
107
+ print()
108
+
109
+ # Load model
110
+ model_path = 'model/ripeness_classifier_best.pth'
111
+ try:
112
+ model, device = load_model(model_path)
113
+ except Exception as e:
114
+ print(f"❌ Failed to load model: {e}")
115
+ return False
116
+
117
+ # Get transforms
118
+ _, transform = get_transforms(img_size=224)
119
+
120
+ # Get test images
121
+ test_images = get_test_images()
122
+ if not test_images:
123
+ print("❌ No test images found")
124
+ return False
125
+
126
+ print(f"Found {len(test_images)} test images")
127
+ print()
128
+
129
+ # Test predictions
130
+ results = []
131
+ correct_predictions = 0
132
+ total_predictions = 0
133
+
134
+ print("Testing predictions...")
135
+ print("-" * 80)
136
+
137
+ for i, test_img in enumerate(test_images):
138
+ image_path = test_img['path']
139
+ true_label = test_img['true_label']
140
+
141
+ # Make prediction
142
+ prediction, error = predict_image(model, device, image_path, transform)
143
+
144
+ if error:
145
+ print(f"❌ Image {i+1}: Error - {error}")
146
+ continue
147
+
148
+ predicted_class = prediction['predicted_class']
149
+ confidence = prediction['confidence']
150
+
151
+ # Check if prediction is correct
152
+ is_correct = predicted_class == true_label
153
+ if is_correct:
154
+ correct_predictions += 1
155
+ total_predictions += 1
156
+
157
+ # Print result
158
+ status = "✅" if is_correct else "❌"
159
+ print(f"{status} Image {i+1}: {os.path.basename(image_path)}")
160
+ print(f" True: {true_label} | Predicted: {predicted_class} ({confidence:.3f})")
161
+ print(f" Probabilities: overripe={prediction['probabilities']['overripe']:.3f}, "
162
+ f"ripe={prediction['probabilities']['ripe']:.3f}, "
163
+ f"unripe={prediction['probabilities']['unripe']:.3f}")
164
+ print()
165
+
166
+ # Store result
167
+ results.append({
168
+ 'image_path': image_path,
169
+ 'true_label': true_label,
170
+ 'predicted_class': predicted_class,
171
+ 'confidence': confidence,
172
+ 'probabilities': prediction['probabilities'],
173
+ 'correct': is_correct
174
+ })
175
+
176
+ # Calculate accuracy
177
+ accuracy = (correct_predictions / total_predictions * 100) if total_predictions > 0 else 0
178
+
179
+ print("=" * 80)
180
+ print("VALIDATION RESULTS")
181
+ print("=" * 80)
182
+ print(f"Total images tested: {total_predictions}")
183
+ print(f"Correct predictions: {correct_predictions}")
184
+ print(f"Accuracy: {accuracy:.1f}%")
185
+ print()
186
+
187
+ # Class-wise analysis
188
+ class_stats = {}
189
+ for result in results:
190
+ true_class = result['true_label']
191
+ if true_class not in class_stats:
192
+ class_stats[true_class] = {'correct': 0, 'total': 0}
193
+ class_stats[true_class]['total'] += 1
194
+ if result['correct']:
195
+ class_stats[true_class]['correct'] += 1
196
+
197
+ print("Class-wise Performance:")
198
+ for class_name, stats in class_stats.items():
199
+ class_accuracy = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0
200
+ print(f" {class_name}: {stats['correct']}/{stats['total']} ({class_accuracy:.1f}%)")
201
+ print()
202
+
203
+ # Save detailed results
204
+ validation_results = {
205
+ 'validation_time': datetime.now().isoformat(),
206
+ 'model_path': model_path,
207
+ 'device': str(device),
208
+ 'total_images': total_predictions,
209
+ 'correct_predictions': correct_predictions,
210
+ 'accuracy_percent': accuracy,
211
+ 'class_stats': class_stats,
212
+ 'detailed_results': results
213
+ }
214
+
215
+ results_path = 'model_validation_results.json'
216
+ with open(results_path, 'w') as f:
217
+ json.dump(validation_results, f, indent=2)
218
+
219
+ print(f"Detailed results saved to: {results_path}")
220
+
221
+ # Validation verdict
222
+ if accuracy >= 90:
223
+ print("🎉 VALIDATION PASSED: Model performs excellently!")
224
+ return True
225
+ elif accuracy >= 80:
226
+ print("⚠️ VALIDATION WARNING: Model performs moderately well")
227
+ return True
228
+ else:
229
+ print("❌ VALIDATION FAILED: Model performance is poor")
230
+ return False
231
+
232
+ if __name__ == '__main__':
233
+ success = validate_model()
234
+ sys.exit(0 if success else 1)
scripts/webcam_capture.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+
4
+ cap = cv2.VideoCapture(0)
5
+
6
+ if not cap.isOpened():
7
+ print("Cannot open camera")
8
+ exit()
9
+
10
+ while True:
11
+ ret, frame = cap.read()
12
+ if ret:
13
+ cv2.imwrite('captured_frame.jpg', frame)
14
+ print("Frame captured and saved as captured_frame.jpg")
15
+ else:
16
+ print("Can't receive frame")
17
+ break
18
+ time.sleep(1)
19
+
20
+ cap.release()
src/arduino_bridge.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Arduino Bridge - Serial Communication for Robotic Arm Control
4
+ Handles communication between Python pipeline and Arduino microcontroller
5
+
6
+ Author: AI Assistant
7
+ Date: 2025-12-15
8
+ """
9
+
10
+ import serial
11
+ import serial.tools.list_ports
12
+ import time
13
+ import logging
14
+ import threading
15
+ from typing import Optional, Tuple, List, Dict
16
+ from dataclasses import dataclass
17
+ import json
18
+ import re
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ @dataclass
23
+ class ServoPosition:
24
+ """Represents a servo position"""
25
+ servo_id: int
26
+ angle: float
27
+ timestamp: float
28
+
29
+ @dataclass
30
+ class SensorData:
31
+ """Represents sensor data from Arduino"""
32
+ limit_switches: Dict[int, bool]
33
+ force_sensor: float
34
+ temperature: float
35
+ timestamp: float
36
+
37
+ class ArduinoBridge:
38
+ """
39
+ Bridge for communication with Arduino-based robotic arm controller
40
+ Handles servo control, sensor reading, and safety monitoring
41
+ """
42
+
43
+ def __init__(self, port: str = '/dev/ttyUSB0', baudrate: int = 115200):
44
+ """Initialize Arduino bridge"""
45
+ self.port = port
46
+ self.baudrate = baudrate
47
+ self.serial_connection: Optional[serial.Serial] = None
48
+ self.connected = False
49
+ self.running = False
50
+
51
+ # Command queue for async communication
52
+ self.command_queue = []
53
+ self.response_queue = []
54
+ self.queue_lock = threading.Lock()
55
+
56
+ # Current servo positions
57
+ self.current_positions = {}
58
+ self.target_positions = {}
59
+
60
+ # Sensor data
61
+ self.latest_sensor_data: Optional[SensorData] = None
62
+ self.sensor_history = []
63
+
64
+ # Safety parameters
65
+ self.max_servo_angle = 180.0
66
+ self.min_servo_angle = 0.0
67
+ self.movement_timeout = 10.0 # seconds
68
+ self.emergency_stop_flag = False
69
+
70
+ # Communication thread
71
+ self.comm_thread: Optional[threading.Thread] = None
72
+
73
+ logger.info(f"Arduino Bridge initialized for port {port} at {baudrate} baud")
74
+
75
+ def connect(self) -> bool:
76
+ """Connect to Arduino via serial port"""
77
+ try:
78
+ # Auto-detect port if not specified
79
+ if self.port == '/dev/ttyUSB0':
80
+ self.port = self._auto_detect_port()
81
+ if not self.port:
82
+ logger.error("Could not auto-detect Arduino port")
83
+ return False
84
+
85
+ # Establish serial connection
86
+ self.serial_connection = serial.Serial(
87
+ port=self.port,
88
+ baudrate=self.baudrate,
89
+ timeout=1.0,
90
+ write_timeout=1.0
91
+ )
92
+
93
+ # Wait for Arduino to initialize
94
+ time.sleep(2.0)
95
+
96
+ # Test connection
97
+ if self._test_connection():
98
+ self.connected = True
99
+ self.running = True
100
+
101
+ # Start communication thread
102
+ self.comm_thread = threading.Thread(target=self._communication_loop, daemon=True)
103
+ self.comm_thread.start()
104
+
105
+ logger.info(f"Successfully connected to Arduino on {self.port}")
106
+ return True
107
+ else:
108
+ logger.error("Arduino connection test failed")
109
+ self.disconnect()
110
+ return False
111
+
112
+ except Exception as e:
113
+ logger.error(f"Failed to connect to Arduino: {e}")
114
+ return False
115
+
116
+ def disconnect(self):
117
+ """Disconnect from Arduino"""
118
+ logger.info("Disconnecting from Arduino...")
119
+
120
+ self.running = False
121
+ self.connected = False
122
+
123
+ # Stop communication thread
124
+ if self.comm_thread and self.comm_thread.is_alive():
125
+ self.comm_thread.join(timeout=2.0)
126
+
127
+ # Close serial connection
128
+ if self.serial_connection and self.serial_connection.is_open:
129
+ self.serial_connection.close()
130
+
131
+ logger.info("Arduino disconnected")
132
+
133
+ def _auto_detect_port(self) -> Optional[str]:
134
+ """Auto-detect Arduino port"""
135
+ ports = serial.tools.list_ports.comports()
136
+
137
+ for port in ports:
138
+ # Look for common Arduino identifiers
139
+ if any(keyword in port.description.lower() for keyword in
140
+ ['arduino', 'ch340', 'cp2102', 'ftdi']):
141
+ logger.info(f"Auto-detected Arduino on port: {port.device}")
142
+ return port.device
143
+
144
+ # If no Arduino found, return first available port
145
+ if ports:
146
+ logger.warning(f"No Arduino detected, using first available port: {ports[0].device}")
147
+ return ports[0].device
148
+
149
+ logger.error("No serial ports available")
150
+ return None
151
+
152
+ def _test_connection(self) -> bool:
153
+ """Test connection with Arduino"""
154
+ try:
155
+ # Send test command
156
+ self._send_command("PING")
157
+
158
+ # Wait for response
159
+ start_time = time.time()
160
+ while time.time() - start_time < 3.0:
161
+ if self._check_response("PONG"):
162
+ return True
163
+ time.sleep(0.1)
164
+
165
+ return False
166
+
167
+ except Exception as e:
168
+ logger.error(f"Connection test failed: {e}")
169
+ return False
170
+
171
+ def _communication_loop(self):
172
+ """Main communication loop for async processing"""
173
+ logger.info("Arduino communication loop started")
174
+
175
+ while self.running and self.connected:
176
+ try:
177
+ # Process outgoing commands
178
+ self._process_command_queue()
179
+
180
+ # Read incoming data
181
+ self._read_serial_data()
182
+
183
+ # Process sensor data
184
+ self._process_sensor_data()
185
+
186
+ time.sleep(0.01) # 10ms loop delay
187
+
188
+ except Exception as e:
189
+ logger.error(f"Communication loop error: {e}")
190
+ time.sleep(0.1)
191
+
192
+ logger.info("Arduino communication loop stopped")
193
+
194
+ def _process_command_queue(self):
195
+ """Process commands in the queue"""
196
+ with self.queue_lock:
197
+ if not self.command_queue:
198
+ return
199
+
200
+ command = self.command_queue.pop(0)
201
+
202
+ try:
203
+ self._send_raw_command(command)
204
+ except Exception as e:
205
+ logger.error(f"Failed to send command {command}: {e}")
206
+
207
+ def _read_serial_data(self):
208
+ """Read and process incoming serial data"""
209
+ if not self.serial_connection or not self.serial_connection.is_open:
210
+ return
211
+
212
+ try:
213
+ if self.serial_connection.in_waiting > 0:
214
+ line = self.serial_connection.readline().decode('utf-8').strip()
215
+ self._process_serial_line(line)
216
+ except Exception as e:
217
+ logger.error(f"Error reading serial data: {e}")
218
+
219
+ def _process_serial_line(self, line: str):
220
+ """Process a single line of serial data"""
221
+ try:
222
+ # Parse different message types
223
+ if line.startswith("SENSOR:"):
224
+ self._parse_sensor_data(line)
225
+ elif line.startswith("STATUS:"):
226
+ self._parse_status_data(line)
227
+ elif line.startswith("ERROR:"):
228
+ logger.error(f"Arduino error: {line}")
229
+ elif line.startswith("DEBUG:"):
230
+ logger.debug(f"Arduino debug: {line}")
231
+ else:
232
+ # Add to response queue
233
+ with self.queue_lock:
234
+ self.response_queue.append(line)
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error processing serial line '{line}': {e}")
238
+
239
+ def _parse_sensor_data(self, line: str):
240
+ """Parse sensor data from Arduino"""
241
+ try:
242
+ # Format: SENSOR:limit_sw1:0,limit_sw2:1,force:45.2,temp:23.5
243
+ data_part = line[7:] # Remove "SENSOR:" prefix
244
+
245
+ sensor_data = {}
246
+ for item in data_part.split(','):
247
+ key, value = item.split(':')
248
+ if key.startswith('limit_sw'):
249
+ sensor_data[key] = bool(int(value))
250
+ elif key == 'force':
251
+ sensor_data[key] = float(value)
252
+ elif key == 'temp':
253
+ sensor_data[key] = float(value)
254
+
255
+ self.latest_sensor_data = SensorData(
256
+ limit_switches={k: v for k, v in sensor_data.items() if k.startswith('limit_sw')},
257
+ force_sensor=sensor_data.get('force', 0.0),
258
+ temperature=sensor_data.get('temp', 0.0),
259
+ timestamp=time.time()
260
+ )
261
+
262
+ # Add to history
263
+ self.sensor_history.append(self.latest_sensor_data)
264
+ if len(self.sensor_history) > 100: # Keep last 100 readings
265
+ self.sensor_history.pop(0)
266
+
267
+ except Exception as e:
268
+ logger.error(f"Error parsing sensor data: {e}")
269
+
270
+ def _parse_status_data(self, line: str):
271
+ """Parse status data from Arduino"""
272
+ try:
273
+ # Format: STATUS:servo1:90.5,servo2:45.0
274
+ data_part = line[7:] # Remove "STATUS:" prefix
275
+
276
+ for item in data_part.split(','):
277
+ servo_id, angle = item.split(':')
278
+ servo_num = int(servo_id.replace('servo', ''))
279
+ self.current_positions[servo_num] = float(angle)
280
+
281
+ except Exception as e:
282
+ logger.error(f"Error parsing status data: {e}")
283
+
284
+ def _send_command(self, command: str):
285
+ """Send command and wait for response"""
286
+ with self.queue_lock:
287
+ self.command_queue.append(command)
288
+
289
+ # Wait for response
290
+ start_time = time.time()
291
+ while time.time() - start_time < 5.0:
292
+ if self._check_response(command):
293
+ return True
294
+ time.sleep(0.1)
295
+
296
+ logger.warning(f"No response received for command: {command}")
297
+ return False
298
+
299
+ def _send_raw_command(self, command: str):
300
+ """Send raw command without queuing"""
301
+ if not self.serial_connection or not self.serial_connection.is_open:
302
+ raise Exception("Serial connection not available")
303
+
304
+ self.serial_connection.write(f"{command}\n".encode('utf-8'))
305
+ self.serial_connection.flush()
306
+
307
+ def _check_response(self, expected_command: str) -> bool:
308
+ """Check if expected response is in queue"""
309
+ with self.queue_lock:
310
+ for i, response in enumerate(self.response_queue):
311
+ if expected_command in response:
312
+ self.response_queue.pop(i)
313
+ return True
314
+ return False
315
+
316
+ def initialize_servos(self):
317
+ """Initialize all servos to home position"""
318
+ logger.info("Initializing servos...")
319
+
320
+ # Home positions for each servo (adjust based on your robot design)
321
+ home_positions = {
322
+ 1: 90.0, # Base rotation
323
+ 2: 45.0, # Shoulder
324
+ 3: 90.0, # Elbow
325
+ 4: 90.0, # Wrist
326
+ 5: 0.0, # Gripper
327
+ }
328
+
329
+ for servo_id, position in home_positions.items():
330
+ self.move_servo(servo_id, position)
331
+ time.sleep(0.5) # Small delay between servo movements
332
+
333
+ logger.info("Servos initialized to home positions")
334
+
335
+ def move_servo(self, servo_id: int, angle: float, speed: float = 1.0) -> bool:
336
+ """Move a specific servo to target angle"""
337
+ if not self.connected:
338
+ logger.error("Not connected to Arduino")
339
+ return False
340
+
341
+ # Validate angle
342
+ if not (self.min_servo_angle <= angle <= self.max_servo_angle):
343
+ logger.error(f"Invalid angle {angle} for servo {servo_id}")
344
+ return False
345
+
346
+ # Check emergency stop
347
+ if self.emergency_stop_flag:
348
+ logger.warning("Emergency stop active, ignoring servo command")
349
+ return False
350
+
351
+ try:
352
+ command = f"MOVE:{servo_id}:{angle:.1f}:{speed:.2f}"
353
+ success = self._send_command(command)
354
+
355
+ if success:
356
+ self.target_positions[servo_id] = angle
357
+ logger.debug(f"Moved servo {servo_id} to {angle} degrees")
358
+
359
+ return success
360
+
361
+ except Exception as e:
362
+ logger.error(f"Failed to move servo {servo_id}: {e}")
363
+ return False
364
+
365
+ def move_to_position(self, x: float, y: float, z: float) -> bool:
366
+ """Move robotic arm to 3D position using inverse kinematics"""
367
+ if not self.connected:
368
+ logger.error("Not connected to Arduino")
369
+ return False
370
+
371
+ try:
372
+ # Convert 3D coordinates to servo angles
373
+ # This is a simplified version - implement proper inverse kinematics
374
+ servo_angles = self._inverse_kinematics(x, y, z)
375
+
376
+ # Move all servos simultaneously
377
+ success = True
378
+ for servo_id, angle in servo_angles.items():
379
+ if not self.move_servo(servo_id, angle):
380
+ success = False
381
+
382
+ if success:
383
+ logger.info(f"Moved to position ({x:.2f}, {y:.2f}, {z:.2f})")
384
+
385
+ return success
386
+
387
+ except Exception as e:
388
+ logger.error(f"Failed to move to position: {e}")
389
+ return False
390
+
391
+ def _inverse_kinematics(self, x: float, y: float, z: float) -> Dict[int, float]:
392
+ """Simple inverse kinematics calculation"""
393
+ # This is a placeholder implementation
394
+ # Replace with proper IK calculations for your robot design
395
+
396
+ # Simplified mapping (adjust based on your robot geometry)
397
+ base_angle = (np.arctan2(y, x) * 180 / np.pi) + 90
398
+ shoulder_angle = 45 + (z * 10) # Simplified mapping
399
+ elbow_angle = 90 - (z * 5)
400
+ wrist_angle = 90
401
+
402
+ return {
403
+ 1: np.clip(base_angle, 0, 180),
404
+ 2: np.clip(shoulder_angle, 0, 180),
405
+ 3: np.clip(elbow_angle, 0, 180),
406
+ 4: np.clip(wrist_angle, 0, 180),
407
+ }
408
+
409
+ def open_gripper(self) -> bool:
410
+ """Open the gripper"""
411
+ return self.move_servo(5, 0.0)
412
+
413
+ def close_gripper(self) -> bool:
414
+ """Close the gripper"""
415
+ return self.move_servo(5, 90.0)
416
+
417
+ def emergency_stop(self):
418
+ """Activate emergency stop"""
419
+ logger.warning("EMERGENCY STOP ACTIVATED")
420
+ self.emergency_stop_flag = True
421
+
422
+ # Send emergency stop command
423
+ try:
424
+ self._send_command("ESTOP")
425
+ except:
426
+ pass
427
+
428
+ # Move all servos to safe positions
429
+ safe_positions = {1: 90, 2: 45, 3: 90, 4: 90, 5: 0}
430
+ for servo_id, position in safe_positions.items():
431
+ try:
432
+ self.move_servo(servo_id, position, speed=2.0)
433
+ except:
434
+ pass
435
+
436
+ def reset_emergency_stop(self):
437
+ """Reset emergency stop"""
438
+ self.emergency_stop_flag = False
439
+ logger.info("Emergency stop reset")
440
+
441
+ def get_sensor_data(self) -> Optional[SensorData]:
442
+ """Get latest sensor data"""
443
+ return self.latest_sensor_data
444
+
445
+ def get_servo_positions(self) -> Dict[int, float]:
446
+ """Get current servo positions"""
447
+ return self.current_positions.copy()
448
+
449
+ def is_movement_complete(self, servo_id: int, tolerance: float = 2.0) -> bool:
450
+ """Check if servo movement is complete"""
451
+ if servo_id not in self.target_positions:
452
+ return True
453
+
454
+ current = self.current_positions.get(servo_id, 0)
455
+ target = self.target_positions[servo_id]
456
+
457
+ return abs(current - target) <= tolerance
458
+
459
+ def wait_for_movement(self, servo_ids: List[int], timeout: float = 10.0) -> bool:
460
+ """Wait for servo movements to complete"""
461
+ start_time = time.time()
462
+
463
+ while time.time() - start_time < timeout:
464
+ all_complete = True
465
+ for servo_id in servo_ids:
466
+ if not self.is_movement_complete(servo_id):
467
+ all_complete = False
468
+ break
469
+
470
+ if all_complete:
471
+ return True
472
+
473
+ time.sleep(0.1)
474
+
475
+ logger.warning(f"Movement timeout after {timeout} seconds")
476
+ return False
477
+
478
+ def get_status(self) -> Dict:
479
+ """Get Arduino bridge status"""
480
+ return {
481
+ 'connected': self.connected,
482
+ 'port': self.port,
483
+ 'baudrate': self.baudrate,
484
+ 'emergency_stop': self.emergency_stop_flag,
485
+ 'current_positions': self.get_servo_positions(),
486
+ 'target_positions': self.target_positions.copy(),
487
+ 'latest_sensor_data': self.latest_sensor_data.__dict__ if self.latest_sensor_data else None,
488
+ 'queue_size': len(self.command_queue)
489
+ }
490
+
491
+ def main():
492
+ """Test Arduino bridge functionality"""
493
+ import argparse
494
+
495
+ parser = argparse.ArgumentParser(description='Test Arduino Bridge')
496
+ parser.add_argument('--port', default='/dev/ttyUSB0', help='Arduino port')
497
+ parser.add_argument('--baudrate', type=int, default=115200, help='Baud rate')
498
+
499
+ args = parser.parse_args()
500
+
501
+ # Create bridge
502
+ bridge = ArduinoBridge(args.port, args.baudrate)
503
+
504
+ try:
505
+ # Connect
506
+ if bridge.connect():
507
+ print("Connected to Arduino successfully")
508
+
509
+ # Initialize servos
510
+ bridge.initialize_servos()
511
+ time.sleep(2)
512
+
513
+ # Test movements
514
+ print("Testing servo movements...")
515
+ bridge.move_servo(1, 45)
516
+ time.sleep(1)
517
+ bridge.move_servo(1, 135)
518
+ time.sleep(1)
519
+ bridge.move_servo(1, 90)
520
+
521
+ # Test gripper
522
+ print("Testing gripper...")
523
+ bridge.open_gripper()
524
+ time.sleep(1)
525
+ bridge.close_gripper()
526
+ time.sleep(1)
527
+ bridge.open_gripper()
528
+
529
+ # Print status
530
+ print("\nArduino Bridge Status:")
531
+ print(json.dumps(bridge.get_status(), indent=2, default=str))
532
+
533
+ else:
534
+ print("Failed to connect to Arduino")
535
+
536
+ except KeyboardInterrupt:
537
+ print("\nInterrupted by user")
538
+
539
+ finally:
540
+ bridge.disconnect()
541
+ print("Arduino bridge disconnected")
542
+
543
+ if __name__ == "__main__":
544
+ main()
src/coordinate_transformer.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Coordinate Transformer - Pixel to Robot Coordinate System
4
+ Handles camera calibration, stereo vision, and coordinate transformations
5
+
6
+ Author: AI Assistant
7
+ Date: 2025-12-15
8
+ """
9
+
10
+ import numpy as np
11
+ import cv2
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Tuple, Optional, Dict, List
15
+ from dataclasses import dataclass
16
+ import json
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @dataclass
21
+ class CameraCalibration:
22
+ """Camera calibration parameters"""
23
+ camera_matrix: np.ndarray
24
+ distortion_coeffs: np.ndarray
25
+ image_width: int
26
+ image_height: int
27
+ focal_length: float
28
+ principal_point: Tuple[float, float]
29
+
30
+ @dataclass
31
+ class StereoCalibration:
32
+ """Stereo camera calibration parameters"""
33
+ left_camera_matrix: np.ndarray
34
+ right_camera_matrix: np.ndarray
35
+ left_distortion: np.ndarray
36
+ right_distortion: np.ndarray
37
+ rotation_matrix: np.ndarray
38
+ translation_vector: np.ndarray
39
+ essential_matrix: np.ndarray
40
+ fundamental_matrix: np.ndarray
41
+ rectification_matrix_left: np.ndarray
42
+ rectification_matrix_right: np.ndarray
43
+ projection_matrix_left: np.ndarray
44
+ projection_matrix_right: np.ndarray
45
+ disparity_to_depth_map: np.ndarray
46
+
47
+ class CoordinateTransformer:
48
+ """
49
+ Handles coordinate transformations between pixel coordinates and robot world coordinates
50
+ Supports both monocular and stereo camera systems
51
+ """
52
+
53
+ def __init__(self,
54
+ camera_matrix_path: Optional[str] = None,
55
+ distortion_coeffs_path: Optional[str] = None,
56
+ stereo_calibration_path: Optional[str] = None):
57
+ """Initialize coordinate transformer"""
58
+
59
+ self.camera_calibration: Optional[CameraCalibration] = None
60
+ self.stereo_calibration: Optional[StereoCalibration] = None
61
+ self.stereo_matcher: Optional[cv2.StereoSGBM] = None
62
+
63
+ # Robot coordinate system parameters
64
+ self.robot_origin = (0.0, 0.0, 0.0) # Robot base position
65
+ self.camera_to_robot_transform = np.eye(4) # 4x4 transformation matrix
66
+
67
+ # Load calibrations if provided
68
+ if camera_matrix_path and distortion_coeffs_path:
69
+ self.load_camera_calibration(camera_matrix_path, distortion_coeffs_path)
70
+
71
+ if stereo_calibration_path:
72
+ self.load_stereo_calibration(stereo_calibration_path)
73
+
74
+ logger.info("Coordinate Transformer initialized")
75
+
76
+ def load_camera_calibration(self, camera_matrix_path: str, distortion_coeffs_path: str):
77
+ """Load single camera calibration"""
78
+ try:
79
+ camera_matrix = np.load(camera_matrix_path)
80
+ distortion_coeffs = np.load(distortion_coeffs_path)
81
+
82
+ # Extract calibration parameters
83
+ fx, fy = camera_matrix[0, 0], camera_matrix[1, 1]
84
+ cx, cy = camera_matrix[0, 2], camera_matrix[1, 2]
85
+
86
+ self.camera_calibration = CameraCalibration(
87
+ camera_matrix=camera_matrix,
88
+ distortion_coeffs=distortion_coeffs,
89
+ image_width=int(camera_matrix[0, 2] * 2), # Approximate
90
+ image_height=int(camera_matrix[1, 2] * 2), # Approximate
91
+ focal_length=(fx + fy) / 2,
92
+ principal_point=(cx, cy)
93
+ )
94
+
95
+ logger.info("Camera calibration loaded successfully")
96
+
97
+ except Exception as e:
98
+ logger.error(f"Failed to load camera calibration: {e}")
99
+ raise
100
+
101
+ def load_stereo_calibration(self, stereo_calibration_path: str):
102
+ """Load stereo camera calibration"""
103
+ try:
104
+ calibration_data = np.load(stereo_calibration_path)
105
+
106
+ self.stereo_calibration = StereoCalibration(
107
+ left_camera_matrix=calibration_data['left_camera_matrix'],
108
+ right_camera_matrix=calibration_data['right_camera_matrix'],
109
+ left_distortion=calibration_data['left_distortion'],
110
+ right_distortion=calibration_data['right_distortion'],
111
+ rotation_matrix=calibration_data['rotation_matrix'],
112
+ translation_vector=calibration_data['translation_vector'],
113
+ essential_matrix=calibration_data['essential_matrix'],
114
+ fundamental_matrix=calibration_data['fundamental_matrix'],
115
+ rectification_matrix_left=calibration_data['rectification_matrix_left'],
116
+ rectification_matrix_right=calibration_data['rectification_matrix_right'],
117
+ projection_matrix_left=calibration_data['projection_matrix_left'],
118
+ projection_matrix_right=calibration_data['projection_matrix_right'],
119
+ disparity_to_depth_map=calibration_data['disparity_to_depth_map']
120
+ )
121
+
122
+ # Initialize stereo matcher for real-time depth calculation
123
+ self._initialize_stereo_matcher()
124
+
125
+ logger.info("Stereo calibration loaded successfully")
126
+
127
+ except Exception as e:
128
+ logger.error(f"Failed to load stereo calibration: {e}")
129
+ raise
130
+
131
+ def _initialize_stereo_matcher(self):
132
+ """Initialize stereo matching for depth calculation"""
133
+ if not self.stereo_calibration:
134
+ return
135
+
136
+ # Create stereo block matcher
137
+ self.stereo_matcher = cv2.StereoSGBM_create(
138
+ minDisparity=0,
139
+ numDisparities=64,
140
+ blockSize=9,
141
+ P1=8 * 9 * 9,
142
+ P2=32 * 9 * 9,
143
+ disp12MaxDiff=1,
144
+ uniquenessRatio=10,
145
+ speckleWindowSize=100,
146
+ speckleRange=32
147
+ )
148
+
149
+ logger.info("Stereo matcher initialized")
150
+
151
+ def pixel_to_world(self,
152
+ pixel_x: int,
153
+ pixel_y: int,
154
+ image_shape: Tuple[int, int, int],
155
+ depth: Optional[float] = None) -> Tuple[float, float, float]:
156
+ """
157
+ Convert pixel coordinates to world coordinates
158
+
159
+ Args:
160
+ pixel_x, pixel_y: Pixel coordinates
161
+ image_shape: Shape of the image (height, width, channels)
162
+ depth: Depth in meters (if None, uses stereo or default depth)
163
+
164
+ Returns:
165
+ Tuple of (x, y, z) world coordinates in meters
166
+ """
167
+ if not self.camera_calibration:
168
+ raise ValueError("Camera calibration not loaded")
169
+
170
+ # Get depth if not provided
171
+ if depth is None:
172
+ depth = self._estimate_depth(pixel_x, pixel_y, image_shape)
173
+
174
+ # Convert pixel to normalized coordinates
175
+ fx, fy = self.camera_calibration.camera_matrix[0, 0], self.camera_calibration.camera_matrix[1, 1]
176
+ cx, cy = self.camera_calibration.principal_point
177
+
178
+ # Convert to camera coordinates
179
+ x_cam = (pixel_x - cx) * depth / fx
180
+ y_cam = (pixel_y - cy) * depth / fy
181
+ z_cam = depth
182
+
183
+ # Transform to robot world coordinates
184
+ camera_point = np.array([x_cam, y_cam, z_cam, 1.0])
185
+ world_point = self.camera_to_robot_transform @ camera_point
186
+
187
+ return (float(world_point[0]), float(world_point[1]), float(world_point[2]))
188
+
189
+ def world_to_pixel(self,
190
+ world_x: float,
191
+ world_y: float,
192
+ world_z: float) -> Tuple[int, int]:
193
+ """
194
+ Convert world coordinates to pixel coordinates
195
+
196
+ Returns:
197
+ Tuple of (pixel_x, pixel_y) coordinates
198
+ """
199
+ if not self.camera_calibration:
200
+ raise ValueError("Camera calibration not loaded")
201
+
202
+ # Transform world point to camera coordinates
203
+ world_point = np.array([world_x, world_y, world_z, 1.0])
204
+ camera_point = np.linalg.inv(self.camera_to_robot_transform) @ world_point
205
+
206
+ x_cam, y_cam, z_cam = camera_point[:3]
207
+
208
+ # Convert to pixel coordinates
209
+ fx, fy = self.camera_calibration.camera_matrix[0, 0], self.camera_calibration.camera_matrix[1, 1]
210
+ cx, cy = self.camera_calibration.principal_point
211
+
212
+ pixel_x = int(x_cam * fx / z_cam + cx)
213
+ pixel_y = int(y_cam * fy / z_cam + cy)
214
+
215
+ return (pixel_x, pixel_y)
216
+
217
+ def _estimate_depth(self, pixel_x: int, pixel_y: int, image_shape: Tuple[int, int, int]) -> float:
218
+ """Estimate depth using stereo vision or default depth"""
219
+
220
+ # If stereo calibration is available, try to get depth from disparity
221
+ if self.stereo_calibration and len(image_shape) == 3:
222
+ # This would require stereo images - simplified for now
223
+ pass
224
+
225
+ # Default depth estimation based on image position
226
+ # Assume strawberries are typically 20-50cm from camera
227
+ image_height = image_shape[0]
228
+
229
+ # Simple depth estimation based on vertical position
230
+ # Lower in image = closer to camera
231
+ normalized_y = pixel_y / image_height
232
+ estimated_depth = 0.2 + (0.3 * (1.0 - normalized_y)) # 20-50cm range
233
+
234
+ return estimated_depth
235
+
236
+ def calculate_depth_from_stereo(self,
237
+ left_image: np.ndarray,
238
+ right_image: np.ndarray,
239
+ pixel_x: int,
240
+ pixel_y: int) -> Optional[float]:
241
+ """Calculate depth from stereo images"""
242
+ if not self.stereo_matcher or not self.stereo_calibration:
243
+ return None
244
+
245
+ try:
246
+ # Compute disparity map
247
+ disparity = self.stereo_matcher.compute(left_image, right_image)
248
+
249
+ # Get disparity at specific pixel
250
+ if 0 <= pixel_y < disparity.shape[0] and 0 <= pixel_x < disparity.shape[1]:
251
+ disp_value = disparity[pixel_y, pixel_x]
252
+
253
+ if disp_value > 0:
254
+ # Convert disparity to depth
255
+ depth = self.stereo_calibration.disparity_to_depth_map[disp_value]
256
+ return float(depth)
257
+
258
+ return None
259
+
260
+ except Exception as e:
261
+ logger.error(f"Error calculating stereo depth: {e}")
262
+ return None
263
+
264
+ def undistort_point(self, pixel_x: int, pixel_y: int) -> Tuple[int, int]:
265
+ """Undistort pixel coordinates using camera calibration"""
266
+ if not self.camera_calibration:
267
+ return pixel_x, pixel_y
268
+
269
+ try:
270
+ # Create point array
271
+ points = np.array([[pixel_x, pixel_y]], dtype=np.float32)
272
+
273
+ # Undistort points
274
+ undistorted_points = cv2.undistortPoints(
275
+ points,
276
+ self.camera_calibration.camera_matrix,
277
+ self.camera_calibration.distortion_coeffs
278
+ )
279
+
280
+ return int(undistorted_points[0][0][0]), int(undistorted_points[0][0][1])
281
+
282
+ except Exception as e:
283
+ logger.error(f"Error undistorting point: {e}")
284
+ return pixel_x, pixel_y
285
+
286
+ def set_camera_to_robot_transform(self, transform_matrix: np.ndarray):
287
+ """Set the transformation matrix from camera to robot coordinates"""
288
+ if transform_matrix.shape != (4, 4):
289
+ raise ValueError("Transform matrix must be 4x4")
290
+
291
+ self.camera_to_robot_transform = transform_matrix
292
+ logger.info("Camera to robot transform updated")
293
+
294
+ def calibrate_camera_to_robot(self,
295
+ world_points: List[Tuple[float, float, float]],
296
+ pixel_points: List[Tuple[int, int]]) -> bool:
297
+ """
298
+ Calibrate camera to robot transformation using known correspondences
299
+
300
+ Args:
301
+ world_points: List of (x, y, z) world coordinates
302
+ pixel_points: List of (pixel_x, pixel_y) pixel coordinates
303
+
304
+ Returns:
305
+ True if calibration successful
306
+ """
307
+ if len(world_points) != len(pixel_points) or len(world_points) < 4:
308
+ logger.error("Need at least 4 corresponding points for calibration")
309
+ return False
310
+
311
+ try:
312
+ # Prepare point correspondences
313
+ world_points_3d = np.array(world_points, dtype=np.float32)
314
+ pixel_points_2d = np.array(pixel_points, dtype=np.float32)
315
+
316
+ # Solve PnP problem
317
+ success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac(
318
+ world_points_3d,
319
+ pixel_points_2d,
320
+ self.camera_calibration.camera_matrix,
321
+ self.camera_calibration.distortion_coeffs
322
+ )
323
+
324
+ if success:
325
+ # Convert rotation vector to rotation matrix
326
+ rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
327
+
328
+ # Create 4x4 transformation matrix
329
+ transform_matrix = np.eye(4)
330
+ transform_matrix[:3, :3] = rotation_matrix
331
+ transform_matrix[:3, 3] = translation_vector.flatten()
332
+
333
+ self.set_camera_to_robot_transform(transform_matrix)
334
+
335
+ logger.info("Camera to robot calibration successful")
336
+ return True
337
+ else:
338
+ logger.error("PnP calibration failed")
339
+ return False
340
+
341
+ except Exception as e:
342
+ logger.error(f"Calibration error: {e}")
343
+ return False
344
+
345
+ def get_workspace_bounds(self) -> Dict[str, Tuple[float, float]]:
346
+ """Get the bounds of the robot workspace in world coordinates"""
347
+ # This should be calibrated for your specific robot setup
348
+ return {
349
+ 'x_min': -0.5, 'x_max': 0.5,
350
+ 'y_min': -0.5, 'y_max': 0.5,
351
+ 'z_min': 0.0, 'z_max': 0.5
352
+ }
353
+
354
+ def is_point_in_workspace(self, x: float, y: float, z: float) -> bool:
355
+ """Check if a point is within the robot workspace"""
356
+ bounds = self.get_workspace_bounds()
357
+
358
+ return (bounds['x_min'] <= x <= bounds['x_max'] and
359
+ bounds['y_min'] <= y <= bounds['y_max'] and
360
+ bounds['z_min'] <= z <= bounds['z_max'])
361
+
362
+ def project_3d_to_2d(self,
363
+ world_points: List[Tuple[float, float, float]],
364
+ image_shape: Tuple[int, int]) -> List[Tuple[int, int]]:
365
+ """Project 3D world points to 2D image coordinates"""
366
+ if not self.camera_calibration:
367
+ raise ValueError("Camera calibration not loaded")
368
+
369
+ projected_points = []
370
+
371
+ for world_point in world_points:
372
+ pixel_x, pixel_y = self.world_to_pixel(*world_point)
373
+
374
+ # Check if point is within image bounds
375
+ if (0 <= pixel_x < image_shape[1] and 0 <= pixel_y < image_shape[0]):
376
+ projected_points.append((pixel_x, pixel_y))
377
+ else:
378
+ projected_points.append((-1, -1)) # Outside image
379
+
380
+ return projected_points
381
+
382
+ def save_calibration(self, filepath: str):
383
+ """Save current calibration to file"""
384
+ calibration_data = {
385
+ 'camera_calibration': {
386
+ 'camera_matrix': self.camera_calibration.camera_matrix.tolist() if self.camera_calibration else None,
387
+ 'distortion_coeffs': self.camera_calibration.distortion_coeffs.tolist() if self.camera_calibration else None,
388
+ 'image_width': self.camera_calibration.image_width if self.camera_calibration else None,
389
+ 'image_height': self.camera_calibration.image_height if self.camera_calibration else None,
390
+ 'focal_length': self.camera_calibration.focal_length if self.camera_calibration else None,
391
+ 'principal_point': self.camera_calibration.principal_point if self.camera_calibration else None
392
+ },
393
+ 'stereo_calibration': {
394
+ 'left_camera_matrix': self.stereo_calibration.left_camera_matrix.tolist() if self.stereo_calibration else None,
395
+ 'right_camera_matrix': self.stereo_calibration.right_camera_matrix.tolist() if self.stereo_calibration else None,
396
+ 'rotation_matrix': self.stereo_calibration.rotation_matrix.tolist() if self.stereo_calibration else None,
397
+ 'translation_vector': self.stereo_calibration.translation_vector.tolist() if self.stereo_calibration else None
398
+ },
399
+ 'camera_to_robot_transform': self.camera_to_robot_transform.tolist(),
400
+ 'robot_origin': self.robot_origin
401
+ }
402
+
403
+ with open(filepath, 'w') as f:
404
+ json.dump(calibration_data, f, indent=2)
405
+
406
+ logger.info(f"Calibration saved to {filepath}")
407
+
408
+ def load_calibration(self, filepath: str):
409
+ """Load calibration from file"""
410
+ try:
411
+ with open(filepath, 'r') as f:
412
+ calibration_data = json.load(f)
413
+
414
+ # Load camera calibration
415
+ if calibration_data['camera_calibration']['camera_matrix']:
416
+ cam_data = calibration_data['camera_calibration']
417
+ self.camera_calibration = CameraCalibration(
418
+ camera_matrix=np.array(cam_data['camera_matrix']),
419
+ distortion_coeffs=np.array(cam_data['distortion_coeffs']),
420
+ image_width=cam_data['image_width'],
421
+ image_height=cam_data['image_height'],
422
+ focal_length=cam_data['focal_length'],
423
+ principal_point=tuple(cam_data['principal_point'])
424
+ )
425
+
426
+ # Load stereo calibration
427
+ if calibration_data['stereo_calibration']['left_camera_matrix']:
428
+ stereo_data = calibration_data['stereo_calibration']
429
+ # Note: This is simplified - you'd need to load all stereo parameters
430
+ logger.warning("Stereo calibration loading not fully implemented")
431
+
432
+ # Load transformation matrix
433
+ self.camera_to_robot_transform = np.array(calibration_data['camera_to_robot_transform'])
434
+ self.robot_origin = tuple(calibration_data['robot_origin'])
435
+
436
+ logger.info(f"Calibration loaded from {filepath}")
437
+
438
+ except Exception as e:
439
+ logger.error(f"Failed to load calibration: {e}")
440
+ raise
441
+
442
+ def main():
443
+ """Test coordinate transformer functionality"""
444
+ import argparse
445
+
446
+ parser = argparse.ArgumentParser(description='Test Coordinate Transformer')
447
+ parser.add_argument('--camera-matrix', help='Camera matrix file path')
448
+ parser.add_argument('--distortion', help='Distortion coefficients file path')
449
+ parser.add_argument('--stereo', help='Stereo calibration file path')
450
+ parser.add_argument('--test-pixel', nargs=2, type=int, metavar=('X', 'Y'),
451
+ help='Test pixel coordinates')
452
+
453
+ args = parser.parse_args()
454
+
455
+ try:
456
+ # Create transformer
457
+ transformer = CoordinateTransformer(
458
+ args.camera_matrix,
459
+ args.distortion,
460
+ args.stereo
461
+ )
462
+
463
+ print("Coordinate Transformer initialized")
464
+
465
+ if args.test_pixel:
466
+ pixel_x, pixel_y = args.test_pixel
467
+ world_coords = transformer.pixel_to_world(pixel_x, pixel_y, (480, 640, 3))
468
+ print(f"Pixel ({pixel_x}, {pixel_y}) -> World {world_coords}")
469
+
470
+ # Convert back
471
+ pixel_x_back, pixel_y_back = transformer.world_to_pixel(*world_coords)
472
+ print(f"World {world_coords} -> Pixel ({pixel_x_back}, {pixel_y_back})")
473
+
474
+ # Print workspace bounds
475
+ bounds = transformer.get_workspace_bounds()
476
+ print(f"Workspace bounds: {bounds}")
477
+
478
+ except Exception as e:
479
+ print(f"Error: {e}")
480
+
481
+ if __name__ == "__main__":
482
+ main()
src/integrated_detection_classification.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Integrated Strawberry Detection and Ripeness Classification Pipeline
4
+ Combines YOLOv8 detection with 3-class ripeness classification
5
+ """
6
+
7
+ import os
8
+ import argparse
9
+ import json
10
+ import time
11
+ import numpy as np
12
+ import cv2
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+ from pathlib import Path
16
+ import yaml
17
+ from datetime import datetime
18
+ import logging
19
+
20
+ # YOLOv8
21
+ from ultralytics import YOLO
22
+
23
+ # Custom imports
24
+ from train_ripeness_classifier import create_model, get_transforms
25
+
26
+ class StrawberryDetectionClassifier:
27
+ """Integrated detection and classification system"""
28
+
29
+ def __init__(self, detection_model_path, classification_model_path, config_path='config.yaml'):
30
+ self.config = self.load_config(config_path)
31
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # Initialize detection model
34
+ print(f"Loading detection model: {detection_model_path}")
35
+ self.detection_model = YOLO(detection_model_path)
36
+
37
+ # Initialize classification model
38
+ print(f"Loading classification model: {classification_model_path}")
39
+ self.classification_model = self.load_classification_model(classification_model_path)
40
+
41
+ # Get classification transforms
42
+ _, self.classify_transform = get_transforms(img_size=224)
43
+
44
+ # Class names for classification
45
+ self.class_names = ['overripe', 'ripe', 'unripe']
46
+
47
+ # Setup logging
48
+ self.setup_logging()
49
+
50
+ def load_config(self, config_path):
51
+ """Load configuration from YAML file"""
52
+ with open(config_path, 'r') as f:
53
+ return yaml.safe_load(f)
54
+
55
+ def load_classification_model(self, model_path):
56
+ """Load the trained classification model"""
57
+ model = create_model(num_classes=3, backbone='resnet18', pretrained=False)
58
+ model.load_state_dict(torch.load(model_path, map_location=self.device))
59
+ model = model.to(self.device)
60
+ model.eval()
61
+ return model
62
+
63
+ def setup_logging(self):
64
+ """Setup logging configuration"""
65
+ logging.basicConfig(
66
+ level=logging.INFO,
67
+ format='%(asctime)s - %(levelname)s - %(message)s',
68
+ handlers=[
69
+ logging.FileHandler('strawberry_pipeline.log'),
70
+ logging.StreamHandler()
71
+ ]
72
+ )
73
+ self.logger = logging.getLogger(__name__)
74
+
75
+ def detect_strawberries(self, image):
76
+ """Detect strawberries in image using YOLOv8"""
77
+ results = self.detection_model(image)
78
+
79
+ detections = []
80
+ for result in results:
81
+ boxes = result.boxes
82
+ if boxes is not None:
83
+ for box in boxes:
84
+ # Get bounding box coordinates
85
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
86
+ confidence = box.conf[0].cpu().numpy()
87
+
88
+ # Only keep high-confidence detections
89
+ if confidence > 0.5:
90
+ detections.append({
91
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
92
+ 'confidence': float(confidence),
93
+ 'class': int(box.cls[0].cpu().numpy())
94
+ })
95
+
96
+ return detections
97
+
98
+ def classify_ripeness(self, image_crop):
99
+ """Classify ripeness of strawberry crop"""
100
+ try:
101
+ # Apply transforms
102
+ if isinstance(image_crop, np.ndarray):
103
+ image_crop = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
104
+ from PIL import Image
105
+ image_crop = Image.fromarray(image_crop)
106
+
107
+ input_tensor = self.classify_transform(image_crop).unsqueeze(0).to(self.device)
108
+
109
+ # Get prediction
110
+ with torch.no_grad():
111
+ outputs = self.classification_model(input_tensor)
112
+ probabilities = torch.softmax(outputs, dim=1)
113
+ predicted_class = torch.argmax(probabilities, dim=1).item()
114
+ confidence = probabilities[0][predicted_class].item()
115
+
116
+ return {
117
+ 'class': self.class_names[predicted_class],
118
+ 'confidence': float(confidence),
119
+ 'probabilities': {
120
+ self.class_names[i]: float(probabilities[0][i].item())
121
+ for i in range(len(self.class_names))
122
+ }
123
+ }
124
+ except Exception as e:
125
+ self.logger.error(f"Classification error: {e}")
126
+ return {
127
+ 'class': 'unknown',
128
+ 'confidence': 0.0,
129
+ 'probabilities': {cls: 0.0 for cls in self.class_names}
130
+ }
131
+
132
+ def process_image(self, image_path, save_annotated=True, output_dir='results'):
133
+ """Process single image with detection and classification"""
134
+ # Load image
135
+ image = cv2.imread(str(image_path))
136
+ if image is None:
137
+ self.logger.error(f"Could not load image: {image_path}")
138
+ return None
139
+
140
+ # Detect strawberries
141
+ detections = self.detect_strawberries(image)
142
+
143
+ results = {
144
+ 'image_path': str(image_path),
145
+ 'timestamp': datetime.now().isoformat(),
146
+ 'detections': [],
147
+ 'summary': {
148
+ 'total_strawberries': len(detections),
149
+ 'ripeness_counts': {'unripe': 0, 'ripe': 0, 'overripe': 0, 'unknown': 0}
150
+ }
151
+ }
152
+
153
+ # Process each detection
154
+ for i, detection in enumerate(detections):
155
+ x1, y1, x2, y2 = detection['bbox']
156
+
157
+ # Crop strawberry
158
+ strawberry_crop = image[y1:y2, x1:x2]
159
+
160
+ # Classify ripeness
161
+ ripeness = self.classify_ripeness(strawberry_crop)
162
+
163
+ # Update summary
164
+ results['summary']['ripeness_counts'][ripeness['class']] += 1
165
+
166
+ # Store result
167
+ result = {
168
+ 'detection_id': i,
169
+ 'bbox': detection['bbox'],
170
+ 'detection_confidence': detection['confidence'],
171
+ 'ripeness': ripeness
172
+ }
173
+ results['detections'].append(result)
174
+
175
+ # Draw annotations if requested
176
+ if save_annotated:
177
+ color = self.get_ripeness_color(ripeness['class'])
178
+ label = f"{ripeness['class']} ({ripeness['confidence']:.2f})"
179
+
180
+ # Draw bounding box
181
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
182
+
183
+ # Draw label
184
+ cv2.putText(image, label, (x1, y1-10),
185
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
186
+
187
+ # Save annotated image
188
+ if save_annotated:
189
+ os.makedirs(output_dir, exist_ok=True)
190
+ output_path = Path(output_dir) / f"annotated_{Path(image_path).name}"
191
+ cv2.imwrite(str(output_path), image)
192
+ results['annotated_image_path'] = str(output_path)
193
+
194
+ return results
195
+
196
+ def get_ripeness_color(self, ripeness_class):
197
+ """Get color for ripeness class"""
198
+ colors = {
199
+ 'unripe': (0, 255, 0), # Green
200
+ 'ripe': (0, 255, 255), # Yellow
201
+ 'overripe': (0, 0, 255), # Red
202
+ 'unknown': (128, 128, 128) # Gray
203
+ }
204
+ return colors.get(ripeness_class, (128, 128, 128))
205
+
206
+ def main():
207
+ parser = argparse.ArgumentParser(description='Integrated strawberry detection and classification')
208
+ parser.add_argument('--detection-model', default='model/weights/best_yolov8n_strawberry.pt',
209
+ help='Path to YOLOv8 detection model')
210
+ parser.add_argument('--classification-model', default='model/ripeness_classifier_best.pth',
211
+ help='Path to ripeness classification model')
212
+ parser.add_argument('--mode', choices=['image', 'video', 'realtime'], required=True,
213
+ help='Processing mode')
214
+ parser.add_argument('--input', required=True, help='Input path (image/video/camera index)')
215
+ parser.add_argument('--output', help='Output path for results')
216
+ parser.add_argument('--save-annotated', action='store_true', help='Save annotated images')
217
+ parser.add_argument('--config', default='config.yaml', help='Configuration file path')
218
+
219
+ args = parser.parse_args()
220
+
221
+ # Initialize system
222
+ system = StrawberryDetectionClassifier(
223
+ args.detection_model,
224
+ args.classification_model,
225
+ args.config
226
+ )
227
+
228
+ if args.mode == 'image':
229
+ # Process single image
230
+ results = system.process_image(
231
+ args.input,
232
+ save_annotated=args.save_annotated,
233
+ output_dir=args.output or 'results'
234
+ )
235
+
236
+ if results:
237
+ # Save results
238
+ results_path = Path(args.output or 'results') / 'detection_results.json'
239
+ results_path.parent.mkdir(exist_ok=True)
240
+ with open(results_path, 'w') as f:
241
+ json.dump(results, f, indent=2)
242
+
243
+ print(f"Results saved to: {results_path}")
244
+ print(f"Found {results['summary']['total_strawberries']} strawberries")
245
+ print(f"Ripeness distribution: {results['summary']['ripeness_counts']}")
246
+
247
+ if __name__ == '__main__':
248
+ main()
src/strawberry_picker_pipeline.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Strawberry Picker Pipeline - End-to-End Real-time System
4
+ Combines detection, classification, and robotic control for automated strawberry picking
5
+
6
+ Author: AI Assistant
7
+ Date: 2025-12-15
8
+ """
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import time
13
+ import json
14
+ import logging
15
+ from pathlib import Path
16
+ from typing import List, Dict, Tuple, Optional
17
+ import argparse
18
+ import yaml
19
+ from dataclasses import dataclass
20
+ from concurrent.futures import ThreadPoolExecutor
21
+ import threading
22
+
23
+ # Import our custom modules
24
+ from integrated_detection_classification import IntegratedDetectorClassifier
25
+ from arduino_bridge import ArduinoBridge
26
+ from coordinate_transformer import CoordinateTransformer
27
+
28
+ # Configure logging
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ handlers=[
33
+ logging.FileHandler('strawberry_picker.log'),
34
+ logging.StreamHandler()
35
+ ]
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+ @dataclass
40
+ class PickingTarget:
41
+ """Represents a strawberry target for picking"""
42
+ bbox: Tuple[int, int, int, int] # x, y, w, h
43
+ confidence: float
44
+ ripeness: str # 'unripe', 'ripe', 'overripe'
45
+ ripeness_confidence: float
46
+ pixel_coords: Tuple[int, int] # center pixel coordinates
47
+ world_coords: Tuple[float, float, float] # x, y, z in robot coordinates
48
+ priority: float # calculated priority score
49
+
50
+ class StrawberryPickerPipeline:
51
+ """
52
+ Main pipeline for automated strawberry picking system
53
+ Integrates computer vision, classification, and robotic control
54
+ """
55
+
56
+ def __init__(self, config_path: str = "config.yaml"):
57
+ """Initialize the strawberry picker pipeline"""
58
+ self.config = self._load_config(config_path)
59
+ self.running = False
60
+ self.picking_targets = []
61
+ self.processed_count = 0
62
+ self.successful_picks = 0
63
+ self.failed_picks = 0
64
+
65
+ # Initialize components
66
+ self.detector_classifier = IntegratedDetectorClassifier(
67
+ detection_model_path=self.config['models']['detection_model'],
68
+ classification_model_path=self.config['models']['classification_model'],
69
+ confidence_threshold=self.config['detection']['confidence_threshold']
70
+ )
71
+
72
+ self.arduino = ArduinoBridge(
73
+ port=self.config['arduino']['port'],
74
+ baudrate=self.config['arduino']['baudrate']
75
+ )
76
+
77
+ self.coordinate_transformer = CoordinateTransformer(
78
+ camera_matrix_path=self.config['calibration']['camera_matrix'],
79
+ distortion_coeffs_path=self.config['calibration']['distortion_coeffs'],
80
+ stereo_calibration_path=self.config['calibration']['stereo_calibration']
81
+ )
82
+
83
+ # Threading for real-time processing
84
+ self.executor = ThreadPoolExecutor(max_workers=2)
85
+ self.processing_lock = threading.Lock()
86
+
87
+ logger.info("Strawberry Picker Pipeline initialized successfully")
88
+
89
+ def _load_config(self, config_path: str) -> Dict:
90
+ """Load configuration from YAML file"""
91
+ try:
92
+ with open(config_path, 'r') as f:
93
+ return yaml.safe_load(f)
94
+ except FileNotFoundError:
95
+ logger.warning(f"Config file {config_path} not found, using defaults")
96
+ return self._get_default_config()
97
+
98
+ def _get_default_config(self) -> Dict:
99
+ """Get default configuration"""
100
+ return {
101
+ 'models': {
102
+ 'detection_model': 'model/weights/best.pt',
103
+ 'classification_model': 'model/ripeness_classifier.h5'
104
+ },
105
+ 'detection': {
106
+ 'confidence_threshold': 0.5,
107
+ 'nms_threshold': 0.4
108
+ },
109
+ 'arduino': {
110
+ 'port': '/dev/ttyUSB0',
111
+ 'baudrate': 115200
112
+ },
113
+ 'calibration': {
114
+ 'camera_matrix': 'calibration/camera_matrix.npy',
115
+ 'distortion_coeffs': 'calibration/distortion_coeffs.npy',
116
+ 'stereo_calibration': 'calibration/stereo_calibration.npz'
117
+ },
118
+ 'picking': {
119
+ 'max_targets_per_frame': 5,
120
+ 'min_confidence': 0.7,
121
+ 'pick_delay': 2.0, # seconds between picks
122
+ 'safety_timeout': 30.0 # seconds
123
+ }
124
+ }
125
+
126
+ def start(self):
127
+ """Start the strawberry picker pipeline"""
128
+ logger.info("Starting Strawberry Picker Pipeline...")
129
+
130
+ try:
131
+ # Initialize hardware
132
+ self.arduino.connect()
133
+ self.arduino.initialize_servos()
134
+
135
+ # Start processing
136
+ self.running = True
137
+ self._start_processing_loop()
138
+
139
+ except Exception as e:
140
+ logger.error(f"Failed to start pipeline: {e}")
141
+ self.stop()
142
+ raise
143
+
144
+ def stop(self):
145
+ """Stop the strawberry picker pipeline"""
146
+ logger.info("Stopping Strawberry Picker Pipeline...")
147
+ self.running = False
148
+
149
+ # Close connections
150
+ if hasattr(self, 'arduino'):
151
+ self.arduino.disconnect()
152
+
153
+ if hasattr(self, 'executor'):
154
+ self.executor.shutdown(wait=True)
155
+
156
+ logger.info("Pipeline stopped successfully")
157
+
158
+ def _start_processing_loop(self):
159
+ """Start the main processing loop"""
160
+ # Start camera capture in separate thread
161
+ capture_future = self.executor.submit(self._camera_capture_loop)
162
+
163
+ # Start picking loop
164
+ picking_future = self.executor.submit(self._picking_loop)
165
+
166
+ logger.info("Processing loops started")
167
+
168
+ def _camera_capture_loop(self):
169
+ """Continuous camera capture and processing loop"""
170
+ cap = cv2.VideoCapture(self.config['camera']['index'])
171
+
172
+ if not cap.isOpened():
173
+ logger.error("Failed to open camera")
174
+ return
175
+
176
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.config['camera']['width'])
177
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.config['camera']['height'])
178
+ cap.set(cv2.CAP_PROP_FPS, self.config['camera']['fps'])
179
+
180
+ logger.info("Camera capture started")
181
+
182
+ while self.running:
183
+ ret, frame = cap.read()
184
+ if not ret:
185
+ logger.warning("Failed to capture frame")
186
+ continue
187
+
188
+ # Process frame for detection and classification
189
+ self._process_frame_async(frame)
190
+
191
+ # Display frame with annotations
192
+ self._display_frame(frame)
193
+
194
+ # Control frame rate
195
+ time.sleep(1.0 / self.config['camera']['fps'])
196
+
197
+ cap.release()
198
+ logger.info("Camera capture stopped")
199
+
200
+ def _process_frame_async(self, frame: np.ndarray):
201
+ """Process frame asynchronously for detection and classification"""
202
+ def process():
203
+ try:
204
+ # Detect and classify strawberries
205
+ results = self.detector_classifier.process_frame(frame)
206
+
207
+ # Convert to picking targets
208
+ targets = self._create_picking_targets(results, frame)
209
+
210
+ # Update targets list
211
+ with self.processing_lock:
212
+ self.picking_targets = targets
213
+ self.processed_count += 1
214
+
215
+ except Exception as e:
216
+ logger.error(f"Frame processing error: {e}")
217
+
218
+ # Submit to thread pool
219
+ self.executor.submit(process)
220
+
221
+ def _create_picking_targets(self, detection_results: Dict, frame: np.ndarray) -> List[PickingTarget]:
222
+ """Create picking targets from detection results"""
223
+ targets = []
224
+
225
+ for detection in detection_results.get('detections', []):
226
+ if detection['confidence'] < self.config['picking']['min_confidence']:
227
+ continue
228
+
229
+ # Get classification result
230
+ ripeness = detection.get('ripeness', 'unknown')
231
+ ripeness_confidence = detection.get('ripeness_confidence', 0.0)
232
+
233
+ # Only pick ripe strawberries
234
+ if ripeness != 'ripe':
235
+ continue
236
+
237
+ # Calculate pixel coordinates
238
+ x, y, w, h = detection['bbox']
239
+ center_x = int(x + w / 2)
240
+ center_y = int(y + h / 2)
241
+
242
+ # Transform to world coordinates
243
+ try:
244
+ world_coords = self.coordinate_transformer.pixel_to_world(
245
+ center_x, center_y, frame.shape
246
+ )
247
+ except Exception as e:
248
+ logger.warning(f"Coordinate transformation failed: {e}")
249
+ world_coords = (0.0, 0.0, 0.0)
250
+
251
+ # Calculate priority (higher confidence = higher priority)
252
+ priority = detection['confidence'] * ripeness_confidence
253
+
254
+ target = PickingTarget(
255
+ bbox=detection['bbox'],
256
+ confidence=detection['confidence'],
257
+ ripeness=ripeness,
258
+ ripeness_confidence=ripeness_confidence,
259
+ pixel_coords=(center_x, center_y),
260
+ world_coords=world_coords,
261
+ priority=priority
262
+ )
263
+
264
+ targets.append(target)
265
+
266
+ # Sort by priority and limit number of targets
267
+ targets.sort(key=lambda t: t.priority, reverse=True)
268
+ return targets[:self.config['picking']['max_targets_per_frame']]
269
+
270
+ def _picking_loop(self):
271
+ """Main picking loop"""
272
+ logger.info("Picking loop started")
273
+
274
+ last_pick_time = 0
275
+ safety_timeout = self.config['picking']['safety_timeout']
276
+
277
+ while self.running:
278
+ try:
279
+ current_time = time.time()
280
+
281
+ # Check if enough time has passed since last pick
282
+ if current_time - last_pick_time < self.config['picking']['pick_delay']:
283
+ time.sleep(0.1)
284
+ continue
285
+
286
+ # Get current targets
287
+ with self.processing_lock:
288
+ targets = self.picking_targets.copy()
289
+
290
+ if not targets:
291
+ time.sleep(0.1)
292
+ continue
293
+
294
+ # Select best target
295
+ target = targets[0]
296
+
297
+ # Execute pick
298
+ success = self._execute_pick(target)
299
+
300
+ if success:
301
+ self.successful_picks += 1
302
+ logger.info(f"Successful pick! Total: {self.successful_picks}")
303
+ else:
304
+ self.failed_picks += 1
305
+ logger.warning(f"Failed pick. Total failures: {self.failed_picks}")
306
+
307
+ last_pick_time = current_time
308
+
309
+ # Safety timeout check
310
+ if current_time - last_pick_time > safety_timeout:
311
+ logger.warning("Safety timeout reached, pausing picking")
312
+ time.sleep(5.0)
313
+
314
+ except Exception as e:
315
+ logger.error(f"Picking loop error: {e}")
316
+ time.sleep(1.0)
317
+
318
+ logger.info("Picking loop stopped")
319
+
320
+ def _execute_pick(self, target: PickingTarget) -> bool:
321
+ """Execute a picking action for the given target"""
322
+ try:
323
+ logger.info(f"Executing pick for target at {target.pixel_coords}")
324
+
325
+ # Move to target position
326
+ x, y, z = target.world_coords
327
+ self.arduino.move_to_position(x, y, z)
328
+
329
+ # Wait for movement to complete
330
+ time.sleep(2.0)
331
+
332
+ # Close gripper
333
+ self.arduino.close_gripper()
334
+ time.sleep(1.0)
335
+
336
+ # Lift strawberry
337
+ self.arduino.move_to_position(x, y, z + 0.1)
338
+ time.sleep(1.0)
339
+
340
+ # Move to collection area
341
+ collection_pos = self.config['picking']['collection_position']
342
+ self.arduino.move_to_position(*collection_pos)
343
+ time.sleep(2.0)
344
+
345
+ # Open gripper to release strawberry
346
+ self.arduino.open_gripper()
347
+ time.sleep(1.0)
348
+
349
+ # Return to home position
350
+ home_pos = self.config['picking']['home_position']
351
+ self.arduino.move_to_position(*home_pos)
352
+ time.sleep(2.0)
353
+
354
+ logger.info("Pick sequence completed successfully")
355
+ return True
356
+
357
+ except Exception as e:
358
+ logger.error(f"Pick execution failed: {e}")
359
+ return False
360
+
361
+ def _display_frame(self, frame: np.ndarray):
362
+ """Display frame with annotations"""
363
+ try:
364
+ # Add annotations for current targets
365
+ with self.processing_lock:
366
+ targets = self.picking_targets
367
+
368
+ for i, target in enumerate(targets):
369
+ x, y, w, h = target.bbox
370
+
371
+ # Draw bounding box
372
+ color = (0, 255, 0) if target.ripeness == 'ripe' else (0, 0, 255)
373
+ cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
374
+
375
+ # Add label
376
+ label = f"{target.ripeness} ({target.confidence:.2f})"
377
+ cv2.putText(frame, label, (x, y - 10),
378
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
379
+
380
+ # Add priority indicator
381
+ cv2.putText(frame, f"P{i+1}: {target.priority:.2f}",
382
+ (x, y + h + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
383
+ (255, 255, 0), 2)
384
+
385
+ # Add status information
386
+ status_text = f"Processed: {self.processed_count} | Success: {self.successful_picks} | Failed: {self.failed_picks}"
387
+ cv2.putText(frame, status_text, (10, 30),
388
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
389
+
390
+ # Display frame
391
+ cv2.imshow('Strawberry Picker Pipeline', frame)
392
+ cv2.waitKey(1)
393
+
394
+ except Exception as e:
395
+ logger.error(f"Frame display error: {e}")
396
+
397
+ def get_statistics(self) -> Dict:
398
+ """Get pipeline statistics"""
399
+ with self.processing_lock:
400
+ return {
401
+ 'processed_frames': self.processed_count,
402
+ 'successful_picks': self.successful_picks,
403
+ 'failed_picks': self.failed_picks,
404
+ 'success_rate': self.successful_picks / max(1, self.successful_picks + self.failed_picks),
405
+ 'current_targets': len(self.picking_targets)
406
+ }
407
+
408
+ def save_statistics(self, filepath: str):
409
+ """Save statistics to file"""
410
+ stats = self.get_statistics()
411
+ stats['timestamp'] = time.time()
412
+
413
+ with open(filepath, 'w') as f:
414
+ json.dump(stats, f, indent=2)
415
+
416
+ logger.info(f"Statistics saved to {filepath}")
417
+
418
+ def main():
419
+ """Main function"""
420
+ parser = argparse.ArgumentParser(description='Strawberry Picker Pipeline')
421
+ parser.add_argument('--config', default='config.yaml', help='Configuration file path')
422
+ parser.add_argument('--test', action='store_true', help='Run in test mode without hardware')
423
+ parser.add_argument('--save-stats', help='Save statistics to file')
424
+
425
+ args = parser.parse_args()
426
+
427
+ try:
428
+ # Initialize pipeline
429
+ pipeline = StrawberryPickerPipeline(args.config)
430
+
431
+ if args.test:
432
+ logger.info("Running in test mode - no hardware control")
433
+
434
+ # Start pipeline
435
+ pipeline.start()
436
+
437
+ # Keep running until interrupted
438
+ try:
439
+ while True:
440
+ time.sleep(1)
441
+
442
+ # Print statistics periodically
443
+ stats = pipeline.get_statistics()
444
+ if stats['processed_frames'] % 100 == 0:
445
+ logger.info(f"Statistics: {stats}")
446
+
447
+ except KeyboardInterrupt:
448
+ logger.info("Received interrupt signal")
449
+
450
+ finally:
451
+ pipeline.stop()
452
+
453
+ if args.save_stats:
454
+ pipeline.save_statistics(args.save_stats)
455
+
456
+ logger.info("Pipeline execution completed")
457
+
458
+ except Exception as e:
459
+ logger.error(f"Pipeline execution failed: {e}")
460
+ return 1
461
+
462
+ return 0
463
+
464
+ if __name__ == "__main__":
465
+ exit(main())
sync_to_huggingface.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sync Strawberry Picker Models to HuggingFace Repository
4
+
5
+ This script automates the process of syncing trained models from strawberryPicker
6
+ to the HuggingFace strawberryPicker repository.
7
+
8
+ Usage:
9
+ python sync_to_huggingface.py [--dry-run]
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import shutil
16
+ import argparse
17
+ import subprocess
18
+ from pathlib import Path
19
+ from datetime import datetime
20
+ import hashlib
21
+
22
+ def calculate_file_hash(file_path):
23
+ """Calculate SHA256 hash of a file."""
24
+ sha256_hash = hashlib.sha256()
25
+ with open(file_path, "rb") as f:
26
+ for byte_block in iter(lambda: f.read(4096), b""):
27
+ sha256_hash.update(byte_block)
28
+ return sha256_hash.hexdigest()
29
+
30
+ def find_updated_models(repo_path):
31
+ """Find models that have been updated."""
32
+ updated_models = []
33
+
34
+ # Check detection model
35
+ detection_pt = Path(repo_path) / "detection" / "best.pt"
36
+ if detection_pt.exists():
37
+ detection_hash = calculate_file_hash(detection_pt)
38
+ updated_models.append({
39
+ 'component': 'detection',
40
+ 'path': detection_pt,
41
+ 'hash': detection_hash
42
+ })
43
+
44
+ # Check classification model
45
+ classification_pth = Path(repo_path) / "classification" / "best_enhanced_classifier.pth"
46
+ if classification_pth.exists():
47
+ classification_hash = calculate_file_hash(classification_pth)
48
+ updated_models.append({
49
+ 'component': 'classification',
50
+ 'path': classification_pth,
51
+ 'hash': classification_hash
52
+ })
53
+
54
+ return updated_models
55
+
56
+ def export_detection_to_onnx(model_path, output_dir):
57
+ """Export detection model to ONNX format."""
58
+ try:
59
+ cmd = [
60
+ "yolo", "export",
61
+ f"model={model_path}",
62
+ f"dir={output_dir}",
63
+ "format=onnx",
64
+ "opset=12"
65
+ ]
66
+
67
+ print(f"Exporting detection model to ONNX...")
68
+ result = subprocess.run(cmd, capture_output=True, text=True)
69
+
70
+ if result.returncode == 0:
71
+ print("Successfully exported detection model to ONNX")
72
+ return True
73
+ else:
74
+ print(f"Export failed: {result.stderr}")
75
+ return False
76
+
77
+ except Exception as e:
78
+ print(f"Error during ONNX export: {e}")
79
+ return False
80
+
81
+ def update_model_metadata(repo_path, models_info):
82
+ """Update metadata files with sync information."""
83
+ metadata = {
84
+ "last_sync": datetime.now().isoformat(),
85
+ "models": {}
86
+ }
87
+
88
+ for model in models_info:
89
+ metadata["models"][model['component']] = {
90
+ "hash": model['hash'],
91
+ "path": str(model['path']),
92
+ "last_updated": datetime.now().isoformat()
93
+ }
94
+
95
+ metadata_file = Path(repo_path) / "sync_metadata.json"
96
+ with open(metadata_file, 'w') as f:
97
+ json.dump(metadata, f, indent=2)
98
+
99
+ print(f"Updated metadata file: {metadata_file}")
100
+ return True
101
+
102
+ def sync_repository(repo_path, dry_run=False):
103
+ """Main sync function for strawberryPicker repository."""
104
+ print(f"Syncing strawberryPicker repository at {repo_path}")
105
+
106
+ if dry_run:
107
+ print("DRY RUN MODE - No changes will be made")
108
+
109
+ # Find updated models
110
+ updated_models = find_updated_models(repo_path)
111
+
112
+ if not updated_models:
113
+ print("No models found to sync")
114
+ return True
115
+
116
+ print(f"Found {len(updated_models)} model components:")
117
+ for model in updated_models:
118
+ print(f" - {model['component']} (hash: {model['hash'][:8]}...)")
119
+
120
+ if dry_run:
121
+ return True
122
+
123
+ # Export detection model to ONNX if needed
124
+ detection_model = next((m for m in updated_models if m['component'] == 'detection'), None)
125
+ if detection_model:
126
+ detection_dir = Path(repo_path) / "detection"
127
+ export_detection_to_onnx(detection_model['path'], detection_dir)
128
+
129
+ # Update metadata
130
+ update_model_metadata(repo_path, updated_models)
131
+
132
+ print("\nSync completed successfully!")
133
+ print("Remember to:")
134
+ print("1. Review and commit changes to git")
135
+ print("2. Push to HuggingFace: git push origin main")
136
+ print("3. Update READMEs with any new performance metrics")
137
+
138
+ return True
139
+
140
+ def main():
141
+ parser = argparse.ArgumentParser(description="Sync strawberryPicker models to HuggingFace")
142
+ parser.add_argument("--repo-path",
143
+ default="/home/user/machine-learning/HuggingfaceModels/strawberryPicker",
144
+ help="Path to strawberryPicker repository")
145
+ parser.add_argument("--dry-run", action="store_true",
146
+ help="Show what would be done without making changes")
147
+
148
+ args = parser.parse_args()
149
+
150
+ # Validate path
151
+ if not os.path.exists(args.repo_path):
152
+ print(f"Error: Repository path {args.repo_path} does not exist")
153
+ sys.exit(1)
154
+
155
+ # Run sync
156
+ success = sync_repository(args.repo_path, args.dry_run)
157
+
158
+ if not success:
159
+ sys.exit(1)
160
+
161
+ if __name__ == "__main__":
162
+ main()
webcam_inference.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Real-time Strawberry Detection and Ripeness Classification using Webcam
4
+ Optimized for WSL (Windows Subsystem for Linux) environments
5
+ """
6
+
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ import argparse
12
+ import time
13
+ from pathlib import Path
14
+ import sys
15
+ import warnings
16
+
17
+ # Suppress warnings
18
+ warnings.filterwarnings('ignore')
19
+
20
+ class StrawberryPickerWebcam:
21
+ def __init__(self, detector_path, classifier_path, device='cpu'):
22
+ """
23
+ Initialize the strawberry picker system
24
+
25
+ Args:
26
+ detector_path: Path to YOLOv8 detection model
27
+ classifier_path: Path to EfficientNet classification model
28
+ device: Device to run inference on ('cpu' or 'cuda')
29
+ """
30
+ print("🍓 Initializing Strawberry Picker AI System...")
31
+
32
+ self.device = device
33
+ self.ripeness_classes = ['unripe', 'partially-ripe', 'ripe', 'overripe']
34
+
35
+ # Color mapping for visualization
36
+ self.colors = {
37
+ 'unripe': (0, 255, 0), # Green
38
+ 'partially-ripe': (0, 255, 255), # Yellow
39
+ 'ripe': (0, 0, 255), # Red
40
+ 'overripe': (128, 0, 128) # Purple
41
+ }
42
+
43
+ # Load detection model
44
+ print("Loading detection model...")
45
+ try:
46
+ from ultralytics import YOLO
47
+ self.detector = YOLO(detector_path)
48
+ print("✅ Detection model loaded successfully")
49
+ except Exception as e:
50
+ print(f"❌ Error loading detection model: {e}")
51
+ sys.exit(1)
52
+
53
+ # Load classification model
54
+ print("Loading classification model...")
55
+ try:
56
+ self.classifier = torch.load(classifier_path, map_location=device)
57
+ self.classifier.eval()
58
+ print("✅ Classification model loaded successfully")
59
+ except Exception as e:
60
+ print(f"❌ Error loading classification model: {e}")
61
+ sys.exit(1)
62
+
63
+ # Setup preprocessing
64
+ self.transform = transforms.Compose([
65
+ transforms.Resize((128, 128)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
68
+ std=[0.229, 0.224, 0.225])
69
+ ])
70
+
71
+ print("✅ System initialized and ready!")
72
+
73
+ def detect_and_classify(self, frame):
74
+ """
75
+ Detect strawberries and classify their ripeness in a frame
76
+
77
+ Args:
78
+ frame: Input frame (BGR format)
79
+
80
+ Returns:
81
+ results: List of detection/classification results
82
+ visualized_frame: Frame with visualizations
83
+ """
84
+ # Convert to RGB
85
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+
87
+ # Detect strawberries
88
+ detection_results = self.detector(frame_rgb)
89
+
90
+ results = []
91
+
92
+ for result in detection_results:
93
+ boxes = result.boxes.xyxy.cpu().numpy()
94
+ confidences = result.boxes.conf.cpu().numpy()
95
+
96
+ for box, conf in zip(boxes, confidences):
97
+ if conf < 0.5: # Confidence threshold
98
+ continue
99
+
100
+ x1, y1, x2, y2 = map(int, box)
101
+
102
+ # Ensure coordinates are within frame bounds
103
+ x1 = max(0, x1)
104
+ y1 = max(0, y1)
105
+ x2 = min(frame.shape[1], x2)
106
+ y2 = min(frame.shape[0], y2)
107
+
108
+ # Crop strawberry
109
+ crop = frame_rgb[y1:y2, x1:x2]
110
+
111
+ if crop.size == 0:
112
+ continue
113
+
114
+ # Classify ripeness
115
+ try:
116
+ crop_pil = Image.fromarray(crop)
117
+ input_tensor = self.transform(crop_pil).unsqueeze(0).to(self.device)
118
+
119
+ with torch.no_grad():
120
+ output = self.classifier(input_tensor)
121
+ probabilities = torch.softmax(output, dim=1)
122
+ predicted_class = torch.argmax(probabilities, dim=1).item()
123
+ confidence = probabilities[0][predicted_class].item()
124
+
125
+ ripeness = self.ripeness_classes[predicted_class]
126
+
127
+ results.append({
128
+ 'bbox': (x1, y1, x2, y2),
129
+ 'ripeness': ripeness,
130
+ 'confidence': confidence,
131
+ 'detection_confidence': float(conf)
132
+ })
133
+
134
+ except Exception as e:
135
+ print(f"Warning: Error classifying crop: {e}")
136
+ continue
137
+
138
+ return results
139
+
140
+ def visualize(self, frame, results):
141
+ """
142
+ Draw bounding boxes and labels on frame
143
+
144
+ Args:
145
+ frame: Input frame
146
+ results: Detection/classification results
147
+
148
+ Returns:
149
+ visualized_frame: Frame with drawings
150
+ """
151
+ vis_frame = frame.copy()
152
+
153
+ for result in results:
154
+ x1, y1, x2, y2 = result['bbox']
155
+ ripeness = result['ripeness']
156
+ conf = result['confidence']
157
+
158
+ # Draw bounding box
159
+ color = self.colors[ripeness]
160
+ cv2.rectangle(vis_frame, (x1, y1), (x2, y2), color, 2)
161
+
162
+ # Draw label background
163
+ label = f"{ripeness} ({conf:.2f})"
164
+ label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
165
+ cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
166
+ (x1 + label_size[0], y1), color, -1)
167
+
168
+ # Draw label text
169
+ cv2.putText(vis_frame, label, (x1, y1 - 5),
170
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
171
+
172
+ # Add FPS counter
173
+ fps_text = f"FPS: {self.fps:.1f}"
174
+ cv2.putText(vis_frame, fps_text, (10, 30),
175
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
176
+
177
+ # Add title
178
+ title = "Strawberry Picker AI - Press 'q' to quit"
179
+ cv2.putText(vis_frame, title, (10, 60),
180
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
181
+
182
+ return vis_frame
183
+
184
+ def run_webcam(self, camera_index=0, width=640, height=480):
185
+ """
186
+ Run real-time inference on webcam
187
+
188
+ Args:
189
+ camera_index: Camera index (0 for default webcam)
190
+ width: Frame width
191
+ height: Frame height
192
+ """
193
+ print(f"\n📹 Starting webcam (camera {camera_index})...")
194
+ print("Press 'q' to quit, 's' to save screenshot")
195
+ print("Make sure strawberries are well-lit and clearly visible\n")
196
+
197
+ # Try to open webcam
198
+ cap = cv2.VideoCapture(camera_index)
199
+
200
+ if not cap.isOpened():
201
+ print(f"❌ Error: Could not open camera {camera_index}")
202
+ print("\nTroubleshooting tips for WSL:")
203
+ print("1. Install v4l2loopback: sudo apt-get install v4l2loopback-dkms")
204
+ print("2. Load module: sudo modprobe v4l2loopback")
205
+ print("3. Use IP webcam app on phone as alternative")
206
+ print("4. Or use pre-recorded video file")
207
+ return
208
+
209
+ # Set camera properties
210
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
211
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
212
+
213
+ # FPS tracking
214
+ self.fps = 0
215
+ frame_count = 0
216
+ start_time = time.time()
217
+
218
+ # Screenshot counter
219
+ screenshot_count = 0
220
+
221
+ try:
222
+ while True:
223
+ # Read frame
224
+ ret, frame = cap.read()
225
+
226
+ if not ret:
227
+ print("❌ Error: Could not read frame from camera")
228
+ break
229
+
230
+ # Detect and classify
231
+ results = self.detect_and_classify(frame)
232
+
233
+ # Visualize results
234
+ vis_frame = self.visualize(frame, results)
235
+
236
+ # Calculate FPS
237
+ frame_count += 1
238
+ if frame_count % 10 == 0:
239
+ elapsed = time.time() - start_time
240
+ self.fps = frame_count / elapsed
241
+
242
+ # Display frame
243
+ cv2.imshow('Strawberry Picker AI', vis_frame)
244
+
245
+ # Handle keyboard input
246
+ key = cv2.waitKey(1) & 0xFF
247
+
248
+ if key == ord('q'):
249
+ print("\n👋 Quitting...")
250
+ break
251
+ elif key == ord('s'):
252
+ # Save screenshot
253
+ screenshot_path = f"screenshot_{screenshot_count}.jpg"
254
+ cv2.imwrite(screenshot_path, vis_frame)
255
+ print(f"📸 Screenshot saved: {screenshot_path}")
256
+ screenshot_count += 1
257
+
258
+ except KeyboardInterrupt:
259
+ print("\n👋 Interrupted by user")
260
+
261
+ finally:
262
+ # Cleanup
263
+ cap.release()
264
+ cv2.destroyAllWindows()
265
+ print("✅ Webcam session ended")
266
+
267
+ def run_video_file(self, video_path):
268
+ """
269
+ Run inference on a video file
270
+
271
+ Args:
272
+ video_path: Path to video file
273
+ """
274
+ print(f"\n🎬 Processing video: {video_path}")
275
+
276
+ cap = cv2.VideoCapture(video_path)
277
+
278
+ if not cap.isOpened():
279
+ print(f"❌ Error: Could not open video file: {video_path}")
280
+ return
281
+
282
+ # Get video properties
283
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
284
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
285
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
286
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
287
+
288
+ print(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames")
289
+
290
+ # Setup output video
291
+ output_path = f"output_{Path(video_path).name}"
292
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
293
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
294
+
295
+ frame_count = 0
296
+ start_time = time.time()
297
+
298
+ try:
299
+ while True:
300
+ ret, frame = cap.read()
301
+
302
+ if not ret:
303
+ break
304
+
305
+ # Process frame
306
+ results = self.detect_and_classify(frame)
307
+ vis_frame = self.visualize(frame, results)
308
+
309
+ # Write to output
310
+ out.write(vis_frame)
311
+
312
+ # Display progress
313
+ frame_count += 1
314
+ if frame_count % 30 == 0:
315
+ progress = (frame_count / total_frames) * 100
316
+ elapsed = time.time() - start_time
317
+ print(f"Progress: {progress:.1f}% ({frame_count}/{total_frames}) - "
318
+ f"Time: {elapsed:.1f}s")
319
+
320
+ except KeyboardInterrupt:
321
+ print("\n👋 Interrupted by user")
322
+
323
+ finally:
324
+ cap.release()
325
+ out.release()
326
+ cv2.destroyAllWindows()
327
+ print(f"✅ Video processing complete. Output saved to: {output_path}")
328
+
329
+ def main():
330
+ """Main function with argument parsing"""
331
+ parser = argparse.ArgumentParser(
332
+ description='Real-time Strawberry Detection and Ripeness Classification'
333
+ )
334
+
335
+ parser.add_argument(
336
+ '--detector',
337
+ type=str,
338
+ default='detection_model/best.pt',
339
+ help='Path to YOLOv8 detection model'
340
+ )
341
+
342
+ parser.add_argument(
343
+ '--classifier',
344
+ type=str,
345
+ default='classification_model/best_enhanced_classifier.pth',
346
+ help='Path to EfficientNet classification model'
347
+ )
348
+
349
+ parser.add_argument(
350
+ '--mode',
351
+ type=str,
352
+ choices=['webcam', 'video'],
353
+ default='webcam',
354
+ help='Mode: webcam or video file'
355
+ )
356
+
357
+ parser.add_argument(
358
+ '--input',
359
+ type=str,
360
+ help='Path to video file (if mode=video)'
361
+ )
362
+
363
+ parser.add_argument(
364
+ '--camera',
365
+ type=int,
366
+ default=0,
367
+ help='Camera index (default: 0)'
368
+ )
369
+
370
+ parser.add_argument(
371
+ '--width',
372
+ type=int,
373
+ default=640,
374
+ help='Camera frame width'
375
+ )
376
+
377
+ parser.add_argument(
378
+ '--height',
379
+ type=int,
380
+ default=480,
381
+ help='Camera frame height'
382
+ )
383
+
384
+ parser.add_argument(
385
+ '--device',
386
+ type=str,
387
+ default='auto',
388
+ choices=['auto', 'cpu', 'cuda'],
389
+ help='Device to use for inference'
390
+ )
391
+
392
+ args = parser.parse_args()
393
+
394
+ # Determine device
395
+ if args.device == 'auto':
396
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
397
+ else:
398
+ device = args.device
399
+
400
+ print(f"Using device: {device}")
401
+
402
+ if device == 'cpu':
403
+ print("⚠️ Running on CPU - this will be slower. Consider using GPU if available.")
404
+
405
+ # Initialize system
406
+ try:
407
+ picker = StrawberryPickerWebcam(
408
+ detector_path=args.detector,
409
+ classifier_path=args.classifier,
410
+ device=device
411
+ )
412
+ except Exception as e:
413
+ print(f"❌ Failed to initialize system: {e}")
414
+ sys.exit(1)
415
+
416
+ # Run inference
417
+ if args.mode == 'webcam':
418
+ picker.run_webcam(
419
+ camera_index=args.camera,
420
+ width=args.width,
421
+ height=args.height
422
+ )
423
+ elif args.mode == 'video':
424
+ if not args.input:
425
+ print("❌ Error: --input required for video mode")
426
+ sys.exit(1)
427
+ picker.run_video_file(args.input)
428
+
429
+ if __name__ == "__main__":
430
+ # Check for required libraries
431
+ try:
432
+ import torch
433
+ import cv2
434
+ from PIL import Image
435
+ from torchvision import transforms
436
+ except ImportError as e:
437
+ print(f"❌ Missing required library: {e}")
438
+ print("Install with: pip install torch torchvision opencv-python pillow")
439
+ sys.exit(1)
440
+
441
+ main()
yolov11n/README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - object-detection
4
+ - yolo
5
+ - yolov11
6
+ - strawberry
7
+ - agriculture
8
+ - robotics
9
+ - computer-vision
10
+ - pytorch
11
+ - onnx
12
+ license: mit
13
+ datasets:
14
+ - theonegareth/strawberry-detect
15
+ language:
16
+ - python
17
+ pretty_name: YOLOv11n Strawberry Detection
18
+ description: YOLOv11 Nano model for strawberry detection using latest architecture
19
+ pipeline_tag: object-detection
20
+ ---
21
+
22
+ # YOLOv11n Strawberry Detection Model
23
+
24
+ This directory contains the YOLOv11 Nano model for strawberry detection, utilizing the latest YOLO architecture improvements.
25
+
26
+ ## 📊 Model Performance
27
+
28
+ | Metric | Value |
29
+ |--------|-------|
30
+ | **mAP@50** | TBD |
31
+ | **mAP@50-95** | TBD |
32
+ | **Inference Speed** | TBD |
33
+ | **Model Size** | TBD |
34
+ | **Parameters** | TBD |
35
+
36
+ *Performance metrics will be updated after validation testing*
37
+
38
+ ## 🚀 Quick Start
39
+
40
+ ### Installation
41
+ ```bash
42
+ pip install ultralytics opencv-python
43
+ ```
44
+
45
+ ### Python Inference
46
+ ```python
47
+ from ultralytics import YOLO
48
+
49
+ # Load model
50
+ model = YOLO('strawberry_yolov11n.pt')
51
+
52
+ # Run inference
53
+ results = model('image.jpg', conf=0.25)
54
+
55
+ # Process results
56
+ for result in results:
57
+ boxes = result.boxes
58
+ for box in boxes:
59
+ cls = int(box.cls)
60
+ conf = float(box.conf)
61
+ xyxy = box.xyxy
62
+ print(f"Strawberry detected: {conf:.2f} confidence at {xyxy}")
63
+ ```
64
+
65
+ ### Command Line
66
+ ```bash
67
+ # Single image
68
+ yolo predict model=strawberry_yolov11n.pt source='image.jpg'
69
+
70
+ # Webcam
71
+ yolo predict model=strawberry_yolov11n.pt source=0
72
+
73
+ # Video
74
+ yolo predict model=strawberry_yolov11n.pt source='video.mp4'
75
+ ```
76
+
77
+ ## 📁 Files
78
+
79
+ - `strawberry_yolov11n.pt` - PyTorch model weights
80
+ - `strawberry_yolov11n.onnx` - ONNX model for deployment
81
+
82
+ ## 🎯 Use Cases
83
+
84
+ - **Latest Architecture Testing**: Evaluation of YOLOv11 improvements
85
+ - **Edge Deployment**: Optimized for modern edge devices
86
+ - **Research Applications**: Academic and industrial research
87
+ - **Future Deployment**: Next-generation robotic systems
88
+
89
+ ## 🔧 Technical Details
90
+
91
+ - **Architecture**: YOLOv11n (Nano)
92
+ - **Input Size**: 640x640
93
+ - **Training Dataset**: Enhanced Strawberry Dataset
94
+ - **Training Epochs**: TBD
95
+ - **Batch Size**: TBD
96
+ - **Optimizer**: TBD
97
+ - **Learning Rate**: TBD
98
+
99
+ ## 📈 Training Configuration
100
+
101
+ *Training configuration will be updated after model validation*
102
+
103
+ ```yaml
104
+ model: yolov11n.pt
105
+ epochs: TBD
106
+ batch: TBD
107
+ imgsz: 640
108
+ optimizer: TBD
109
+ lr0: TBD
110
+ # Additional hyperparameters will be added after validation
111
+ ```
112
+
113
+ ## 🔗 Related Models
114
+
115
+ - [YOLOv8n](../yolov8n/) - Proven YOLOv8 nano model
116
+ - [YOLOv8s](../yolov8s/) - Higher accuracy YOLOv8 small model
117
+
118
+ ## 📚 Documentation
119
+
120
+ - [Training Pipeline](https://github.com/theonegareth/strawberryPicker)
121
+ - [Dataset](https://universe.roboflow.com/theonegareth/strawberry-detect)
122
+ - [ROS2 Integration](https://github.com/theonegareth/strawberryPicker/blob/main/ROS2_INTEGRATION_PLAN.md)
123
+
124
+ ## ⚠️ Note
125
+
126
+ This model is currently in testing phase. Performance metrics and training details will be updated after comprehensive validation. For production deployment, consider using the validated [YOLOv8n](../yolov8n/) or [YOLOv8s](../yolov8s/) models.
127
+
128
+ ## 📄 License
129
+
130
+ MIT License - See main repository for details.
131
+
132
+ ---
133
+
134
+ **Model Version**: 0.1.0 (Testing)
135
+ **Training Date**: December 2025
136
+ **Status**: Under validation - Not recommended for production use