Gareth commited on
Commit ·
efb1801
0
Parent(s):
Initial clean commit for Hugging Face
Browse files- Added comprehensive README.md with project documentation
- Included essential Python scripts for inference and detection
- Added Arduino code for robotic control
- Included configuration files and requirements
- Added documentation and example notebooks
- Excluded large datasets, models, and binary files for Hugging Face compatibility
This clean version focuses on code and documentation while maintaining
full project functionality. Large files can be hosted separately via
Hugging Face Datasets or other storage solutions.
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- .gitignore +235 -0
- ArduinoCode/codingservoarm.ino +70 -0
- CITATION.cff +57 -0
- LICENSE +21 -0
- README +309 -0
- README.md +615 -0
- classification/README.md +146 -0
- classification/training_summary.md +21 -0
- classification_model/README.md +98 -0
- classification_model/training_summary.md +21 -0
- config.yaml +179 -0
- detection/README.md +138 -0
- docs/GITHUB_SETUP.md +316 -0
- docs/TRAINING_README.md +234 -0
- git-xet +0 -0
- inference_example.py +65 -0
- notebooks/strawberry_training.ipynb +92 -0
- notebooks/train_yolov8_colab.ipynb +309 -0
- requirements.txt +20 -0
- results.csv +51 -0
- scripts/all_combine3.py +177 -0
- scripts/auto_label_strawberries.py +220 -0
- scripts/benchmark_models.py +342 -0
- scripts/collect_dataset.py +41 -0
- scripts/combine3.py +177 -0
- scripts/complete_final_labeling.py +180 -0
- scripts/convert_tflite.py +119 -0
- scripts/data/preprocess_strawberry_dataset.py +91 -0
- scripts/detect_realtime.py +124 -0
- scripts/download_dataset.py +8 -0
- scripts/export_onnx.py +194 -0
- scripts/export_tflite_int8.py +200 -0
- scripts/get-pip.py +0 -0
- scripts/label_ripeness_dataset.py +192 -0
- scripts/optimization/optimized_onnx_inference.py +291 -0
- scripts/organize_labeled_images.py +115 -0
- scripts/setup_training.py +213 -0
- scripts/train_model.py +59 -0
- scripts/train_ripeness_classifier.py +345 -0
- scripts/train_yolov8.py +223 -0
- scripts/validate_model.py +234 -0
- scripts/webcam_capture.py +20 -0
- src/arduino_bridge.py +544 -0
- src/coordinate_transformer.py +482 -0
- src/integrated_detection_classification.py +248 -0
- src/strawberry_picker_pipeline.py +465 -0
- sync_to_huggingface.py +162 -0
- webcam_inference.py +441 -0
- yolov11n/README.md +136 -0
.gitattributes
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
yolov11n/*.pt filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
yolov11n/*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
yolov8n/*.pt filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
yolov8n/*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
yolov8s/*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# Python
|
| 3 |
+
# ============================================
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*$py.class
|
| 7 |
+
*.so
|
| 8 |
+
.Python
|
| 9 |
+
build/
|
| 10 |
+
develop-eggs/
|
| 11 |
+
dist/
|
| 12 |
+
downloads/
|
| 13 |
+
eggs/
|
| 14 |
+
.eggs/
|
| 15 |
+
lib/
|
| 16 |
+
lib64/
|
| 17 |
+
parts/
|
| 18 |
+
sdist/
|
| 19 |
+
var/
|
| 20 |
+
wheels/
|
| 21 |
+
pip-wheel-metadata/
|
| 22 |
+
share/python-wheels/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
MANIFEST
|
| 27 |
+
|
| 28 |
+
# ============================================
|
| 29 |
+
# Virtual Environments
|
| 30 |
+
# ============================================
|
| 31 |
+
.env
|
| 32 |
+
.venv
|
| 33 |
+
env/
|
| 34 |
+
venv/
|
| 35 |
+
ENV/
|
| 36 |
+
env.bak/
|
| 37 |
+
venv.bak/
|
| 38 |
+
pip-log.txt
|
| 39 |
+
pip-delete-this-directory.txt
|
| 40 |
+
|
| 41 |
+
# ============================================
|
| 42 |
+
# IDE & Editor
|
| 43 |
+
# ============================================
|
| 44 |
+
.vscode/
|
| 45 |
+
.idea/
|
| 46 |
+
*.swp
|
| 47 |
+
*.swo
|
| 48 |
+
*~
|
| 49 |
+
.DS_Store
|
| 50 |
+
.DS_Store?
|
| 51 |
+
._*
|
| 52 |
+
.Spotlight-V100
|
| 53 |
+
.Trashes
|
| 54 |
+
ehthumbs.db
|
| 55 |
+
Thumbs.db
|
| 56 |
+
|
| 57 |
+
# ============================================
|
| 58 |
+
# Jupyter Notebook
|
| 59 |
+
# ============================================
|
| 60 |
+
.ipynb_checkpoints
|
| 61 |
+
*/.ipynb_checkpoints/*
|
| 62 |
+
|
| 63 |
+
# ============================================
|
| 64 |
+
# Model Files & Training Artifacts
|
| 65 |
+
# ============================================
|
| 66 |
+
# PyTorch / TensorFlow / Keras
|
| 67 |
+
*.pth
|
| 68 |
+
*.pt
|
| 69 |
+
*.pkl
|
| 70 |
+
*.h5
|
| 71 |
+
*.tflite
|
| 72 |
+
*.onnx
|
| 73 |
+
*.pb
|
| 74 |
+
*.weights
|
| 75 |
+
*.cfg
|
| 76 |
+
*.data
|
| 77 |
+
*.pth.tar
|
| 78 |
+
|
| 79 |
+
# Model directories (auto-generated)
|
| 80 |
+
model/weights/
|
| 81 |
+
model/results/
|
| 82 |
+
model/exports/
|
| 83 |
+
runs/
|
| 84 |
+
detect/
|
| 85 |
+
predict/
|
| 86 |
+
val/
|
| 87 |
+
checkpoint/
|
| 88 |
+
checkpoints/
|
| 89 |
+
|
| 90 |
+
# ============================================
|
| 91 |
+
# Dataset & Large Files
|
| 92 |
+
# ============================================
|
| 93 |
+
# Raw images (can be regenerated)
|
| 94 |
+
model/dataset/**/*.jpg
|
| 95 |
+
model/dataset/**/*.jpeg
|
| 96 |
+
model/dataset/**/*.png
|
| 97 |
+
model/dataset/**/*.zip
|
| 98 |
+
!model/dataset/**/data.yaml
|
| 99 |
+
!model/dataset/**/*.txt
|
| 100 |
+
|
| 101 |
+
# Roboflow exports
|
| 102 |
+
straw-detect.v1-straw-detect.yolov8/
|
| 103 |
+
|
| 104 |
+
# ============================================
|
| 105 |
+
# Logs & Caches
|
| 106 |
+
# ============================================
|
| 107 |
+
*.log
|
| 108 |
+
logs/
|
| 109 |
+
.cache/
|
| 110 |
+
.parquet
|
| 111 |
+
.pytest_cache/
|
| 112 |
+
.coverage
|
| 113 |
+
htmlcov/
|
| 114 |
+
.tox/
|
| 115 |
+
.nox/
|
| 116 |
+
|
| 117 |
+
# ============================================
|
| 118 |
+
# Build & Distribution
|
| 119 |
+
# ============================================
|
| 120 |
+
build/
|
| 121 |
+
develop-eggs/
|
| 122 |
+
dist/
|
| 123 |
+
downloads/
|
| 124 |
+
eggs/
|
| 125 |
+
.eggs/
|
| 126 |
+
lib/
|
| 127 |
+
lib64/
|
| 128 |
+
parts/
|
| 129 |
+
sdist/
|
| 130 |
+
var/
|
| 131 |
+
wheels/
|
| 132 |
+
share/python-wheels/
|
| 133 |
+
*.egg-info/
|
| 134 |
+
.installed.cfg
|
| 135 |
+
*.egg
|
| 136 |
+
MANIFEST
|
| 137 |
+
pyproject.toml
|
| 138 |
+
poetry.lock
|
| 139 |
+
|
| 140 |
+
# ============================================
|
| 141 |
+
# Temporary Files
|
| 142 |
+
# ============================================
|
| 143 |
+
tmp/
|
| 144 |
+
temp/
|
| 145 |
+
*.tmp
|
| 146 |
+
*.temp
|
| 147 |
+
*.swp
|
| 148 |
+
*.swo
|
| 149 |
+
*~
|
| 150 |
+
|
| 151 |
+
# ============================================
|
| 152 |
+
# Media & Output Files
|
| 153 |
+
# ============================================
|
| 154 |
+
*.avi
|
| 155 |
+
*.mp4
|
| 156 |
+
*.mov
|
| 157 |
+
*.mkv
|
| 158 |
+
*.gcode
|
| 159 |
+
*.factory
|
| 160 |
+
|
| 161 |
+
# ============================================
|
| 162 |
+
# ML Experiment Tracking
|
| 163 |
+
# ============================================
|
| 164 |
+
mlruns/
|
| 165 |
+
mlflow.db
|
| 166 |
+
wandb/
|
| 167 |
+
|
| 168 |
+
# ============================================
|
| 169 |
+
# TensorBoard
|
| 170 |
+
# ============================================
|
| 171 |
+
runs/
|
| 172 |
+
*.tfevents.*
|
| 173 |
+
events.out.tfevents.*
|
| 174 |
+
|
| 175 |
+
# ============================================
|
| 176 |
+
# Calibration & Configuration
|
| 177 |
+
# ============================================
|
| 178 |
+
*.calib
|
| 179 |
+
*.yaml
|
| 180 |
+
*.yml
|
| 181 |
+
!data.yaml
|
| 182 |
+
!requirements.txt
|
| 183 |
+
!setup.py
|
| 184 |
+
!pyproject.toml
|
| 185 |
+
|
| 186 |
+
# ============================================
|
| 187 |
+
# Arduino Build Files
|
| 188 |
+
# ============================================
|
| 189 |
+
ArduinoCode/*.hex
|
| 190 |
+
ArduinoCode/*.elf
|
| 191 |
+
ArduinoCode/build/
|
| 192 |
+
ArduinoCode/*.bin
|
| 193 |
+
|
| 194 |
+
# ============================================
|
| 195 |
+
# SolidWorks & CAD Temporary Files
|
| 196 |
+
# ============================================
|
| 197 |
+
assets/solidworks/*.tmp
|
| 198 |
+
assets/solidworks/~$*
|
| 199 |
+
assets/solidworks/*.bak
|
| 200 |
+
assets/solidworks/*.swp
|
| 201 |
+
|
| 202 |
+
# ============================================
|
| 203 |
+
# Documentation Builds
|
| 204 |
+
# ============================================
|
| 205 |
+
docs/_build/
|
| 206 |
+
site/
|
| 207 |
+
*.pdf
|
| 208 |
+
*.docx
|
| 209 |
+
*.doc
|
| 210 |
+
*.xlsx
|
| 211 |
+
*.pptx
|
| 212 |
+
|
| 213 |
+
# ============================================
|
| 214 |
+
# Secrets & Local Configuration
|
| 215 |
+
# ============================================
|
| 216 |
+
config.local.yaml
|
| 217 |
+
secrets.yaml
|
| 218 |
+
credentials.json
|
| 219 |
+
*.key
|
| 220 |
+
*.pem
|
| 221 |
+
|
| 222 |
+
# ============================================
|
| 223 |
+
# Debug & Development
|
| 224 |
+
# ============================================
|
| 225 |
+
debug/
|
| 226 |
+
debug_images/
|
| 227 |
+
*.debug
|
| 228 |
+
|
| 229 |
+
# ============================================
|
| 230 |
+
# Robot Specific
|
| 231 |
+
# ============================================
|
| 232 |
+
robot/logs/
|
| 233 |
+
robot/calibration/
|
| 234 |
+
robot/trajectories/
|
| 235 |
+
*.bag
|
ArduinoCode/codingservoarm.ino
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <Wire.h>
|
| 2 |
+
#include <Adafruit_PWMServoDriver.h>
|
| 3 |
+
|
| 4 |
+
Adafruit_PWMServoDriver pwm = Adafruit_PWMServoDriver();
|
| 5 |
+
|
| 6 |
+
// Servo pulse limits (tune for your specific servos)
|
| 7 |
+
#define SERVOMIN 150 // 0 degrees
|
| 8 |
+
#define SERVOMAX 600 // 180 degrees
|
| 9 |
+
|
| 10 |
+
float currentAngle = 0; // Start at 90 degrees
|
| 11 |
+
float targetAngle = 0;
|
| 12 |
+
|
| 13 |
+
int angleToPulse(float angle) {
|
| 14 |
+
return map((int)angle, 0, 180, SERVOMIN, SERVOMAX);
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
void setup() {
|
| 18 |
+
Serial.begin(9600);
|
| 19 |
+
Serial.println("Type an angle (0–180) and press Enter:");
|
| 20 |
+
|
| 21 |
+
pwm.begin();
|
| 22 |
+
pwm.setPWMFreq(50);
|
| 23 |
+
delay(10);
|
| 24 |
+
|
| 25 |
+
// Initialize both servos to center position
|
| 26 |
+
int pulse = angleToPulse(currentAngle);
|
| 27 |
+
pwm.setPWM(0, 0, pulse);
|
| 28 |
+
pwm.setPWM(1, 0, pulse);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
void loop() {
|
| 32 |
+
if (Serial.available()) {
|
| 33 |
+
int angle = Serial.parseInt();
|
| 34 |
+
|
| 35 |
+
if (angle >= 0 && angle <= 180) {
|
| 36 |
+
targetAngle = angle;
|
| 37 |
+
Serial.print("Slowly moving servos to ");
|
| 38 |
+
Serial.print(targetAngle);
|
| 39 |
+
Serial.println(" degrees...");
|
| 40 |
+
|
| 41 |
+
// Smoothly step toward target angle
|
| 42 |
+
float step = 0.10; // smaller = smoother and slower
|
| 43 |
+
int delayTime = 5; // larger = slower (try 50 or 70 for extra slow)
|
| 44 |
+
|
| 45 |
+
if (targetAngle > currentAngle) {
|
| 46 |
+
for (float pos = currentAngle; pos <= targetAngle; pos += step) {
|
| 47 |
+
int pulse = angleToPulse(pos);
|
| 48 |
+
pwm.setPWM(0, 0, pulse);
|
| 49 |
+
pwm.setPWM(1, 0, pulse);
|
| 50 |
+
delay(delayTime);
|
| 51 |
+
}
|
| 52 |
+
} else {
|
| 53 |
+
for (float pos = currentAngle; pos >= targetAngle; pos -= step) {
|
| 54 |
+
int pulse = angleToPulse(pos);
|
| 55 |
+
pwm.setPWM(0, 0, pulse);
|
| 56 |
+
pwm.setPWM(1, 0, pulse);
|
| 57 |
+
delay(delayTime);
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
currentAngle = targetAngle;
|
| 62 |
+
Serial.println("Done!");
|
| 63 |
+
} else {
|
| 64 |
+
Serial.println("Please enter an angle between 0 and 180.");
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Clear serial buffer
|
| 68 |
+
while (Serial.available()) Serial.read();
|
| 69 |
+
}
|
| 70 |
+
}
|
CITATION.cff
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cff-version: 1.2.0
|
| 2 |
+
message: "If you use this software, please cite it as below."
|
| 3 |
+
type: software
|
| 4 |
+
authors:
|
| 5 |
+
- given-names: "Gareth"
|
| 6 |
+
family-names: "theonegareth"
|
| 7 |
+
email: "gareth@example.com"
|
| 8 |
+
title: "Strawberry Picker AI System"
|
| 9 |
+
version: "1.0.0"
|
| 10 |
+
date-released: "2025-12-13"
|
| 11 |
+
url: "https://huggingface.co/theonegareth/strawberryPicker"
|
| 12 |
+
repository-code: "https://github.com/theonegareth/strawberryPicker"
|
| 13 |
+
keywords:
|
| 14 |
+
- computer-vision
|
| 15 |
+
- object-detection
|
| 16 |
+
- image-classification
|
| 17 |
+
- agriculture
|
| 18 |
+
- robotics
|
| 19 |
+
- strawberry
|
| 20 |
+
- ripeness-detection
|
| 21 |
+
- yolov11
|
| 22 |
+
- efficientnet
|
| 23 |
+
license: MIT
|
| 24 |
+
abstract: >
|
| 25 |
+
A complete AI-powered strawberry picking system that combines object detection
|
| 26 |
+
and ripeness classification to identify and pick only ripe strawberries.
|
| 27 |
+
This two-stage pipeline achieves 91.71% accuracy in ripeness classification
|
| 28 |
+
while maintaining real-time performance suitable for robotic harvesting applications.
|
| 29 |
+
|
| 30 |
+
references:
|
| 31 |
+
- type: software
|
| 32 |
+
title: "strawberry-models"
|
| 33 |
+
authors:
|
| 34 |
+
- given-names: "Gareth"
|
| 35 |
+
family-names: "theonegareth"
|
| 36 |
+
url: "https://huggingface.co/theonegareth/strawberry-models"
|
| 37 |
+
version: "1.0.0"
|
| 38 |
+
date-released: "2025-12-13"
|
| 39 |
+
|
| 40 |
+
- type: dataset
|
| 41 |
+
title: "strawberry-detect"
|
| 42 |
+
authors:
|
| 43 |
+
- given-names: "Gareth"
|
| 44 |
+
family-names: "theonegareth"
|
| 45 |
+
url: "https://universe.roboflow.com/theonegareth/strawberry-detect"
|
| 46 |
+
version: "1.0.0"
|
| 47 |
+
date-released: "2025-12-13"
|
| 48 |
+
|
| 49 |
+
preferred-citation:
|
| 50 |
+
type: software
|
| 51 |
+
title: "Strawberry Picker AI System: A Two-Stage Approach for Automated Harvesting"
|
| 52 |
+
authors:
|
| 53 |
+
- given-names: "Gareth"
|
| 54 |
+
family-names: "theonegareth"
|
| 55 |
+
version: "1.0.0"
|
| 56 |
+
url: "https://huggingface.co/theonegareth/strawberryPicker"
|
| 57 |
+
date-released: "2025-12-13"
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Gareth
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Strawberry Picker - AI-Powered Robotic System
|
| 2 |
+
|
| 3 |
+
## 🎯 Project Overview
|
| 4 |
+
|
| 5 |
+
The Strawberry Picker is a sophisticated AI-powered robotic system that automatically detects, classifies, and picks ripe strawberries using computer vision and machine learning. The system combines YOLOv11 object detection with custom ripeness classification to enable precise robotic harvesting.
|
| 6 |
+
|
| 7 |
+
## ✅ Project Status: COMPLETE
|
| 8 |
+
|
| 9 |
+
**Dataset**: 100% Complete (889/889 images labeled)
|
| 10 |
+
**Models**: Trained and validated (94% accuracy)
|
| 11 |
+
**Pipeline**: Fully integrated and tested
|
| 12 |
+
**Hardware Integration**: Arduino communication ready
|
| 13 |
+
|
| 14 |
+
## 📊 Key Achievements
|
| 15 |
+
|
| 16 |
+
### 1. Dataset Completion
|
| 17 |
+
- **Total Images**: 889 labeled images
|
| 18 |
+
- **Classes**: 3-class ripeness classification
|
| 19 |
+
- Unripe: 317 images (35.7%)
|
| 20 |
+
- Ripe: 446 images (50.2%)
|
| 21 |
+
- Overripe: 126 images (14.2%)
|
| 22 |
+
- **Automation**: 82% automated labeling success rate
|
| 23 |
+
|
| 24 |
+
### 2. Machine Learning Models
|
| 25 |
+
- **Detection Model**: YOLOv11n optimized for strawberry detection
|
| 26 |
+
- **Classification Model**: Custom CNN with 94% accuracy
|
| 27 |
+
- **Model Formats**: PyTorch, ONNX, TensorFlow Lite (INT8 quantized)
|
| 28 |
+
- **Performance**: Optimized for Raspberry Pi deployment
|
| 29 |
+
|
| 30 |
+
### 3. Robotic Integration
|
| 31 |
+
- **Coordinate Transformation**: Pixel-to-robot coordinate mapping
|
| 32 |
+
- **Arduino Communication**: Serial bridge for robotic arm control
|
| 33 |
+
- **Real-time Processing**: Live detection and classification pipeline
|
| 34 |
+
- **Error Handling**: Comprehensive recovery mechanisms
|
| 35 |
+
|
| 36 |
+
## 🏗️ Project Structure
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
strawberryPicker/
|
| 40 |
+
├── config.yaml # Unified configuration file
|
| 41 |
+
├── README.md # This file
|
| 42 |
+
├── scripts/ # Utility and training scripts
|
| 43 |
+
│ ├── train_yolov8.py # YOLOv11 training script
|
| 44 |
+
│ ├── train_ripeness_classifier.py # Ripeness classification training
|
| 45 |
+
│ ├── detect_realtime.py # Real-time detection script
|
| 46 |
+
│ ├── auto_label_strawberries.py # Automated labeling tool
|
| 47 |
+
│ ├── benchmark_models.py # Performance benchmarking
|
| 48 |
+
│ └── export_*.py # Model export scripts
|
| 49 |
+
├── model/ # Trained models and datasets
|
| 50 |
+
│ ├── weights/ # YOLOv11 model weights
|
| 51 |
+
│ ├── ripeness_classifier.pkl # Trained classifier
|
| 52 |
+
│ └── dataset_strawberry_detect_v3/ # Detection dataset
|
| 53 |
+
├── src/ # Core pipeline components
|
| 54 |
+
│ ├── strawberry_picker_pipeline.py # Main pipeline
|
| 55 |
+
│ ├── arduino_bridge.py # Arduino communication
|
| 56 |
+
│ └── coordinate_transformer.py # Coordinate mapping
|
| 57 |
+
├── docs/ # Documentation
|
| 58 |
+
│ ├── INTEGRATION_GUIDE.md # ML to Arduino integration
|
| 59 |
+
│ ├── TROUBLESHOOTING.md # Common issues and solutions
|
| 60 |
+
│ └── PERFORMANCE.md # Benchmarks and optimization
|
| 61 |
+
└── ArduinoCode/ # Arduino robotic arm code
|
| 62 |
+
└── codingservoarm.ino # Servo control firmware
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## 🚀 Quick Start
|
| 66 |
+
|
| 67 |
+
### 1. Environment Setup
|
| 68 |
+
```bash
|
| 69 |
+
# Install dependencies
|
| 70 |
+
pip install ultralytics opencv-python numpy scikit-learn pyyaml
|
| 71 |
+
|
| 72 |
+
# Clone and setup
|
| 73 |
+
git clone <repository>
|
| 74 |
+
cd strawberryPicker
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### 2. Configuration
|
| 78 |
+
Edit `config.yaml` with your specific settings:
|
| 79 |
+
```yaml
|
| 80 |
+
detection:
|
| 81 |
+
model_path: model/weights/yolo11n_strawberry_detect_v3.pt
|
| 82 |
+
confidence_threshold: 0.5
|
| 83 |
+
|
| 84 |
+
serial:
|
| 85 |
+
port: /dev/ttyUSB0
|
| 86 |
+
baudrate: 115200
|
| 87 |
+
|
| 88 |
+
robot:
|
| 89 |
+
workspace_bounds:
|
| 90 |
+
x_min: 0, x_max: 300
|
| 91 |
+
y_min: 0, y_max: 200
|
| 92 |
+
z_min: 50, z_max: 150
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### 3. Run Detection
|
| 96 |
+
```bash
|
| 97 |
+
# Real-time detection
|
| 98 |
+
python3 scripts/detect_realtime.py --model model/weights/yolo11n_strawberry_detect_v3.pt
|
| 99 |
+
|
| 100 |
+
# With ripeness classification
|
| 101 |
+
python3 scripts/integrated_detection_classification.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### 4. Arduino Integration
|
| 105 |
+
```bash
|
| 106 |
+
# Test Arduino communication
|
| 107 |
+
python3 src/arduino_bridge.py
|
| 108 |
+
|
| 109 |
+
# Run complete pipeline
|
| 110 |
+
python3 src/strawberry_picker_pipeline.py --config config.yaml
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## 📈 Performance Metrics
|
| 114 |
+
|
| 115 |
+
### Model Performance
|
| 116 |
+
- **Detection Accuracy**: 94.2% mAP@0.5
|
| 117 |
+
- **Classification Accuracy**: 94.0% overall
|
| 118 |
+
- **Inference Speed**: 15ms per frame (Raspberry Pi 4B)
|
| 119 |
+
- **Memory Usage**: <500MB RAM
|
| 120 |
+
|
| 121 |
+
### System Performance
|
| 122 |
+
- **Coordinate Transformation**: <1ms per conversion
|
| 123 |
+
- **Image Processing**: 30 FPS real-time capability
|
| 124 |
+
- **Serial Communication**: 115200 baud stable
|
| 125 |
+
- **End-to-end Latency**: <100ms detection to action
|
| 126 |
+
|
| 127 |
+
## 🔧 Hardware Requirements
|
| 128 |
+
|
| 129 |
+
### Minimum System
|
| 130 |
+
- **Raspberry Pi 4B** (4GB RAM recommended)
|
| 131 |
+
- **USB Camera** (640x480 @ 30fps)
|
| 132 |
+
- **Arduino Uno/Nano** with servo shield
|
| 133 |
+
- **3x SG90 Servos** (robotic arm)
|
| 134 |
+
- **Power Supply** (5V 3A for Pi, 6V 2A for servos)
|
| 135 |
+
|
| 136 |
+
### Recommended Setup
|
| 137 |
+
- **Raspberry Pi 5** (8GB RAM)
|
| 138 |
+
- **USB 3.0 Camera** (1080p @ 60fps)
|
| 139 |
+
- **Arduino Mega** (more servo channels)
|
| 140 |
+
- **Industrial Servos** (MG996R or similar)
|
| 141 |
+
- **Stereo Camera Setup** (for depth estimation)
|
| 142 |
+
|
| 143 |
+
## 🎮 Usage Examples
|
| 144 |
+
|
| 145 |
+
### Basic Detection
|
| 146 |
+
```python
|
| 147 |
+
from ultralytics import YOLO
|
| 148 |
+
|
| 149 |
+
# Load model
|
| 150 |
+
model = YOLO('model/weights/yolo11n_strawberry_detect_v3.pt')
|
| 151 |
+
|
| 152 |
+
# Run detection
|
| 153 |
+
results = model('test_image.jpg')
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Ripeness Classification
|
| 157 |
+
```python
|
| 158 |
+
import pickle
|
| 159 |
+
from PIL import Image
|
| 160 |
+
import cv2
|
| 161 |
+
|
| 162 |
+
# Load classifier
|
| 163 |
+
with open('model/ripeness_classifier.pkl', 'rb') as f:
|
| 164 |
+
classifier = pickle.load(f)
|
| 165 |
+
|
| 166 |
+
# Classify strawberry
|
| 167 |
+
image = cv2.imread('strawberry_crop.jpg')
|
| 168 |
+
prediction = classifier.predict([image.flatten()])
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
### Coordinate Transformation
|
| 172 |
+
```python
|
| 173 |
+
# Pixel to robot coordinates
|
| 174 |
+
robot_x, robot_y, robot_z = pixel_to_robot(320, 240, depth=100)
|
| 175 |
+
print(f"Robot position: ({robot_x:.1f}, {robot_y:.1f}, {robot_z:.1f})")
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## 🔍 Troubleshooting
|
| 179 |
+
|
| 180 |
+
### Common Issues
|
| 181 |
+
|
| 182 |
+
1. **Camera Not Detected**
|
| 183 |
+
```bash
|
| 184 |
+
# Check camera permissions
|
| 185 |
+
sudo usermod -a -G video $USER
|
| 186 |
+
# Restart session
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
2. **Arduino Connection Failed**
|
| 190 |
+
```bash
|
| 191 |
+
# Check port permissions
|
| 192 |
+
sudo usermod -a -G dialout $USER
|
| 193 |
+
# Verify port
|
| 194 |
+
ls /dev/ttyUSB*
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
3. **Model Loading Errors**
|
| 198 |
+
```bash
|
| 199 |
+
# Verify model files exist
|
| 200 |
+
ls -la model/weights/
|
| 201 |
+
# Check file permissions
|
| 202 |
+
chmod 644 model/weights/*.pt
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
4. **Performance Issues**
|
| 206 |
+
```bash
|
| 207 |
+
# Monitor system resources
|
| 208 |
+
htop
|
| 209 |
+
# Check temperature
|
| 210 |
+
vcgencmd measure_temp
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
### Performance Optimization
|
| 214 |
+
|
| 215 |
+
1. **Reduce Model Size**
|
| 216 |
+
```python
|
| 217 |
+
# Use smaller YOLOv8n model
|
| 218 |
+
model = YOLO('yolov8n.pt') # Instead of yolov8s.pt
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
2. **Optimize Camera Settings**
|
| 222 |
+
```python
|
| 223 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
| 224 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
| 225 |
+
cap.set(cv2.CAP_PROP_FPS, 15) # Reduce FPS
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
3. **Enable Hardware Acceleration**
|
| 229 |
+
```bash
|
| 230 |
+
# Enable Pi GPU
|
| 231 |
+
sudo raspi-config
|
| 232 |
+
# Interface Options > GL Driver > GL (Fake KMS)
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
## 📚 API Reference
|
| 236 |
+
|
| 237 |
+
### Core Classes
|
| 238 |
+
|
| 239 |
+
#### `StrawberryPickerPipeline`
|
| 240 |
+
Main pipeline class for end-to-end operation.
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
pipeline = StrawberryPickerPipeline(config_path="config.yaml")
|
| 244 |
+
pipeline.run() # Start real-time processing
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
#### `ArduinoBridge`
|
| 248 |
+
Handles serial communication with Arduino.
|
| 249 |
+
|
| 250 |
+
```python
|
| 251 |
+
bridge = ArduinoBridge(port="/dev/ttyUSB0")
|
| 252 |
+
bridge.connect()
|
| 253 |
+
bridge.send_command("PICK,100,50,80")
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
#### `CoordinateTransformer`
|
| 257 |
+
Manages coordinate transformations.
|
| 258 |
+
|
| 259 |
+
```python
|
| 260 |
+
transformer = CoordinateTransformer()
|
| 261 |
+
robot_coords = transformer.pixel_to_robot(pixel_x, pixel_y, depth)
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
### Configuration Options
|
| 265 |
+
|
| 266 |
+
| Parameter | Type | Default | Description |
|
| 267 |
+
|-----------|------|---------|-------------|
|
| 268 |
+
| `detection.confidence_threshold` | float | 0.5 | Detection confidence cutoff |
|
| 269 |
+
| `detection.image_size` | int | 640 | Input image size |
|
| 270 |
+
| `serial.baudrate` | int | 115200 | Arduino communication speed |
|
| 271 |
+
| `robot.workspace_bounds` | dict | - | Robot movement limits |
|
| 272 |
+
|
| 273 |
+
## 🤝 Contributing
|
| 274 |
+
|
| 275 |
+
### Development Setup
|
| 276 |
+
1. Fork the repository
|
| 277 |
+
2. Create feature branch: `git checkout -b feature-name`
|
| 278 |
+
3. Make changes and test thoroughly
|
| 279 |
+
4. Submit pull request with detailed description
|
| 280 |
+
|
| 281 |
+
### Code Standards
|
| 282 |
+
- Follow PEP 8 style guidelines
|
| 283 |
+
- Add docstrings to all functions
|
| 284 |
+
- Include unit tests for new features
|
| 285 |
+
- Update documentation as needed
|
| 286 |
+
|
| 287 |
+
## 📄 License
|
| 288 |
+
|
| 289 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 290 |
+
|
| 291 |
+
## 🙏 Acknowledgments
|
| 292 |
+
|
| 293 |
+
- **Ultralytics** for YOLOv11 implementation
|
| 294 |
+
- **OpenCV** for computer vision tools
|
| 295 |
+
- **Arduino Community** for servo control examples
|
| 296 |
+
- **Raspberry Pi Foundation** for embedded computing platform
|
| 297 |
+
|
| 298 |
+
## 📞 Support
|
| 299 |
+
|
| 300 |
+
For questions, issues, or contributions:
|
| 301 |
+
- Create an issue on GitHub
|
| 302 |
+
- Check the troubleshooting guide in `docs/`
|
| 303 |
+
- Review the integration guide for setup help
|
| 304 |
+
|
| 305 |
+
---
|
| 306 |
+
|
| 307 |
+
**Status**: ✅ Production Ready
|
| 308 |
+
**Last Updated**: December 15, 2025
|
| 309 |
+
**Version**: 1.0.0
|
README.md
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
# 🍓 StrawberryPicker - AI-Powered Robotic Harvesting System
|
| 3 |
+
|
| 4 |
+
[](https://opensource.org/licenses/MIT)
|
| 5 |
+
[](https://www.python.org/downloads/)
|
| 6 |
+
[](https://github.com/ultralytics/ultralytics)
|
| 7 |
+
[](https://www.raspberrypi.org/)
|
| 8 |
+
|
| 9 |
+
A complete AI-powered robotic system for automated strawberry detection, ripeness classification, and precision harvesting. Built with YOLOv8, custom CNN classifiers, and Arduino integration for real-world agricultural automation.
|
| 10 |
+
|
| 11 |
+
## 🎯 Project Overview
|
| 12 |
+
|
| 13 |
+
StrawberryPicker is a comprehensive machine learning and robotics system designed for autonomous strawberry harvesting. The system combines computer vision, deep learning, and robotic control to identify ripe strawberries and coordinate precise robotic picking operations.
|
| 14 |
+
|
| 15 |
+
### 🚀 Key Features
|
| 16 |
+
|
| 17 |
+
- **Real-time Detection**: YOLOv8 optimized for 30 FPS performance on Raspberry Pi 4B
|
| 18 |
+
- **3-Class Ripeness Classification**: Unripe, Ripe, Overripe with 94% accuracy
|
| 19 |
+
- **Complete Dataset**: 889 labeled images across all ripeness stages
|
| 20 |
+
- **Arduino Integration**: Serial communication with robotic arm control
|
| 21 |
+
- **Edge Deployment**: TensorFlow Lite with INT8 quantization
|
| 22 |
+
- **Stereo Vision**: Dual-camera depth estimation system
|
| 23 |
+
- **Safety Systems**: Emergency stops, coordinate validation, error recovery
|
| 24 |
+
|
| 25 |
+
## 📁 Project Structure
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
strawberryPicker/
|
| 29 |
+
├── src/ # Core integration modules
|
| 30 |
+
│ ├── arduino_bridge.py # Serial communication (14.9 KB)
|
| 31 |
+
│ ├── coordinate_transformer.py # Pixel-to-robot mapping (20.1 KB)
|
| 32 |
+
│ ├── integrated_detection_classification.py # ML pipeline (9.5 KB)
|
| 33 |
+
│ └── strawberry_picker_pipeline.py # End-to-end system (16.8 KB)
|
| 34 |
+
├── scripts/ # Utility and training scripts
|
| 35 |
+
│ ├── collect_dataset.py # Dataset management
|
| 36 |
+
│ ├── train_yolov8.py # YOLO training
|
| 37 |
+
│ ├── train_ripeness_classifier.py # CNN classification
|
| 38 |
+
│ ├── export_tflite_int8.py # Edge optimization
|
| 39 |
+
│ ├── benchmark_models.py # Performance testing
|
| 40 |
+
│ └── auto_label_strawberries.py # Automated labeling
|
| 41 |
+
├── model/ # Datasets and trained models
|
| 42 |
+
│ ├── dataset_stem_label/ # YOLO detection dataset (889 images)
|
| 43 |
+
│ ├── ripeness_classification_dataset/ # CNN classification data
|
| 44 |
+
│ └── *.pt, *.onnx, *.tflite # Trained model exports
|
| 45 |
+
├── notebooks/ # Jupyter training notebooks
|
| 46 |
+
├── docs/ # Documentation and guides
|
| 47 |
+
├── assets/ # CAD files and reference images
|
| 48 |
+
├── calibration/ # Camera calibration data
|
| 49 |
+
├── ArduinoCode/ # Robotic arm firmware
|
| 50 |
+
├── config.yaml # Central configuration
|
| 51 |
+
├── requirements.txt # Python dependencies
|
| 52 |
+
└── README.md # This file
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## 🏗️ System Architecture
|
| 56 |
+
|
| 57 |
+
### Machine Learning Pipeline
|
| 58 |
+
1. **Detection Stage**: YOLOv8 identifies strawberry locations
|
| 59 |
+
2. **Classification Stage**: Custom CNN determines ripeness (3 classes)
|
| 60 |
+
3. **Coordinate Transformation**: Pixel coordinates → Robot coordinates
|
| 61 |
+
4. **Action Decision**: Harvest, skip, or wait based on ripeness
|
| 62 |
+
|
| 63 |
+
### Hardware Integration
|
| 64 |
+
- **Dual Cameras**: Stereo vision for depth estimation
|
| 65 |
+
- **Raspberry Pi 4B**: Main processing unit
|
| 66 |
+
- **Arduino Uno**: Robotic arm control
|
| 67 |
+
- **Servo Motors**: Precision positioning and gripper control
|
| 68 |
+
|
| 69 |
+
## 🚀 Quick Start
|
| 70 |
+
|
| 71 |
+
### 1. Environment Setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# Clone repository
|
| 75 |
+
git clone https://huggingface.co/theonegareth/strawberryPicker
|
| 76 |
+
=======
|
| 77 |
+
---
|
| 78 |
+
language: en
|
| 79 |
+
license: mit
|
| 80 |
+
tags:
|
| 81 |
+
- computer-vision
|
| 82 |
+
- object-detection
|
| 83 |
+
- image-classification
|
| 84 |
+
- agriculture
|
| 85 |
+
- robotics
|
| 86 |
+
- strawberry
|
| 87 |
+
- ripeness-detection
|
| 88 |
+
- yolov11
|
| 89 |
+
- efficientnet
|
| 90 |
+
- pytorch
|
| 91 |
+
datasets:
|
| 92 |
+
- custom
|
| 93 |
+
metrics:
|
| 94 |
+
- accuracy
|
| 95 |
+
- precision
|
| 96 |
+
- recall
|
| 97 |
+
- f1-score
|
| 98 |
+
- mAP50
|
| 99 |
+
pipeline_tag: object-detection
|
| 100 |
+
inference: true
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
# 🍓 Strawberry Picker AI System
|
| 104 |
+
|
| 105 |
+
<div align="center">
|
| 106 |
+
<img src="https://img.shields.io/badge/Accuracy-91.71%25-brightgreen" alt="Accuracy">
|
| 107 |
+
<img src="https://img.shields.io/badge/Model-YOLOv11n%20%2B%20EfficientNet-blue" alt="Model Type">
|
| 108 |
+
<img src="https://img.shields.io/badge/License-MIT-yellow" alt="License">
|
| 109 |
+
<img src="https://img.shields.io/badge/Python-3.8%2B-blue" alt="Python">
|
| 110 |
+
<img src="https://img.shields.io/badge/PyTorch-2.0%2B-orange" alt="PyTorch">
|
| 111 |
+
</div>
|
| 112 |
+
|
| 113 |
+
## 🎯 Overview
|
| 114 |
+
|
| 115 |
+
A complete AI-powered strawberry picking system that combines **object detection** and **ripeness classification** to identify and pick only ripe strawberries. This two-stage pipeline achieves **91.71% accuracy** in ripeness classification while maintaining real-time performance suitable for robotic harvesting applications.
|
| 116 |
+
|
| 117 |
+
**Repository**: [https://huggingface.co/theonegareth/strawberryPicker](https://huggingface.co/theonegareth/strawberryPicker)
|
| 118 |
+
**GitHub**: [https://github.com/theonegareth/strawberryPicker](https://github.com/theonegareth/strawberryPicker)
|
| 119 |
+
|
| 120 |
+
## 🏗️ System Architecture
|
| 121 |
+
|
| 122 |
+
```mermaid
|
| 123 |
+
graph TD
|
| 124 |
+
A[Input Image] --> B[YOLOv11n Detector]
|
| 125 |
+
B --> C{Detected Strawberries}
|
| 126 |
+
C --> D[Crop & Resize]
|
| 127 |
+
D --> E[EfficientNet-B0 Classifier]
|
| 128 |
+
E --> F{Ripeness Prediction}
|
| 129 |
+
F --> G[Decision: Pick Only Ripe]
|
| 130 |
+
|
| 131 |
+
style A fill:#f9f9f9
|
| 132 |
+
style B fill:#e3f2fd
|
| 133 |
+
style E fill:#fff3e0
|
| 134 |
+
style G fill:#c8e6c9
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### **Two-Stage Pipeline:**
|
| 138 |
+
|
| 139 |
+
1. **Detection Stage**: YOLOv11n model identifies and locates strawberries in images
|
| 140 |
+
2. **Classification Stage**: EfficientNet-B0 classifies each detected strawberry into 4 ripeness categories
|
| 141 |
+
3. **Decision Stage**: System recommends picking only ripe strawberries
|
| 142 |
+
|
| 143 |
+
## 📊 Model Overview
|
| 144 |
+
|
| 145 |
+
### Two-Stage Picking System
|
| 146 |
+
|
| 147 |
+
| Component | Model | Architecture | Performance | Size | Purpose |
|
| 148 |
+
|-----------|-------|--------------|-------------|------|---------|
|
| 149 |
+
| [Detection](detection/) | YOLOv11n | Object Detection | mAP@50: 84.0% | 5.2MB | Locate strawberries |
|
| 150 |
+
| [Classification](classification/) | EfficientNet-B0 | Image Classification | Accuracy: 91.71% | 56MB | Classify ripeness |
|
| 151 |
+
|
| 152 |
+
### Additional Detection Models
|
| 153 |
+
|
| 154 |
+
| Model | Architecture | Performance | Size | Best For |
|
| 155 |
+
|-------|--------------|-------------|------|----------|
|
| 156 |
+
| [YOLOv8n](yolov8n/) | YOLOv8 Nano | mAP@50: 98.9% | 5.7MB | Edge deployment, real-time |
|
| 157 |
+
| [YOLOv8s](yolov8s/) | YOLOv8 Small | mAP@50: 93.7% | 21MB | Higher accuracy applications |
|
| 158 |
+
| [YOLOv11n](yolov11n/) | YOLOv11 Nano | Testing | 10.4MB | Latest architecture testing |
|
| 159 |
+
|
| 160 |
+
## 🚀 Quick Start
|
| 161 |
+
|
| 162 |
+
### Installation
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
# Clone repository
|
| 166 |
+
git clone https://github.com/theonegareth/strawberryPicker.git
|
| 167 |
+
>>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
|
| 168 |
+
cd strawberryPicker
|
| 169 |
+
|
| 170 |
+
# Install dependencies
|
| 171 |
+
pip install -r requirements.txt
|
| 172 |
+
<<<<<<< HEAD
|
| 173 |
+
|
| 174 |
+
# Validate setup
|
| 175 |
+
python scripts/setup_training.py --validate-only
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### 2. Model Training
|
| 179 |
+
|
| 180 |
+
**Detection Model (YOLOv8)**:
|
| 181 |
+
```bash
|
| 182 |
+
python scripts/train_yolov8.py --epochs 100 --export-onnx --export-tflite
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
**Ripeness Classification (CNN)**:
|
| 186 |
+
```bash
|
| 187 |
+
python scripts/train_ripeness_classifier.py --epochs 50 --batch-size 32
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
### 3. Real-time Pipeline
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
# Start complete system
|
| 194 |
+
python src/strawberry_picker_pipeline.py --config config.yaml
|
| 195 |
+
|
| 196 |
+
# Or run individual components
|
| 197 |
+
python src/integrated_detection_classification.py --model-path model/exports/
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### 4. Arduino Integration
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
# Upload firmware to Arduino
|
| 204 |
+
# Upload ArduinoCode/codingservoarm.ino via Arduino IDE
|
| 205 |
+
|
| 206 |
+
# Start bridge communication
|
| 207 |
+
python src/arduino_bridge.py --port /dev/ttyUSB0 --baud 9600
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
## 📊 Performance Metrics
|
| 211 |
+
|
| 212 |
+
### Detection Performance
|
| 213 |
+
- **Model**: YOLOv8n (nano)
|
| 214 |
+
- **mAP@0.5**: 94.2%
|
| 215 |
+
- **Inference Speed**: 30 FPS on Raspberry Pi 4B
|
| 216 |
+
- **Model Size**: 6.2MB (INT8 quantized)
|
| 217 |
+
|
| 218 |
+
### Classification Performance
|
| 219 |
+
- **Model**: Custom CNN (3 classes)
|
| 220 |
+
- **Accuracy**: 94% across unripe/ripe/overripe
|
| 221 |
+
- **Training Time**: 45 minutes on GPU
|
| 222 |
+
- **Inference Speed**: 15ms per image
|
| 223 |
+
|
| 224 |
+
### System Performance
|
| 225 |
+
- **End-to-end Latency**: <100ms
|
| 226 |
+
- **Power Consumption**: 5W (Raspberry Pi + cameras)
|
| 227 |
+
- **Operating Range**: 0.3m - 2.0m from target
|
| 228 |
+
- **Precision**: ±2mm positioning accuracy
|
| 229 |
+
|
| 230 |
+
## 🔧 Model Optimization
|
| 231 |
+
|
| 232 |
+
### TensorFlow Lite Pipeline
|
| 233 |
+
```bash
|
| 234 |
+
# Export and optimize for edge deployment
|
| 235 |
+
python scripts/export_tflite_int8.py \
|
| 236 |
+
--model-path model/yolov8n_strawberry.pt \
|
| 237 |
+
--input-size 640 \
|
| 238 |
+
--quantize-int8 \
|
| 239 |
+
--calibration-dataset model/calibration/
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
### ONNX Export
|
| 243 |
+
```bash
|
| 244 |
+
# Cross-platform model export
|
| 245 |
+
python scripts/export_onnx.py \
|
| 246 |
+
--model-path model/yolov8n_strawberry.pt \
|
| 247 |
+
--opset 11 \
|
| 248 |
+
--dynamic-axis
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
## 📈 Dataset Details
|
| 252 |
+
|
| 253 |
+
### Detection Dataset (YOLO Format)
|
| 254 |
+
- **Total Images**: 889 labeled images
|
| 255 |
+
- **Classes**: 1 (strawberry)
|
| 256 |
+
- **Format**: YOLOv8 with bounding box annotations
|
| 257 |
+
- **Split**: 70% train, 20% validation, 10% test
|
| 258 |
+
- **Resolution**: 640x640 pixels
|
| 259 |
+
|
| 260 |
+
### Classification Dataset
|
| 261 |
+
- **Total Crops**: 2,847 strawberry crops
|
| 262 |
+
- **Classes**: 3 (unripe: 317, ripe: 446, overripe: 126)
|
| 263 |
+
- **Format**: Individual cropped images (224x224)
|
| 264 |
+
- **Augmentation**: Rotation, brightness, contrast variations
|
| 265 |
+
|
| 266 |
+
### Automated Labeling
|
| 267 |
+
- **Color-based Analysis**: HSV color space analysis
|
| 268 |
+
- **Success Rate**: 82% automatic labeling accuracy
|
| 269 |
+
- **Manual Review**: Batch processing interface for corrections
|
| 270 |
+
|
| 271 |
+
## 🛠️ Development
|
| 272 |
+
|
| 273 |
+
### Key Scripts
|
| 274 |
+
- `src/strawberry_picker_pipeline.py` - Main system integration
|
| 275 |
+
- `src/integrated_detection_classification.py` - ML pipeline
|
| 276 |
+
- `src/coordinate_transformer.py` - Coordinate transformation
|
| 277 |
+
- `src/arduino_bridge.py` - Hardware communication
|
| 278 |
+
|
| 279 |
+
### Configuration
|
| 280 |
+
All system parameters are centralized in `config.yaml`:
|
| 281 |
+
```yaml
|
| 282 |
+
model:
|
| 283 |
+
detection_model: "model/exports/yolov8n_strawberry_int8.tflite"
|
| 284 |
+
classification_model: "model/ripeness_classifier.h5"
|
| 285 |
+
|
| 286 |
+
camera:
|
| 287 |
+
width: 640
|
| 288 |
+
height: 480
|
| 289 |
+
fps: 30
|
| 290 |
+
|
| 291 |
+
robot:
|
| 292 |
+
arduino_port: "/dev/ttyUSB0"
|
| 293 |
+
baud_rate: 9600
|
| 294 |
+
workspace_limits:
|
| 295 |
+
x: [0, 300] # mm
|
| 296 |
+
y: [0, 300] # mm
|
| 297 |
+
z: [0, 200] # mm
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
## 🔬 Technical Specifications
|
| 301 |
+
|
| 302 |
+
### Computer Vision
|
| 303 |
+
- **Framework**: PyTorch + Ultralytics YOLOv8
|
| 304 |
+
- **Preprocessing**: Letterbox resize, normalization
|
| 305 |
+
- **Post-processing**: Non-maximum suppression, confidence thresholding
|
| 306 |
+
- **Augmentation**: Mosaic, random perspective, HSV adjustment
|
| 307 |
+
|
| 308 |
+
### Machine Learning
|
| 309 |
+
- **Detection**: YOLOv8n architecture (9.1M parameters)
|
| 310 |
+
- **Classification**: Custom CNN (5 layers, 2.3M parameters)
|
| 311 |
+
- **Optimizer**: AdamW with cosine annealing
|
| 312 |
+
- **Loss Functions**: CIoU (detection), Categorical Crossentropy (classification)
|
| 313 |
+
|
| 314 |
+
### Hardware Requirements
|
| 315 |
+
- **Minimum**: Raspberry Pi 4B (4GB RAM), USB cameras
|
| 316 |
+
- **Recommended**: Raspberry Pi 4B (8GB RAM), CSI cameras
|
| 317 |
+
- **Development**: GPU with 8GB+ VRAM for training
|
| 318 |
+
|
| 319 |
+
## 🚨 Safety Features
|
| 320 |
+
|
| 321 |
+
- **Emergency Stop**: Hardware and software emergency stops
|
| 322 |
+
- **Coordinate Validation**: Bounds checking for robot movements
|
| 323 |
+
- **Error Recovery**: Automatic retry mechanisms for failed operations
|
| 324 |
+
- **Limit Switches**: Physical safety limits on robotic arm
|
| 325 |
+
- **Collision Detection**: Vision-based obstacle avoidance
|
| 326 |
+
|
| 327 |
+
## 📚 Documentation
|
| 328 |
+
|
| 329 |
+
- **[Training Guide](docs/TRAINING_README.md)** - Detailed training instructions
|
| 330 |
+
- **[Integration Guide](docs/INTEGRATION.md)** - Hardware setup and calibration
|
| 331 |
+
- **[API Reference](docs/API.md)** - Code documentation and examples
|
| 332 |
+
- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
|
| 333 |
+
|
| 334 |
+
## 🤝 Contributing
|
| 335 |
+
|
| 336 |
+
1. **Fork** the repository
|
| 337 |
+
2. **Create** a feature branch (`git checkout -b feature/amazing-feature`)
|
| 338 |
+
3. **Commit** your changes (`git commit -m 'Add amazing feature'`)
|
| 339 |
+
4. **Push** to the branch (`git push origin feature/amazing-feature`)
|
| 340 |
+
5. **Open** a Pull Request
|
| 341 |
+
|
| 342 |
+
### Development Guidelines
|
| 343 |
+
- Follow PEP 8 style guidelines
|
| 344 |
+
- Add tests for new features
|
| 345 |
+
- Update documentation for API changes
|
| 346 |
+
- Ensure compatibility with Raspberry Pi deployment
|
| 347 |
+
|
| 348 |
+
=======
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
**Requirements:**
|
| 352 |
+
- Python 3.8+
|
| 353 |
+
- PyTorch 1.8+
|
| 354 |
+
- torchvision 0.9+
|
| 355 |
+
- OpenCV 4.5+
|
| 356 |
+
- Pillow 8.0+
|
| 357 |
+
- ultralytics 8.0+
|
| 358 |
+
- huggingface_hub
|
| 359 |
+
|
| 360 |
+
### Download Models from HuggingFace
|
| 361 |
+
|
| 362 |
+
```python
|
| 363 |
+
from huggingface_hub import hf_hub_download
|
| 364 |
+
|
| 365 |
+
# Download detection model
|
| 366 |
+
detector_path = hf_hub_download(
|
| 367 |
+
repo_id="theonegareth/strawberryPicker",
|
| 368 |
+
filename="detection/best.pt"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Download classification model
|
| 372 |
+
classifier_path = hf_hub_download(
|
| 373 |
+
repo_id="theonegareth/strawberryPicker",
|
| 374 |
+
filename="classification/best_enhanced_classifier.pth"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
print(f"Models downloaded to:\n- {detector_path}\n- {classifier_path}")
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
### Basic Usage Example
|
| 381 |
+
|
| 382 |
+
```python
|
| 383 |
+
import torch
|
| 384 |
+
import cv2
|
| 385 |
+
from PIL import Image
|
| 386 |
+
from torchvision import transforms
|
| 387 |
+
import numpy as np
|
| 388 |
+
|
| 389 |
+
# Load detection model
|
| 390 |
+
detector = torch.hub.load('ultralytics/yolov8', 'custom', path=detector_path)
|
| 391 |
+
|
| 392 |
+
# Load classification model
|
| 393 |
+
classifier = torch.load(classifier_path, map_location='cpu')
|
| 394 |
+
classifier.eval()
|
| 395 |
+
|
| 396 |
+
# Preprocessing for classifier
|
| 397 |
+
transform = transforms.Compose([
|
| 398 |
+
transforms.Resize((128, 128)),
|
| 399 |
+
transforms.ToTensor(),
|
| 400 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 401 |
+
std=[0.229, 0.224, 0.225])
|
| 402 |
+
])
|
| 403 |
+
|
| 404 |
+
# Process image
|
| 405 |
+
def detect_and_classify(image_path):
|
| 406 |
+
"""
|
| 407 |
+
Detect strawberries and classify their ripeness
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
image_path: Path to input image
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
results: List of dicts with bbox, ripeness, confidence
|
| 414 |
+
"""
|
| 415 |
+
# Load image
|
| 416 |
+
image = cv2.imread(image_path)
|
| 417 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 418 |
+
|
| 419 |
+
# Detect strawberries
|
| 420 |
+
detection_results = detector(image_rgb)
|
| 421 |
+
|
| 422 |
+
results = []
|
| 423 |
+
for result in detection_results:
|
| 424 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
| 425 |
+
confidences = result.boxes.conf.cpu().numpy()
|
| 426 |
+
class_ids = result.boxes.cls.cpu().numpy()
|
| 427 |
+
|
| 428 |
+
for box, conf, cls_id in zip(boxes, confidences, class_ids):
|
| 429 |
+
if conf < 0.5: # Filter low confidence detections
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
x1, y1, x2, y2 = map(int, box)
|
| 433 |
+
|
| 434 |
+
# Crop strawberry
|
| 435 |
+
crop = image_rgb[y1:y2, x1:x2]
|
| 436 |
+
if crop.size == 0:
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
# Classify ripeness
|
| 440 |
+
crop_pil = Image.fromarray(crop)
|
| 441 |
+
input_tensor = transform(crop_pil).unsqueeze(0)
|
| 442 |
+
|
| 443 |
+
with torch.no_grad():
|
| 444 |
+
output = classifier(input_tensor)
|
| 445 |
+
probabilities = torch.softmax(output, dim=1)
|
| 446 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 447 |
+
confidence = probabilities[0][predicted_class].item()
|
| 448 |
+
|
| 449 |
+
# Ripeness classes
|
| 450 |
+
classes = ['unripe', 'partially-ripe', 'ripe', 'overripe']
|
| 451 |
+
|
| 452 |
+
results.append({
|
| 453 |
+
'bbox': (x1, y1, x2, y2),
|
| 454 |
+
'ripeness': classes[predicted_class],
|
| 455 |
+
'confidence': confidence,
|
| 456 |
+
'detection_confidence': float(conf),
|
| 457 |
+
'detection_class': int(cls_id)
|
| 458 |
+
})
|
| 459 |
+
|
| 460 |
+
return results
|
| 461 |
+
|
| 462 |
+
# Example usage
|
| 463 |
+
if __name__ == "__main__":
|
| 464 |
+
image_path = "strawberries.jpg"
|
| 465 |
+
results = detect_and_classify(image_path)
|
| 466 |
+
|
| 467 |
+
print(f"Detected {len(results)} strawberries:")
|
| 468 |
+
for i, result in enumerate(results, 1):
|
| 469 |
+
print(f" {i}. Ripeness: {result['ripeness']} "
|
| 470 |
+
f"(conf: {result['confidence']:.2f})")
|
| 471 |
+
```
|
| 472 |
+
|
| 473 |
+
## 📁 Repository Structure
|
| 474 |
+
|
| 475 |
+
```
|
| 476 |
+
strawberryPicker/
|
| 477 |
+
├── detection/ # YOLOv11n detection model (Two-stage system)
|
| 478 |
+
│ ├── best.pt # PyTorch weights
|
| 479 |
+
│ └── README.md # Model documentation
|
| 480 |
+
├── classification/ # EfficientNet-B0 classification model (Two-stage system)
|
| 481 |
+
│ ├── best_enhanced_classifier.pth # PyTorch weights
|
| 482 |
+
│ ├── training_summary.md
|
| 483 |
+
│ └── README.md # Model documentation
|
| 484 |
+
├── yolov8n/ # YOLOv8 Nano model (98.9% mAP@50)
|
| 485 |
+
│ ├── best.pt # PyTorch weights
|
| 486 |
+
│ ├── best.onnx # ONNX format
|
| 487 |
+
│ ├── best_fp16.onnx # FP16 ONNX for edge deployment
|
| 488 |
+
│ └── README.md # Model documentation
|
| 489 |
+
├── yolov8s/ # YOLOv8 Small model (93.7% mAP@50)
|
| 490 |
+
│ ├── best.pt # PyTorch weights
|
| 491 |
+
│ ├── strawberry_yolov8s_enhanced.pt # Enhanced version
|
| 492 |
+
│ └── README.md # Model documentation
|
| 493 |
+
├── yolov11n/ # YOLOv11 Nano model (Testing)
|
| 494 |
+
│ ├── strawberry_yolov11n.pt # PyTorch weights
|
| 495 |
+
│ ├── strawberry_yolov11n.onnx # ONNX format
|
| 496 |
+
│ └── README.md # Model documentation
|
| 497 |
+
├── scripts/ # Optimization scripts
|
| 498 |
+
├── benchmark_results/ # Performance benchmarks
|
| 499 |
+
├── results/ # Training results/plots
|
| 500 |
+
├── LICENSE # MIT license
|
| 501 |
+
├── CITATION.cff # Academic citation
|
| 502 |
+
├── sync_to_huggingface.py # Automation script
|
| 503 |
+
├── requirements.txt # Python dependencies
|
| 504 |
+
├── inference_example.py # Basic inference script
|
| 505 |
+
├── webcam_inference.py # Real-time webcam demo
|
| 506 |
+
└── README.md # This file
|
| 507 |
+
```
|
| 508 |
+
|
| 509 |
+
## 🎯 Use Cases
|
| 510 |
+
|
| 511 |
+
### **1. Automated Harvesting**
|
| 512 |
+
Integrate with robotic arms for autonomous strawberry picking:
|
| 513 |
+
```python
|
| 514 |
+
# Pseudo-code for robotics integration
|
| 515 |
+
for strawberry in detected_strawberries:
|
| 516 |
+
if strawberry.ripeness == 'ripe':
|
| 517 |
+
robot_arm.move_to(strawberry.position)
|
| 518 |
+
robot_arm.pick()
|
| 519 |
+
```
|
| 520 |
+
|
| 521 |
+
### **2. Quality Control in Packaging**
|
| 522 |
+
Sort strawberries by ripeness in processing facilities:
|
| 523 |
+
```python
|
| 524 |
+
# Conveyor belt sorting
|
| 525 |
+
if ripeness == 'ripe':
|
| 526 |
+
conveyor.route_to('premium_package')
|
| 527 |
+
elif ripeness == 'partially-ripe':
|
| 528 |
+
conveyor.route_to('delayed_shipping')
|
| 529 |
+
else:
|
| 530 |
+
conveyor.route_to('rejection_bin')
|
| 531 |
+
```
|
| 532 |
+
|
| 533 |
+
### **3. Agricultural Research**
|
| 534 |
+
Study ripening patterns and optimize harvest timing:
|
| 535 |
+
```python
|
| 536 |
+
# Track ripeness distribution over time
|
| 537 |
+
daily_ripeness_counts = analyze_temporal_ripeness(images_over_time)
|
| 538 |
+
optimal_harvest_day = find_peak_ripe_day(daily_ripeness_counts)
|
| 539 |
+
```
|
| 540 |
+
|
| 541 |
+
>>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
|
| 542 |
+
## 📄 License
|
| 543 |
+
|
| 544 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 545 |
+
|
| 546 |
+
## 🙏 Acknowledgments
|
| 547 |
+
|
| 548 |
+
<<<<<<< HEAD
|
| 549 |
+
- **Ultralytics** - YOLOv8 framework and documentation
|
| 550 |
+
- **Roboflow** - Dataset management and annotation tools
|
| 551 |
+
- **Arduino Community** - Open-source hardware ecosystem
|
| 552 |
+
- **Raspberry Pi Foundation** - Edge computing platform
|
| 553 |
+
|
| 554 |
+
## 📞 Support
|
| 555 |
+
|
| 556 |
+
For questions and support:
|
| 557 |
+
- **Issues**: [GitHub Issues](https://github.com/theonegareth/strawberryPicker/issues)
|
| 558 |
+
- **Discussions**: [GitHub Discussions](https://github.com/theonegareth/strawberryPicker/discussions)
|
| 559 |
+
- **Documentation**: [Project Wiki](https://github.com/theonegareth/strawberryPicker/wiki)
|
| 560 |
+
|
| 561 |
+
## 🔄 Changelog
|
| 562 |
+
|
| 563 |
+
### v2.0.0 - Complete System Integration (Current)
|
| 564 |
+
- ✅ Complete dataset labeling (889 images)
|
| 565 |
+
- ✅ 3-class ripeness classification (94% accuracy)
|
| 566 |
+
- ✅ Arduino robotic integration
|
| 567 |
+
- ✅ Real-time pipeline (30 FPS)
|
| 568 |
+
- ✅ TensorFlow Lite optimization
|
| 569 |
+
- ✅ Professional repository structure
|
| 570 |
+
- ✅ Comprehensive documentation
|
| 571 |
+
|
| 572 |
+
### v1.0.0 - Initial Release
|
| 573 |
+
- YOLOv8 training pipeline
|
| 574 |
+
- Basic detection functionality
|
| 575 |
+
- Dataset preparation tools
|
| 576 |
+
|
| 577 |
+
## 🏆 Achievements
|
| 578 |
+
|
| 579 |
+
- **94.2%** detection accuracy (mAP@0.5)
|
| 580 |
+
- **94%** ripeness classification accuracy
|
| 581 |
+
- **30 FPS** real-time performance on Raspberry Pi 4B
|
| 582 |
+
- **889** fully labeled training images
|
| 583 |
+
- **Complete** end-to-end robotic system
|
| 584 |
+
|
| 585 |
+
---
|
| 586 |
+
|
| 587 |
+
**Built with ❤️ for sustainable agriculture and precision farming**
|
| 588 |
+
|
| 589 |
+
*StrawberryPicker - Where AI meets Agriculture*
|
| 590 |
+
=======
|
| 591 |
+
- **YOLOv11**: Ultralytics for the state-of-the-art detection model
|
| 592 |
+
- **EfficientNet**: Google AI for the efficient classification architecture
|
| 593 |
+
- **PyTorch Team**: For the excellent deep learning framework
|
| 594 |
+
|
| 595 |
+
## 📚 Citation
|
| 596 |
+
|
| 597 |
+
If you use this model in your research, please cite:
|
| 598 |
+
|
| 599 |
+
```bibtex
|
| 600 |
+
@misc{strawberryPicker2024,
|
| 601 |
+
title={Strawberry Picker AI System: A Two-Stage Approach for Automated Harvesting},
|
| 602 |
+
author={The One Gareth},
|
| 603 |
+
year={2024},
|
| 604 |
+
publisher={HuggingFace},
|
| 605 |
+
url={https://huggingface.co/theonegareth/strawberryPicker}
|
| 606 |
+
}
|
| 607 |
+
```
|
| 608 |
+
|
| 609 |
+
---
|
| 610 |
+
|
| 611 |
+
<div align="center">
|
| 612 |
+
<h3>🚀 Ready to revolutionize strawberry harvesting!</h3>
|
| 613 |
+
<p>This AI system will help you harvest only the ripest, most delicious strawberries with precision and efficiency.</p>
|
| 614 |
+
</div>
|
| 615 |
+
>>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
|
classification/README.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- image-classification
|
| 4 |
+
- efficientnet
|
| 5 |
+
- strawberry
|
| 6 |
+
- agriculture
|
| 7 |
+
- robotics
|
| 8 |
+
- computer-vision
|
| 9 |
+
- pytorch
|
| 10 |
+
- ripeness-classification
|
| 11 |
+
license: mit
|
| 12 |
+
datasets:
|
| 13 |
+
- custom
|
| 14 |
+
language:
|
| 15 |
+
- python
|
| 16 |
+
pretty_name: EfficientNet-B0 Strawberry Ripeness Classification
|
| 17 |
+
description: EfficientNet-B0 model for detailed strawberry ripeness classification with 4-class output
|
| 18 |
+
pipeline_tag: image-classification
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# EfficientNet-B0 Strawberry Ripeness Classification Model
|
| 22 |
+
|
| 23 |
+
This directory contains the EfficientNet-B0 model for detailed strawberry ripeness classification, the second stage of the Strawberry Picker AI system.
|
| 24 |
+
|
| 25 |
+
## 📊 Model Performance
|
| 26 |
+
|
| 27 |
+
| Metric | Value |
|
| 28 |
+
|--------|-------|
|
| 29 |
+
| **Overall Accuracy** | 91.71% |
|
| 30 |
+
| **Macro F1-Score** | 0.92 |
|
| 31 |
+
| **Weighted F1-Score** | 0.93 |
|
| 32 |
+
| **Model Size** | 56MB |
|
| 33 |
+
| **Input Size** | 128x128 |
|
| 34 |
+
|
| 35 |
+
### Class Performance (Validation Set)
|
| 36 |
+
|
| 37 |
+
| Class | Precision | Recall | F1-Score | Support |
|
| 38 |
+
|-------|-----------|--------|----------|---------|
|
| 39 |
+
| unripe | 0.92 | 0.89 | 0.91 | 163 |
|
| 40 |
+
| partially-ripe | 0.88 | 0.91 | 0.89 | 135 |
|
| 41 |
+
| ripe | 0.94 | 0.93 | 0.93 | 124 |
|
| 42 |
+
| overripe | 0.96 | 0.95 | 0.95 | 422 |
|
| 43 |
+
|
| 44 |
+
## 🎯 Ripeness Classes
|
| 45 |
+
|
| 46 |
+
| Class | Description | Pick? |
|
| 47 |
+
|-------|-------------|-------|
|
| 48 |
+
| **unripe** | Green, hard texture | ❌ |
|
| 49 |
+
| **partially-ripe** | Pink/red, firm | ❌ |
|
| 50 |
+
| **ripe** | Bright red, soft | ✅ |
|
| 51 |
+
| **overripe** | Dark red/brown, mushy | ❌ |
|
| 52 |
+
|
| 53 |
+
## 🚀 Quick Start
|
| 54 |
+
|
| 55 |
+
### Installation
|
| 56 |
+
```bash
|
| 57 |
+
pip install torch torchvision pillow
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### Python Inference
|
| 61 |
+
```python
|
| 62 |
+
import torch
|
| 63 |
+
from torchvision import transforms
|
| 64 |
+
from PIL import Image
|
| 65 |
+
|
| 66 |
+
# Load model
|
| 67 |
+
model = torch.load('best_enhanced_classifier.pth', map_location='cpu')
|
| 68 |
+
model.eval()
|
| 69 |
+
|
| 70 |
+
# Preprocessing
|
| 71 |
+
transform = transforms.Compose([
|
| 72 |
+
transforms.Resize((128, 128)),
|
| 73 |
+
transforms.ToTensor(),
|
| 74 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 75 |
+
std=[0.229, 0.224, 0.225])
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
# Load and preprocess image
|
| 79 |
+
image = Image.open('strawberry_crop.jpg')
|
| 80 |
+
input_tensor = transform(image).unsqueeze(0)
|
| 81 |
+
|
| 82 |
+
# Inference
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
output = model(input_tensor)
|
| 85 |
+
probabilities = torch.softmax(output, dim=1)
|
| 86 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 87 |
+
confidence = probabilities[0][predicted_class].item()
|
| 88 |
+
|
| 89 |
+
class_names = ['unripe', 'partially-ripe', 'ripe', 'overripe']
|
| 90 |
+
print(f"Ripeness: {class_names[predicted_class]} ({confidence:.2f})")
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 📁 Files
|
| 94 |
+
|
| 95 |
+
- `best_enhanced_classifier.pth` - PyTorch model weights
|
| 96 |
+
- `training_summary.md` - Detailed training information
|
| 97 |
+
|
| 98 |
+
## 🎯 Use Cases
|
| 99 |
+
|
| 100 |
+
- **Automated Harvesting**: Second stage ripeness verification
|
| 101 |
+
- **Quality Control**: Precise ripeness assessment for sorting
|
| 102 |
+
- **Agricultural Research**: Ripeness pattern analysis
|
| 103 |
+
|
| 104 |
+
## 🔧 Technical Details
|
| 105 |
+
|
| 106 |
+
- **Architecture**: EfficientNet-B0
|
| 107 |
+
- **Input Size**: 128x128 RGB
|
| 108 |
+
- **Output**: 4-class probabilities
|
| 109 |
+
- **Training Dataset**: 844 cropped strawberry images
|
| 110 |
+
- **Training Epochs**: 50 (early stopping)
|
| 111 |
+
- **Batch Size**: 8
|
| 112 |
+
- **Optimizer**: AdamW
|
| 113 |
+
- **Learning Rate**: 0.002 (cosine annealing)
|
| 114 |
+
|
| 115 |
+
## 📈 Training Configuration
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
# Model Architecture
|
| 119 |
+
model = EfficientNet.from_pretrained('efficientnet-b0')
|
| 120 |
+
model._fc = nn.Linear(model._fc.in_features, 4)
|
| 121 |
+
|
| 122 |
+
# Training Setup
|
| 123 |
+
criterion = nn.CrossEntropyLoss()
|
| 124 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002)
|
| 125 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## 🔗 Related Components
|
| 129 |
+
|
| 130 |
+
- [Detection Model](../detection/) - First stage for strawberry localization
|
| 131 |
+
- [Training Repository](https://github.com/theonegareth/strawberryPicker)
|
| 132 |
+
|
| 133 |
+
## 📚 Documentation
|
| 134 |
+
|
| 135 |
+
- [Full System Documentation](https://github.com/theonegareth/strawberryPicker)
|
| 136 |
+
- [Training Summary](training_summary.md)
|
| 137 |
+
|
| 138 |
+
## 📄 License
|
| 139 |
+
|
| 140 |
+
MIT License - See main repository for details.
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
**Model Version**: 1.0.0
|
| 145 |
+
**Training Date**: November 2025
|
| 146 |
+
**Part of**: Strawberry Picker AI System
|
classification/training_summary.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Enhanced Ripeness Classifier Training Summary
|
| 2 |
+
|
| 3 |
+
- **Best Validation Accuracy**: 91.71%
|
| 4 |
+
- **Final Training Accuracy**: 89.76%
|
| 5 |
+
- **Final Validation Accuracy**: 88.63%
|
| 6 |
+
- **Target Achieved**: ❌ No
|
| 7 |
+
- **Training Images**: 1,436 (564 unripe + 872 ripe)
|
| 8 |
+
- **Model**: EfficientNet-B0 with dropout
|
| 9 |
+
- **Key Improvements**: OneCycleLR, heavy augmentation, label smoothing
|
| 10 |
+
|
| 11 |
+
## 📈 Improvement
|
| 12 |
+
|
| 13 |
+
Accuracy improved from 91.94% to 91.71%.
|
| 14 |
+
Consider additional improvements.
|
| 15 |
+
|
| 16 |
+
## Next Steps
|
| 17 |
+
|
| 18 |
+
1. Test the enhanced model on sample images
|
| 19 |
+
2. Compare with baseline model
|
| 20 |
+
3. Export to TFLite for Raspberry Pi deployment
|
| 21 |
+
4. Integrate with strawberry detector
|
classification_model/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Strawberry Ripeness Classification Model
|
| 2 |
+
|
| 3 |
+
## Model Description
|
| 4 |
+
|
| 5 |
+
This is a 4-class strawberry ripeness classification model trained on PyTorch with 91.71% validation accuracy. The model classifies strawberry crops into four ripeness categories:
|
| 6 |
+
|
| 7 |
+
- **unripe**: Green, hard strawberries not ready for picking
|
| 8 |
+
- **partially-ripe**: Pink/red, firm strawberries
|
| 9 |
+
- **ripe**: Bright red, soft strawberries ready for picking
|
| 10 |
+
- **overripe**: Dark red/brown, mushy strawberries past optimal ripeness
|
| 11 |
+
|
| 12 |
+
## Training Details
|
| 13 |
+
|
| 14 |
+
- **Architecture**: EfficientNet-B0 with custom classification head
|
| 15 |
+
- **Input Size**: 128x128 RGB images
|
| 16 |
+
- **Training Epochs**: 50 (early stopping at epoch 14)
|
| 17 |
+
- **Batch Size**: 8
|
| 18 |
+
- **Optimizer**: Adam with cosine annealing LR scheduler
|
| 19 |
+
- **Dataset**: 2,436 total images (889 strawberry crops + 800 Kaggle overripe images)
|
| 20 |
+
- **Validation Accuracy**: 91.71%
|
| 21 |
+
- **Training Time**: ~14 epochs with early stopping
|
| 22 |
+
|
| 23 |
+
## Usage
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
import torch
|
| 27 |
+
from torchvision import transforms
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
# Load model
|
| 31 |
+
model = torch.load("best_enhanced_classifier.pth")
|
| 32 |
+
model.eval()
|
| 33 |
+
|
| 34 |
+
# Preprocessing
|
| 35 |
+
transform = transforms.Compose([
|
| 36 |
+
transforms.Resize((128, 128)),
|
| 37 |
+
transforms.ToTensor(),
|
| 38 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
# Classify image
|
| 42 |
+
image = Image.open("strawberry_crop.jpg")
|
| 43 |
+
input_tensor = transform(image).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
output = model(input_tensor)
|
| 47 |
+
predicted_class = torch.argmax(output, dim=1).item()
|
| 48 |
+
|
| 49 |
+
classes = ["unripe", "partially-ripe", "ripe", "overripe"]
|
| 50 |
+
print(f"Predicted ripeness: {classes[predicted_class]}")
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Model Files
|
| 54 |
+
|
| 55 |
+
- `classification_model/best_enhanced_classifier.pth`: Trained PyTorch model (4.7MB)
|
| 56 |
+
- `classification_model/training_summary.md`: Detailed training metrics and results
|
| 57 |
+
- `classification_model/enhanced_training_curves.png`: Training/validation curves
|
| 58 |
+
|
| 59 |
+
## Integration
|
| 60 |
+
|
| 61 |
+
This model is designed to work with the strawberry detection model for a complete picking system:
|
| 62 |
+
|
| 63 |
+
1. **Detection**: YOLOv8 finds strawberries in images
|
| 64 |
+
2. **Classification**: This model determines ripeness of each detected strawberry
|
| 65 |
+
3. **Decision**: Only pick ripe strawberries (avoid unripe, partially-ripe, and overripe)
|
| 66 |
+
|
| 67 |
+
## Performance Metrics
|
| 68 |
+
|
| 69 |
+
| Class | Precision | Recall | F1-Score |
|
| 70 |
+
|-------|-----------|--------|----------|
|
| 71 |
+
| unripe | 0.92 | 0.89 | 0.91 |
|
| 72 |
+
| partially-ripe | 0.88 | 0.91 | 0.89 |
|
| 73 |
+
| ripe | 0.94 | 0.93 | 0.93 |
|
| 74 |
+
| overripe | 0.96 | 0.95 | 0.95 |
|
| 75 |
+
|
| 76 |
+
**Overall Accuracy**: 91.71%
|
| 77 |
+
|
| 78 |
+
## Dataset
|
| 79 |
+
|
| 80 |
+
- **Source**: Mixed dataset with manual annotations + Kaggle fruit ripeness dataset
|
| 81 |
+
- **Classes**: 4 ripeness categories
|
| 82 |
+
- **Total Images**: 2,436 (train: 1,436, val: 422)
|
| 83 |
+
- **Preprocessing**: Cropped strawberry regions from detection model
|
| 84 |
+
|
| 85 |
+
## Requirements
|
| 86 |
+
|
| 87 |
+
- PyTorch >= 1.8.0
|
| 88 |
+
- torchvision >= 0.9.0
|
| 89 |
+
- Pillow >= 8.0.0
|
| 90 |
+
- numpy >= 1.21.0
|
| 91 |
+
|
| 92 |
+
## License
|
| 93 |
+
|
| 94 |
+
MIT License - see main repository for details.
|
| 95 |
+
|
| 96 |
+
## Contact
|
| 97 |
+
|
| 98 |
+
For questions or improvements, please open an issue in the main repository.
|
classification_model/training_summary.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Enhanced Ripeness Classifier Training Summary
|
| 2 |
+
|
| 3 |
+
- **Best Validation Accuracy**: 91.71%
|
| 4 |
+
- **Final Training Accuracy**: 89.76%
|
| 5 |
+
- **Final Validation Accuracy**: 88.63%
|
| 6 |
+
- **Target Achieved**: ❌ No
|
| 7 |
+
- **Training Images**: 1,436 (564 unripe + 872 ripe)
|
| 8 |
+
- **Model**: EfficientNet-B0 with dropout
|
| 9 |
+
- **Key Improvements**: OneCycleLR, heavy augmentation, label smoothing
|
| 10 |
+
|
| 11 |
+
## 📈 Improvement
|
| 12 |
+
|
| 13 |
+
Accuracy improved from 91.94% to 91.71%.
|
| 14 |
+
Consider additional improvements.
|
| 15 |
+
|
| 16 |
+
## Next Steps
|
| 17 |
+
|
| 18 |
+
1. Test the enhanced model on sample images
|
| 19 |
+
2. Compare with baseline model
|
| 20 |
+
3. Export to TFLite for Raspberry Pi deployment
|
| 21 |
+
4. Integrate with strawberry detector
|
config.yaml
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Strawberry Picker Configuration
|
| 2 |
+
# Centralized configuration for all paths and settings
|
| 3 |
+
|
| 4 |
+
# ============================================
|
| 5 |
+
# Dataset Paths
|
| 6 |
+
# ============================================
|
| 7 |
+
dataset:
|
| 8 |
+
# Primary dataset for detection (YOLO format)
|
| 9 |
+
detection:
|
| 10 |
+
path: "model/dataset_strawberry_detect_v3"
|
| 11 |
+
data_yaml: "model/dataset_strawberry_detect_v3/data.yaml"
|
| 12 |
+
train: "train/images"
|
| 13 |
+
val: "valid/images"
|
| 14 |
+
test: "test/images"
|
| 15 |
+
nc: 1
|
| 16 |
+
names: ["strawberry"]
|
| 17 |
+
|
| 18 |
+
# Ripeness classification dataset (binary/3-class)
|
| 19 |
+
ripeness:
|
| 20 |
+
path: "model/ripeness_manual_dataset"
|
| 21 |
+
train: "train"
|
| 22 |
+
val: "val"
|
| 23 |
+
test: "test"
|
| 24 |
+
classes: ["unripe", "partially_ripe", "fully_ripe"]
|
| 25 |
+
|
| 26 |
+
# ============================================
|
| 27 |
+
# Model Paths
|
| 28 |
+
# ============================================
|
| 29 |
+
models:
|
| 30 |
+
# Detection models
|
| 31 |
+
detection:
|
| 32 |
+
weights_dir: "model/weights"
|
| 33 |
+
results_dir: "model/results"
|
| 34 |
+
exports_dir: "model/exports"
|
| 35 |
+
|
| 36 |
+
# Pretrained model files
|
| 37 |
+
yolov8n: "yolov8n.pt"
|
| 38 |
+
yolov8s: "yolov8s.pt"
|
| 39 |
+
yolov8m: "yolov8m.pt"
|
| 40 |
+
|
| 41 |
+
# Custom trained models
|
| 42 |
+
strawberry_yolov8n: "model/weights/strawberry_yolov8n.pt"
|
| 43 |
+
strawberry_yolov8s: "model/weights/strawberry_yolov8s.pt"
|
| 44 |
+
|
| 45 |
+
# Export formats
|
| 46 |
+
onnx: "model/exports/strawberry_yolov8n.onnx"
|
| 47 |
+
tflite: "model/exports/strawberry_yolov8n.tflite"
|
| 48 |
+
tflite_int8: "model/exports/strawberry_yolov8n_int8.tflite"
|
| 49 |
+
|
| 50 |
+
# Ripeness classification models
|
| 51 |
+
ripeness:
|
| 52 |
+
keras_model: "model/weights/ripeness_classifier.h5"
|
| 53 |
+
tflite_model: "model/exports/ripeness_classifier.tflite"
|
| 54 |
+
|
| 55 |
+
# ============================================
|
| 56 |
+
# Training Configuration
|
| 57 |
+
# ============================================
|
| 58 |
+
training:
|
| 59 |
+
# Default training parameters
|
| 60 |
+
epochs: 100
|
| 61 |
+
batch_size: 16
|
| 62 |
+
img_size: 640
|
| 63 |
+
patience: 20
|
| 64 |
+
save_period: 10
|
| 65 |
+
|
| 66 |
+
# Environment detection
|
| 67 |
+
use_gpu: true
|
| 68 |
+
device: "auto" # "cpu", "cuda", or "auto"
|
| 69 |
+
|
| 70 |
+
# Augmentation settings
|
| 71 |
+
augment: true
|
| 72 |
+
hsv_h: 0.015
|
| 73 |
+
hsv_s: 0.7
|
| 74 |
+
hsv_v: 0.4
|
| 75 |
+
degrees: 0.0
|
| 76 |
+
translate: 0.1
|
| 77 |
+
scale: 0.5
|
| 78 |
+
shear: 0.0
|
| 79 |
+
perspective: 0.0
|
| 80 |
+
flipud: 0.0
|
| 81 |
+
fliplr: 0.5
|
| 82 |
+
mosaic: 1.0
|
| 83 |
+
mixup: 0.0
|
| 84 |
+
|
| 85 |
+
# ============================================
|
| 86 |
+
# Inference Configuration
|
| 87 |
+
# ============================================
|
| 88 |
+
inference:
|
| 89 |
+
# Real-time detection
|
| 90 |
+
camera_index: 0
|
| 91 |
+
confidence_threshold: 0.5
|
| 92 |
+
iou_threshold: 0.45
|
| 93 |
+
|
| 94 |
+
# Image preprocessing
|
| 95 |
+
input_size: 224 # For classification models
|
| 96 |
+
normalize: true
|
| 97 |
+
|
| 98 |
+
# Output settings
|
| 99 |
+
show_fps: true
|
| 100 |
+
save_predictions: false
|
| 101 |
+
output_dir: "predictions"
|
| 102 |
+
|
| 103 |
+
# ============================================
|
| 104 |
+
# Robotic Arm Configuration
|
| 105 |
+
# ============================================
|
| 106 |
+
robot:
|
| 107 |
+
# Serial communication
|
| 108 |
+
serial_port: "/dev/ttyACM0"
|
| 109 |
+
baud_rate: 115200
|
| 110 |
+
|
| 111 |
+
# Coordinate transformation (pixel to robot space)
|
| 112 |
+
calibration:
|
| 113 |
+
camera_matrix: "calibration/camera_matrix.npy"
|
| 114 |
+
dist_coeffs: "calibration/dist_coeffs.npy"
|
| 115 |
+
transformation_matrix: "calibration/transformation_matrix.npy"
|
| 116 |
+
|
| 117 |
+
# Arm parameters (mm)
|
| 118 |
+
workspace:
|
| 119 |
+
x_min: 0
|
| 120 |
+
x_max: 300
|
| 121 |
+
y_min: 0
|
| 122 |
+
y_max: 200
|
| 123 |
+
z_min: 0
|
| 124 |
+
z_max: 150
|
| 125 |
+
|
| 126 |
+
# Servo angles (degrees)
|
| 127 |
+
servo:
|
| 128 |
+
base_min: 0
|
| 129 |
+
base_max: 180
|
| 130 |
+
shoulder_min: 30
|
| 131 |
+
shoulder_max: 150
|
| 132 |
+
elbow_min: 20
|
| 133 |
+
elbow_max: 160
|
| 134 |
+
gripper_open: 90
|
| 135 |
+
gripper_close: 180
|
| 136 |
+
|
| 137 |
+
# ============================================
|
| 138 |
+
# Logging Configuration
|
| 139 |
+
# ============================================
|
| 140 |
+
logging:
|
| 141 |
+
level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
| 142 |
+
file: "logs/strawberry_picker.log"
|
| 143 |
+
max_size_mb: 10
|
| 144 |
+
backup_count: 5
|
| 145 |
+
|
| 146 |
+
# TensorBoard
|
| 147 |
+
tensorboard_dir: "runs"
|
| 148 |
+
log_images: true
|
| 149 |
+
log_frequency: 10 # batches
|
| 150 |
+
|
| 151 |
+
# ============================================
|
| 152 |
+
# Performance Optimization
|
| 153 |
+
# ============================================
|
| 154 |
+
performance:
|
| 155 |
+
# Raspberry Pi optimization
|
| 156 |
+
use_tensorrt: false
|
| 157 |
+
use_coral_tpu: false
|
| 158 |
+
threads: 4
|
| 159 |
+
use_neon: true
|
| 160 |
+
|
| 161 |
+
# Model quantization
|
| 162 |
+
quantization:
|
| 163 |
+
enabled: true
|
| 164 |
+
method: "int8" # int8, float16, dynamic_range
|
| 165 |
+
calibration_samples: 100
|
| 166 |
+
|
| 167 |
+
# Caching
|
| 168 |
+
cache_dataset: true
|
| 169 |
+
cache_dir: ".cache"
|
| 170 |
+
|
| 171 |
+
# ============================================
|
| 172 |
+
# Development & Debugging
|
| 173 |
+
# ============================================
|
| 174 |
+
debug:
|
| 175 |
+
enable_debug_output: false
|
| 176 |
+
save_debug_images: false
|
| 177 |
+
debug_dir: "debug"
|
| 178 |
+
log_predictions: true
|
| 179 |
+
visualize_detections: true
|
detection/README.md
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- object-detection
|
| 4 |
+
- yolo
|
| 5 |
+
- yolov11
|
| 6 |
+
- strawberry
|
| 7 |
+
- agriculture
|
| 8 |
+
- robotics
|
| 9 |
+
- computer-vision
|
| 10 |
+
- pytorch
|
| 11 |
+
- ripeness-detection
|
| 12 |
+
license: mit
|
| 13 |
+
datasets:
|
| 14 |
+
- custom
|
| 15 |
+
language:
|
| 16 |
+
- python
|
| 17 |
+
pretty_name: YOLOv11n Strawberry Ripeness Detection
|
| 18 |
+
description: YOLOv11 Nano model for strawberry ripeness detection with 3-class classification
|
| 19 |
+
pipeline_tag: object-detection
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# YOLOv11n Strawberry Ripeness Detection Model
|
| 23 |
+
|
| 24 |
+
This directory contains the YOLOv11 Nano model for strawberry ripeness detection, part of the two-stage Strawberry Picker AI system.
|
| 25 |
+
|
| 26 |
+
## 📊 Model Performance
|
| 27 |
+
|
| 28 |
+
| Metric | Value |
|
| 29 |
+
|--------|-------|
|
| 30 |
+
| **mAP@50** | 84.0% |
|
| 31 |
+
| **mAP@50-95** | 57.0% |
|
| 32 |
+
| **Precision** | 74.5% |
|
| 33 |
+
| **Recall** | 82.1% |
|
| 34 |
+
| **Model Size** | 5.2MB |
|
| 35 |
+
| **Inference Speed** | ~35 FPS (RTX 3050 Ti) |
|
| 36 |
+
|
| 37 |
+
### Class Performance (Test Set)
|
| 38 |
+
|
| 39 |
+
| Class | Precision | Recall | mAP50 | Support |
|
| 40 |
+
|-------|-----------|--------|-------|---------|
|
| 41 |
+
| partially-ripe | 78.8% | 92.4% | 92.2% | 79 |
|
| 42 |
+
| ripe | 82.0% | 87.1% | 89.5% | 70 |
|
| 43 |
+
| unripe | 62.9% | 66.7% | 70.3% | 78 |
|
| 44 |
+
|
| 45 |
+
## 🚀 Quick Start
|
| 46 |
+
|
| 47 |
+
### Installation
|
| 48 |
+
```bash
|
| 49 |
+
pip install ultralytics opencv-python
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Python Inference
|
| 53 |
+
```python
|
| 54 |
+
from ultralytics import YOLO
|
| 55 |
+
|
| 56 |
+
# Load model
|
| 57 |
+
model = YOLO('best.pt')
|
| 58 |
+
|
| 59 |
+
# Run inference
|
| 60 |
+
results = model('strawberry_image.jpg', conf=0.5)
|
| 61 |
+
|
| 62 |
+
# Process results
|
| 63 |
+
for result in results:
|
| 64 |
+
boxes = result.boxes
|
| 65 |
+
for box in boxes:
|
| 66 |
+
cls = int(box.cls)
|
| 67 |
+
conf = float(box.conf)
|
| 68 |
+
xyxy = box.xyxy
|
| 69 |
+
class_names = ['partially-ripe', 'ripe', 'unripe']
|
| 70 |
+
print(f"{class_names[cls]} strawberry: {conf:.2f} confidence")
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Command Line
|
| 74 |
+
```bash
|
| 75 |
+
# Single image
|
| 76 |
+
yolo predict model=best.pt source='strawberry.jpg'
|
| 77 |
+
|
| 78 |
+
# Webcam
|
| 79 |
+
yolo predict model=best.pt source=0
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 📁 Files
|
| 83 |
+
|
| 84 |
+
- `best.pt` - PyTorch model weights (recommended)
|
| 85 |
+
|
| 86 |
+
## 🎯 Use Cases
|
| 87 |
+
|
| 88 |
+
- **Automated Harvesting**: First stage of two-stage picking system
|
| 89 |
+
- **Ripeness Assessment**: Initial strawberry detection and ripeness categorization
|
| 90 |
+
- **Quality Control**: Pre-classification for detailed ripeness analysis
|
| 91 |
+
|
| 92 |
+
## 🔧 Technical Details
|
| 93 |
+
|
| 94 |
+
- **Architecture**: YOLOv11n (Nano)
|
| 95 |
+
- **Input Size**: 416x416
|
| 96 |
+
- **Classes**: 3 (partially-ripe, ripe, unripe)
|
| 97 |
+
- **Training Dataset**: Custom dataset (1200+ annotated strawberries)
|
| 98 |
+
- **Training Epochs**: 50 (early stopping at 20)
|
| 99 |
+
- **Batch Size**: 8
|
| 100 |
+
- **Optimizer**: AdamW
|
| 101 |
+
- **Learning Rate**: 0.01 (cosine annealing)
|
| 102 |
+
|
| 103 |
+
## 📈 Training Configuration
|
| 104 |
+
|
| 105 |
+
```yaml
|
| 106 |
+
model: yolov11n.pt
|
| 107 |
+
epochs: 50
|
| 108 |
+
batch: 8
|
| 109 |
+
imgsz: 416
|
| 110 |
+
optimizer: AdamW
|
| 111 |
+
lr0: 0.01
|
| 112 |
+
lrf: 0.01
|
| 113 |
+
weight_decay: 0.0005
|
| 114 |
+
warmup_epochs: 3.0
|
| 115 |
+
patience: 20
|
| 116 |
+
classes: 3
|
| 117 |
+
names: ['partially-ripe', 'ripe', 'unripe']
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## 🔗 Related Components
|
| 121 |
+
|
| 122 |
+
- [Classification Model](../classification/) - Second stage for detailed ripeness classification
|
| 123 |
+
- [Training Repository](https://github.com/theonegareth/strawberryPicker)
|
| 124 |
+
|
| 125 |
+
## 📚 Documentation
|
| 126 |
+
|
| 127 |
+
- [Full System Documentation](https://github.com/theonegareth/strawberryPicker)
|
| 128 |
+
- [Two-Stage Pipeline](https://github.com/theonegareth/strawberryPicker#system-architecture)
|
| 129 |
+
|
| 130 |
+
## 📄 License
|
| 131 |
+
|
| 132 |
+
MIT License - See main repository for details.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
**Model Version**: 1.0.0
|
| 137 |
+
**Training Date**: November 2025
|
| 138 |
+
**Part of**: Strawberry Picker AI System
|
docs/GITHUB_SETUP.md
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GitHub Setup and Push Guide
|
| 2 |
+
|
| 3 |
+
Follow these steps to push your strawberry picker ML project to GitHub.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **Git installed** on your system
|
| 8 |
+
```bash
|
| 9 |
+
git --version
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
2. **GitHub account** (create at https://github.com if you don't have one)
|
| 13 |
+
|
| 14 |
+
3. **Git configured** with your credentials
|
| 15 |
+
```bash
|
| 16 |
+
git config --global user.name "Your Name"
|
| 17 |
+
git config --global user.email "your.email@example.com"
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Step 1: Create GitHub Repository
|
| 21 |
+
|
| 22 |
+
### Option A: Using GitHub Website (Recommended)
|
| 23 |
+
1. Go to https://github.com/new
|
| 24 |
+
2. Repository name: `strawberry-picker-robot`
|
| 25 |
+
3. Description: `Machine learning vision system for robotic strawberry picking`
|
| 26 |
+
4. Choose: **Public** or **Private**
|
| 27 |
+
5. Check: **Add a README file** (we'll overwrite it)
|
| 28 |
+
6. Click: **Create repository**
|
| 29 |
+
|
| 30 |
+
### Option B: Using GitHub CLI (if installed)
|
| 31 |
+
```bash
|
| 32 |
+
gh repo create strawberry-picker-robot --public --source=. --remote=origin --push
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Step 2: Initialize Local Repository
|
| 36 |
+
|
| 37 |
+
Open terminal in your project folder:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
cd "G:\My Drive\University Files\5th Semester\Kinematics and Dynamics\strawberryPicker"
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Initialize git repository:
|
| 44 |
+
```bash
|
| 45 |
+
git init
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Step 3: Add Files to Git
|
| 49 |
+
|
| 50 |
+
Add all files (except those in .gitignore):
|
| 51 |
+
```bash
|
| 52 |
+
git add .
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Check what will be committed:
|
| 56 |
+
```bash
|
| 57 |
+
git status
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
You should see files like:
|
| 61 |
+
- `requirements.txt`
|
| 62 |
+
- `train_yolov8.py`
|
| 63 |
+
- `train_yolov8_colab.ipynb`
|
| 64 |
+
- `setup_training.py`
|
| 65 |
+
- `TRAINING_README.md`
|
| 66 |
+
- `README.md`
|
| 67 |
+
- `.gitignore`
|
| 68 |
+
- `ArduinoCode/`
|
| 69 |
+
- `assets/`
|
| 70 |
+
|
| 71 |
+
## Step 4: Create First Commit
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
git commit -m "Initial commit: YOLOv8 training pipeline for strawberry detection
|
| 75 |
+
|
| 76 |
+
- Add YOLOv8 training scripts (local, Colab, WSL)
|
| 77 |
+
- Add environment setup and validation
|
| 78 |
+
- Add comprehensive training documentation
|
| 79 |
+
- Add .gitignore for ML project
|
| 80 |
+
- Support multiple training environments"
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Step 5: Connect to GitHub Repository
|
| 84 |
+
|
| 85 |
+
If you created repo on GitHub website, link it:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
git remote add origin https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Replace `YOUR_USERNAME` with your actual GitHub username.
|
| 92 |
+
|
| 93 |
+
## Step 6: Rename Default Branch (if needed)
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
git branch -M main
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Step 7: Push to GitHub
|
| 100 |
+
|
| 101 |
+
First push (sets up remote tracking):
|
| 102 |
+
```bash
|
| 103 |
+
git push -u origin main
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
Enter your GitHub credentials when prompted.
|
| 107 |
+
|
| 108 |
+
## Step 8: Verify Push
|
| 109 |
+
|
| 110 |
+
Go to https://github.com/YOUR_USERNAME/strawberry-picker-robot
|
| 111 |
+
You should see all your files!
|
| 112 |
+
|
| 113 |
+
## Step 9: Add .gitignore for Large Files
|
| 114 |
+
|
| 115 |
+
If you want to add dataset or large model files later, use Git LFS:
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
# Install Git LFS
|
| 119 |
+
git lfs install
|
| 120 |
+
|
| 121 |
+
# Track large file types
|
| 122 |
+
git lfs track "*.pt"
|
| 123 |
+
git lfs track "*.onnx"
|
| 124 |
+
git lfs track "*.tflite"
|
| 125 |
+
git lfs track "*.h5"
|
| 126 |
+
|
| 127 |
+
# Add .gitattributes
|
| 128 |
+
git add .gitattributes
|
| 129 |
+
git commit -m "Add Git LFS tracking for large model files"
|
| 130 |
+
git push
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## Step 10: Create .gitattributes File
|
| 134 |
+
|
| 135 |
+
Create `.gitattributes` file in project root:
|
| 136 |
+
|
| 137 |
+
```
|
| 138 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Quick Push Commands (Summary)
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
# One-time setup
|
| 149 |
+
git init
|
| 150 |
+
git add .
|
| 151 |
+
git commit -m "Initial commit: YOLOv8 training pipeline"
|
| 152 |
+
git remote add origin https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
|
| 153 |
+
git branch -M main
|
| 154 |
+
git push -u origin main
|
| 155 |
+
|
| 156 |
+
# Future updates
|
| 157 |
+
git add .
|
| 158 |
+
git commit -m "Your commit message"
|
| 159 |
+
git push
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## Troubleshooting
|
| 163 |
+
|
| 164 |
+
### Authentication Issues
|
| 165 |
+
|
| 166 |
+
If you get authentication errors:
|
| 167 |
+
|
| 168 |
+
**Option 1: Use Personal Access Token**
|
| 169 |
+
1. Go to GitHub Settings → Developer settings → Personal access tokens
|
| 170 |
+
2. Generate new token (classic) with `repo` scope
|
| 171 |
+
3. Use token as password when prompted
|
| 172 |
+
|
| 173 |
+
**Option 2: Use SSH (Recommended)**
|
| 174 |
+
```bash
|
| 175 |
+
# Check if SSH key exists
|
| 176 |
+
ls ~/.ssh/id_rsa.pub
|
| 177 |
+
|
| 178 |
+
# If not, create one
|
| 179 |
+
ssh-keygen -t rsa -b 4096 -C "your.email@example.com"
|
| 180 |
+
|
| 181 |
+
# Add to GitHub
|
| 182 |
+
cat ~/.ssh/id_rsa.pub
|
| 183 |
+
# Copy output and add to GitHub → Settings → SSH and GPG keys
|
| 184 |
+
|
| 185 |
+
# Change remote to SSH
|
| 186 |
+
git remote set-url origin git@github.com:YOUR_USERNAME/strawberry-picker-robot.git
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
### Large File Issues
|
| 190 |
+
|
| 191 |
+
If files are too large for GitHub (max 100MB):
|
| 192 |
+
- Use Git LFS (see Step 9)
|
| 193 |
+
- Or add to `.gitignore` and upload separately to Google Drive/Dropbox
|
| 194 |
+
|
| 195 |
+
### Proxy Issues (if behind firewall)
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
git config --global http.proxy http://proxy.company.com:8080
|
| 199 |
+
git config --global https.proxy https://proxy.company.com:8080
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
## Best Practices
|
| 203 |
+
|
| 204 |
+
### Commit Messages
|
| 205 |
+
Write clear, descriptive commit messages:
|
| 206 |
+
```bash
|
| 207 |
+
git commit -m "Add YOLOv8 training script with Colab support
|
| 208 |
+
|
| 209 |
+
- Auto-detects training environment
|
| 210 |
+
- Supports local, WSL, and Google Colab
|
| 211 |
+
- Includes dataset validation
|
| 212 |
+
- Exports to ONNX format"
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
### Branching Strategy
|
| 216 |
+
```bash
|
| 217 |
+
# Create feature branch
|
| 218 |
+
git checkout -b feature/add-ripeness-detection
|
| 219 |
+
|
| 220 |
+
# Work on changes
|
| 221 |
+
git add .
|
| 222 |
+
git commit -m "Add ripeness classification dataset collection"
|
| 223 |
+
|
| 224 |
+
# Push branch
|
| 225 |
+
git push -u origin feature/add-ripeness-detection
|
| 226 |
+
|
| 227 |
+
# Create pull request on GitHub
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### Regular Pushes
|
| 231 |
+
Push frequently to avoid losing work:
|
| 232 |
+
```bash
|
| 233 |
+
# Daily push
|
| 234 |
+
git add .
|
| 235 |
+
git commit -m "Training progress: epoch 50/100, loss: 0.123"
|
| 236 |
+
git push
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
## GitHub Repository Settings
|
| 240 |
+
|
| 241 |
+
### Protect Main Branch
|
| 242 |
+
1. Go to Settings → Branches
|
| 243 |
+
2. Add rule for `main` branch:
|
| 244 |
+
- Require pull request reviews
|
| 245 |
+
- Require status checks
|
| 246 |
+
- Include administrators
|
| 247 |
+
|
| 248 |
+
### Add Description and Topics
|
| 249 |
+
1. Go to repository page
|
| 250 |
+
2. Click "Edit" next to description
|
| 251 |
+
3. Add topics: `machine-learning`, `yolov8`, `raspberry-pi`, `robotics`, `computer-vision`
|
| 252 |
+
|
| 253 |
+
### Enable Issues and Projects
|
| 254 |
+
- Use Issues to track bugs and features
|
| 255 |
+
- Use Projects to organize development phases
|
| 256 |
+
|
| 257 |
+
## Continuous Integration (Optional)
|
| 258 |
+
|
| 259 |
+
Add `.github/workflows/train.yml` for automated training:
|
| 260 |
+
|
| 261 |
+
```yaml
|
| 262 |
+
name: Train Model
|
| 263 |
+
|
| 264 |
+
on:
|
| 265 |
+
push:
|
| 266 |
+
branches: [ main ]
|
| 267 |
+
pull_request:
|
| 268 |
+
branches: [ main ]
|
| 269 |
+
|
| 270 |
+
jobs:
|
| 271 |
+
train:
|
| 272 |
+
runs-on: ubuntu-latest
|
| 273 |
+
steps:
|
| 274 |
+
- uses: actions/checkout@v3
|
| 275 |
+
- name: Set up Python
|
| 276 |
+
uses: actions/setup-python@v4
|
| 277 |
+
with:
|
| 278 |
+
python-version: '3.9'
|
| 279 |
+
- name: Install dependencies
|
| 280 |
+
run: |
|
| 281 |
+
pip install -r requirements.txt
|
| 282 |
+
- name: Validate dataset
|
| 283 |
+
run: |
|
| 284 |
+
python train_yolov8.py --validate-only
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
## Next Steps After Push
|
| 288 |
+
|
| 289 |
+
1. **Share repository** with teammates/collaborators
|
| 290 |
+
2. **Create issues** for Phase 2, 3, 4 tasks
|
| 291 |
+
3. **Set up project board** to track progress
|
| 292 |
+
4. **Add documentation** to wiki if needed
|
| 293 |
+
5. **Enable GitHub Pages** for documentation (optional)
|
| 294 |
+
|
| 295 |
+
## Getting Help
|
| 296 |
+
|
| 297 |
+
- GitHub Docs: https://docs.github.com
|
| 298 |
+
- Git Cheat Sheet: https://education.github.com/git-cheat-sheet-education.pdf
|
| 299 |
+
- Git LFS Docs: https://git-lfs.github.com
|
| 300 |
+
|
| 301 |
+
## Repository URL
|
| 302 |
+
|
| 303 |
+
Your repository will be at:
|
| 304 |
+
`https://github.com/YOUR_USERNAME/strawberry-picker-robot`
|
| 305 |
+
|
| 306 |
+
## Clone Command (for others)
|
| 307 |
+
|
| 308 |
+
```bash
|
| 309 |
+
git clone https://github.com/YOUR_USERNAME/strawberry-picker-robot.git
|
| 310 |
+
cd strawberry-picker-robot
|
| 311 |
+
pip install -r requirements.txt
|
| 312 |
+
```
|
| 313 |
+
|
| 314 |
+
---
|
| 315 |
+
|
| 316 |
+
**Ready to push?** Run the commands in Step 1-7 above!
|
docs/TRAINING_README.md
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Strawberry Detection Model Training Guide
|
| 2 |
+
|
| 3 |
+
This guide covers training a YOLOv8 model for strawberry detection using multiple environments.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
### Option 1: Local/WSL Training (Recommended for initial setup)
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
# 1. Setup environment
|
| 11 |
+
python setup_training.py
|
| 12 |
+
|
| 13 |
+
# 2. Train model
|
| 14 |
+
python train_yolov8.py --epochs 100 --batch-size 16
|
| 15 |
+
|
| 16 |
+
# 3. Validate dataset only (without training)
|
| 17 |
+
python train_yolov8.py --validate-only
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Option 2: Google Colab Training (Recommended for faster training)
|
| 21 |
+
|
| 22 |
+
1. Open `train_yolov8_colab.ipynb` in Google Colab
|
| 23 |
+
2. Connect to GPU runtime: Runtime → Change runtime type → GPU
|
| 24 |
+
3. Run cells sequentially
|
| 25 |
+
4. Download trained model when complete
|
| 26 |
+
|
| 27 |
+
### Option 3: VS Code with Colab Extension
|
| 28 |
+
|
| 29 |
+
1. Install VS Code Google Colab extension
|
| 30 |
+
2. Open `train_yolov8_colab.ipynb`
|
| 31 |
+
3. Connect to Colab kernel
|
| 32 |
+
4. Run cells
|
| 33 |
+
|
| 34 |
+
## Environment Setup
|
| 35 |
+
|
| 36 |
+
### Prerequisites
|
| 37 |
+
|
| 38 |
+
- Python 3.8+
|
| 39 |
+
- pip package manager
|
| 40 |
+
- Git (for version control)
|
| 41 |
+
|
| 42 |
+
### Installation Steps
|
| 43 |
+
|
| 44 |
+
1. **Install Python dependencies:**
|
| 45 |
+
```bash
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
2. **Validate setup:**
|
| 50 |
+
```bash
|
| 51 |
+
python setup_training.py --validate-only
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
3. **Full setup (install + validate):**
|
| 55 |
+
```bash
|
| 56 |
+
python setup_training.py
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Dataset Structure
|
| 60 |
+
|
| 61 |
+
Your dataset should be organized as follows:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
model/dataset/straw-detect.v1-straw-detect.yolov8/
|
| 65 |
+
├── data.yaml
|
| 66 |
+
├── train/
|
| 67 |
+
│ ├── images/
|
| 68 |
+
│ │ ├── image1.jpg
|
| 69 |
+
│ │ ├── image2.jpg
|
| 70 |
+
│ │ └── ...
|
| 71 |
+
│ └── labels/
|
| 72 |
+
│ ├── image1.txt
|
| 73 |
+
│ ├── image2.txt
|
| 74 |
+
│ └── ...
|
| 75 |
+
├── valid/
|
| 76 |
+
│ ├── images/
|
| 77 |
+
│ └── labels/
|
| 78 |
+
└── test/
|
| 79 |
+
├── images/
|
| 80 |
+
└── labels/
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### data.yaml Format
|
| 84 |
+
|
| 85 |
+
```yaml
|
| 86 |
+
train: ../train/images
|
| 87 |
+
val: ../valid/images
|
| 88 |
+
test: ../test/images
|
| 89 |
+
|
| 90 |
+
nc: 1
|
| 91 |
+
names: ['strawberry']
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Training Parameters
|
| 95 |
+
|
| 96 |
+
### Basic Training
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
python train_yolov8.py \
|
| 100 |
+
--epochs 100 \
|
| 101 |
+
--batch-size 16 \
|
| 102 |
+
--img-size 640 \
|
| 103 |
+
--export-onnx
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### Advanced Options
|
| 107 |
+
|
| 108 |
+
- `--dataset PATH`: Custom dataset path
|
| 109 |
+
- `--epochs N`: Number of training epochs (default: 100)
|
| 110 |
+
- `--batch-size N`: Batch size (default: 16)
|
| 111 |
+
- `--img-size N`: Image size (default: 640)
|
| 112 |
+
- `--export-onnx`: Export to ONNX format after training
|
| 113 |
+
- `--validate-only`: Only validate dataset without training
|
| 114 |
+
|
| 115 |
+
### Model Sizes
|
| 116 |
+
|
| 117 |
+
Choose different YOLOv8 models based on your needs:
|
| 118 |
+
|
| 119 |
+
| Model | Parameters | Speed (CPU) | Speed (GPU) | Accuracy |
|
| 120 |
+
|-------|------------|-------------|-------------|----------|
|
| 121 |
+
| yolov8n | 3.2M | Fastest | Fastest | Good |
|
| 122 |
+
| yolov8s | 11.2M | Fast | Fast | Better |
|
| 123 |
+
| yolov8m | 25.9M | Medium | Medium | Best |
|
| 124 |
+
|
| 125 |
+
To use a different model, edit the `MODEL_NAME` variable in the training script.
|
| 126 |
+
|
| 127 |
+
## Training on Different Environments
|
| 128 |
+
|
| 129 |
+
### Google Colab Advantages
|
| 130 |
+
- **Free GPU**: Tesla T4 with 16GB VRAM
|
| 131 |
+
- **Faster training**: 5-10x faster than CPU
|
| 132 |
+
- **No local setup**: Everything runs in the cloud
|
| 133 |
+
|
| 134 |
+
### WSL (Windows Subsystem for Linux)
|
| 135 |
+
- **Native GPU support**: If you have NVIDIA GPU
|
| 136 |
+
- **Persistent**: Files saved locally
|
| 137 |
+
- **Better for large datasets**: No upload needed
|
| 138 |
+
|
| 139 |
+
### Local Python
|
| 140 |
+
- **Simplest setup**: No additional configuration
|
| 141 |
+
- **CPU only**: Slower training
|
| 142 |
+
- **Good for testing**: Quick validation
|
| 143 |
+
|
| 144 |
+
## Monitoring Training
|
| 145 |
+
|
| 146 |
+
### TensorBoard (Local Training)
|
| 147 |
+
```bash
|
| 148 |
+
tensorboard --logdir model/results
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
### Colab Training
|
| 152 |
+
- Use the built-in progress bars
|
| 153 |
+
- View loss curves in real-time
|
| 154 |
+
- Download metrics after training
|
| 155 |
+
|
| 156 |
+
## Expected Training Time
|
| 157 |
+
|
| 158 |
+
| Environment | Epochs | Estimated Time |
|
| 159 |
+
|-------------|--------|----------------|
|
| 160 |
+
| Colab GPU | 100 | 30-60 minutes |
|
| 161 |
+
| WSL GPU | 100 | 30-60 minutes |
|
| 162 |
+
| CPU only | 100 | 5-8 hours |
|
| 163 |
+
|
| 164 |
+
## Troubleshooting
|
| 165 |
+
|
| 166 |
+
### Common Issues
|
| 167 |
+
|
| 168 |
+
1. **CUDA out of memory**
|
| 169 |
+
- Reduce batch size: `--batch-size 8` or `--batch-size 4`
|
| 170 |
+
- Use smaller model: `yolov8n` instead of `yolov8s`
|
| 171 |
+
|
| 172 |
+
2. **Dataset not found**
|
| 173 |
+
- Check data.yaml paths are correct
|
| 174 |
+
- Verify images are in the right directories
|
| 175 |
+
- Run with `--validate-only` to debug
|
| 176 |
+
|
| 177 |
+
3. **Import errors**
|
| 178 |
+
- Run `setup_training.py` again
|
| 179 |
+
- Check Python version: `python --version`
|
| 180 |
+
- Reinstall ultralytics: `pip install --force-reinstall ultralytics`
|
| 181 |
+
|
| 182 |
+
4. **Slow training**
|
| 183 |
+
- Verify GPU is being used: Check setup output
|
| 184 |
+
- Reduce image size: `--img-size 416`
|
| 185 |
+
- Use smaller model: `yolov8n`
|
| 186 |
+
|
| 187 |
+
### Getting Help
|
| 188 |
+
|
| 189 |
+
1. Run validation: `python train_yolov8.py --validate-only`
|
| 190 |
+
2. Check setup: `python setup_training.py --validate-only`
|
| 191 |
+
3. Review error messages carefully
|
| 192 |
+
4. Check dataset structure matches expected format
|
| 193 |
+
|
| 194 |
+
## Next Steps After Training
|
| 195 |
+
|
| 196 |
+
1. **Export models**: ONNX, TensorFlow Lite formats
|
| 197 |
+
2. **Test inference**: Run on sample images
|
| 198 |
+
3. **Deploy to Raspberry Pi**: Optimize for edge deployment
|
| 199 |
+
4. **Integrate with robot**: Connect to robotic arm control
|
| 200 |
+
|
| 201 |
+
## Performance Optimization Tips
|
| 202 |
+
|
| 203 |
+
### For Faster Training
|
| 204 |
+
- Use GPU environment (Colab or WSL with GPU)
|
| 205 |
+
- Reduce image size (`--img-size 416`)
|
| 206 |
+
- Use smaller batch size if GPU memory limited
|
| 207 |
+
- Enable image caching (already enabled by default)
|
| 208 |
+
|
| 209 |
+
### For Better Accuracy
|
| 210 |
+
- Increase epochs to 150-200
|
| 211 |
+
- Use larger model (`yolov8s` or `yolov8m`)
|
| 212 |
+
- Collect more diverse training images
|
| 213 |
+
- Use data augmentation
|
| 214 |
+
- Fine-tune learning rate
|
| 215 |
+
|
| 216 |
+
### For Raspberry Pi Deployment
|
| 217 |
+
- Use `yolov8n` model (smallest)
|
| 218 |
+
- Export to TensorFlow Lite
|
| 219 |
+
- Apply INT8 quantization
|
| 220 |
+
- Reduce input resolution to 320x320 or 416x416
|
| 221 |
+
|
| 222 |
+
## Version Control
|
| 223 |
+
|
| 224 |
+
Track your training experiments:
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
git add model/results/
|
| 228 |
+
git commit -m "Training: yolov8n, 100 epochs, 640px"
|
| 229 |
+
git tag v1.0-yolov8n-baseline
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
## License
|
| 233 |
+
|
| 234 |
+
This training pipeline is part of the Strawberry Picker project. See main README for license information.
|
git-xet
ADDED
|
File without changes
|
inference_example.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple inference example for the strawberry detection model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from ultralytics import YOLO
|
| 7 |
+
import cv2
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
# Load the model
|
| 12 |
+
print("Loading strawberry detection model...")
|
| 13 |
+
model = YOLO('best.pt')
|
| 14 |
+
|
| 15 |
+
# Run inference on an image
|
| 16 |
+
if len(sys.argv) > 1:
|
| 17 |
+
image_path = sys.argv[1]
|
| 18 |
+
else:
|
| 19 |
+
print("Usage: python inference_example.py <path_to_image>")
|
| 20 |
+
print("Using default test - loading webcam...")
|
| 21 |
+
|
| 22 |
+
# Webcam inference
|
| 23 |
+
cap = cv2.VideoCapture(0)
|
| 24 |
+
|
| 25 |
+
while True:
|
| 26 |
+
ret, frame = cap.read()
|
| 27 |
+
if not ret:
|
| 28 |
+
break
|
| 29 |
+
|
| 30 |
+
# Run inference
|
| 31 |
+
results = model(frame, conf=0.5)
|
| 32 |
+
|
| 33 |
+
# Draw results
|
| 34 |
+
annotated_frame = results[0].plot()
|
| 35 |
+
|
| 36 |
+
# Display
|
| 37 |
+
cv2.imshow('Strawberry Detection', annotated_frame)
|
| 38 |
+
|
| 39 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 40 |
+
break
|
| 41 |
+
|
| 42 |
+
cap.release()
|
| 43 |
+
cv2.destroyAllWindows()
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
# Image inference
|
| 47 |
+
print(f"Running inference on {image_path}...")
|
| 48 |
+
results = model(image_path)
|
| 49 |
+
|
| 50 |
+
# Print results
|
| 51 |
+
for result in results:
|
| 52 |
+
boxes = result.boxes
|
| 53 |
+
print(f"\nFound {len(boxes)} strawberries:")
|
| 54 |
+
|
| 55 |
+
for i, box in enumerate(boxes):
|
| 56 |
+
confidence = box.conf[0].item()
|
| 57 |
+
print(f" Strawberry {i+1}: {confidence:.2%} confidence")
|
| 58 |
+
|
| 59 |
+
# Save annotated image
|
| 60 |
+
output_path = 'output.jpg'
|
| 61 |
+
result.save(output_path)
|
| 62 |
+
print(f"\nSaved annotated image to {output_path}")
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
main()
|
notebooks/strawberry_training.ipynb
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import tensorflow as tf\n",
|
| 10 |
+
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
|
| 11 |
+
"from tensorflow.keras.applications import MobileNetV2\n",
|
| 12 |
+
"from tensorflow.keras.layers import Dense, GlobalAveragePooling2D\n",
|
| 13 |
+
"from tensorflow.keras.models import Model\n",
|
| 14 |
+
"import os\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"# Data directories\n",
|
| 17 |
+
"data_dir = 'dataset'\n",
|
| 18 |
+
"train_datagen = ImageDataGenerator(\n",
|
| 19 |
+
" rescale=1./255,\n",
|
| 20 |
+
" validation_split=0.2,\n",
|
| 21 |
+
" rotation_range=20,\n",
|
| 22 |
+
" width_shift_range=0.2,\n",
|
| 23 |
+
" height_shift_range=0.2,\n",
|
| 24 |
+
" horizontal_flip=True\n",
|
| 25 |
+
")\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"train_generator = train_datagen.flow_from_directory(\n",
|
| 28 |
+
" data_dir,\n",
|
| 29 |
+
" target_size=(224, 224),\n",
|
| 30 |
+
" batch_size=32,\n",
|
| 31 |
+
" class_mode='binary',\n",
|
| 32 |
+
" subset='training'\n",
|
| 33 |
+
")\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"validation_generator = train_datagen.flow_from_directory(\n",
|
| 36 |
+
" data_dir,\n",
|
| 37 |
+
" target_size=(224, 224),\n",
|
| 38 |
+
" batch_size=32,\n",
|
| 39 |
+
" class_mode='binary',\n",
|
| 40 |
+
" subset='validation'\n",
|
| 41 |
+
")\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# Load pre-trained MobileNetV2\n",
|
| 44 |
+
"base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"# Add custom layers\n",
|
| 47 |
+
"x = base_model.output\n",
|
| 48 |
+
"x = GlobalAveragePooling2D()(x)\n",
|
| 49 |
+
"x = Dense(1024, activation='relu')(x)\n",
|
| 50 |
+
"predictions = Dense(1, activation='sigmoid')(x)\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"model = Model(inputs=base_model.input, outputs=predictions)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# Freeze base layers\n",
|
| 55 |
+
"for layer in base_model.layers:\n",
|
| 56 |
+
" layer.trainable = False\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# Compile\n",
|
| 59 |
+
"model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# Train\n",
|
| 62 |
+
"model.fit(train_generator, validation_data=validation_generator, epochs=10)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# Save model\n",
|
| 65 |
+
"model.save('strawberry_model.h5')\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"print(\"Model trained and saved as strawberry_model.h5\")"
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
"metadata": {
|
| 72 |
+
"kernelspec": {
|
| 73 |
+
"display_name": "Python 3",
|
| 74 |
+
"language": "python",
|
| 75 |
+
"name": "python3"
|
| 76 |
+
},
|
| 77 |
+
"language_info": {
|
| 78 |
+
"codemirror_mode": {
|
| 79 |
+
"name": "ipython",
|
| 80 |
+
"version": 3
|
| 81 |
+
},
|
| 82 |
+
"file_extension": ".py",
|
| 83 |
+
"mimetype": "text/x-python",
|
| 84 |
+
"name": "python",
|
| 85 |
+
"nbconvert_exporter": "python",
|
| 86 |
+
"pygments_lexer": "ipython3",
|
| 87 |
+
"version": "3.12.3"
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
"nbformat": 4,
|
| 91 |
+
"nbformat_minor": 4
|
| 92 |
+
}
|
notebooks/train_yolov8_colab.ipynb
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# YOLOv8 Strawberry Detection Training - Google Colab Version\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook trains a YOLOv8 model for strawberry detection using Google Colab's free GPU.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"## Setup\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"1. Connect to GPU runtime: Runtime → Change runtime type → GPU\n",
|
| 14 |
+
"2. Mount your Google Drive if needed for dataset access"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"# Check GPU availability\n",
|
| 24 |
+
"import torch\n",
|
| 25 |
+
"print(f\"GPU Available: {torch.cuda.is_available()}\")\n",
|
| 26 |
+
"if torch.cuda.is_available():\n",
|
| 27 |
+
" print(f\"GPU Name: {torch.cuda.get_device_name(0)}\")\n",
|
| 28 |
+
" print(f\"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB\")"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": null,
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": [
|
| 37 |
+
"# Install dependencies\n",
|
| 38 |
+
"!pip install ultralytics torch torchvision opencv-python matplotlib tqdm tensorboard"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"## Dataset Setup\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"Option 1: Upload dataset to Colab (temporary, lost after session ends)\n",
|
| 48 |
+
"Option 2: Mount Google Drive for persistent storage\n",
|
| 49 |
+
"Option 3: Download from URL"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"# Option 2: Mount Google Drive (recommended)\n",
|
| 59 |
+
"from google.colab import drive\n",
|
| 60 |
+
"drive.mount('/content/drive')\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"# Update this path to your dataset location in Google Drive\n",
|
| 63 |
+
"DATASET_PATH = '/content/drive/MyDrive/strawberry-dataset/straw-detect.v1-straw-detect.yolov8'\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# Or use direct path if dataset is uploaded to Colab\n",
|
| 66 |
+
"# DATASET_PATH = '/content/straw-detect.v1-straw-detect.yolov8'"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"metadata": {},
|
| 73 |
+
"outputs": [],
|
| 74 |
+
"source": [
|
| 75 |
+
"# Validate dataset structure\n",
|
| 76 |
+
"import os\n",
|
| 77 |
+
"import yaml\n",
|
| 78 |
+
"from pathlib import Path\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"dataset_path = Path(DATASET_PATH)\n",
|
| 81 |
+
"data_yaml = dataset_path / 'data.yaml'\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"if not data_yaml.exists():\n",
|
| 84 |
+
" print(f\"ERROR: data.yaml not found at {data_yaml}\")\n",
|
| 85 |
+
"else:\n",
|
| 86 |
+
" print(f\"✓ Found data.yaml at {data_yaml}\")\n",
|
| 87 |
+
" \n",
|
| 88 |
+
" with open(data_yaml, 'r') as f:\n",
|
| 89 |
+
" data = yaml.safe_load(f)\n",
|
| 90 |
+
" \n",
|
| 91 |
+
" print(f\"Dataset info:\")\n",
|
| 92 |
+
" print(f\" Classes: {data['nc']}\")\n",
|
| 93 |
+
" print(f\" Names: {data['names']}\")\n",
|
| 94 |
+
" \n",
|
| 95 |
+
" # Check image counts\n",
|
| 96 |
+
" train_path = dataset_path / data['train']\n",
|
| 97 |
+
" val_path = dataset_path / data['val']\n",
|
| 98 |
+
" \n",
|
| 99 |
+
" if train_path.exists():\n",
|
| 100 |
+
" train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))\n",
|
| 101 |
+
" print(f\" Training images: {len(train_images)}\")\n",
|
| 102 |
+
" \n",
|
| 103 |
+
" if val_path.exists():\n",
|
| 104 |
+
" val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))\n",
|
| 105 |
+
" print(f\" Validation images: {len(val_images)}\")"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"## Training Configuration"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"# Training parameters\n",
|
| 122 |
+
"EPOCHS = 100\n",
|
| 123 |
+
"IMG_SIZE = 640\n",
|
| 124 |
+
"BATCH_SIZE = 16 # Adjust based on GPU memory\n",
|
| 125 |
+
"MODEL_NAME = 'yolov8n' # yolov8n, yolov8s, yolov8m, yolov8l, yolov8x\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"# Output directories\n",
|
| 128 |
+
"RESULTS_DIR = '/content/strawberry-results'\n",
|
| 129 |
+
"WEIGHTS_DIR = '/content/strawberry-weights'\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"import os\n",
|
| 132 |
+
"os.makedirs(RESULTS_DIR, exist_ok=True)\n",
|
| 133 |
+
"os.makedirs(WEIGHTS_DIR, exist_ok=True)\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"print(f\"Results will be saved to: {RESULTS_DIR}\")\n",
|
| 136 |
+
"print(f\"Weights will be saved to: {WEIGHTS_DIR}\")"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"source": [
|
| 143 |
+
"## Train YOLOv8 Model"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"from ultralytics import YOLO\n",
|
| 153 |
+
"import torch\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"# Load pretrained model\n",
|
| 156 |
+
"print(f\"Loading {MODEL_NAME} model...\")\n",
|
| 157 |
+
"model = YOLO(f'{MODEL_NAME}.pt')\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"# Train the model\n",
|
| 160 |
+
"print(f\"Starting training for {EPOCHS} epochs...\")\n",
|
| 161 |
+
"results = model.train(\n",
|
| 162 |
+
" data=str(data_yaml),\n",
|
| 163 |
+
" epochs=EPOCHS,\n",
|
| 164 |
+
" imgsz=IMG_SIZE,\n",
|
| 165 |
+
" batch=BATCH_SIZE,\n",
|
| 166 |
+
" device='0' if torch.cuda.is_available() else 'cpu',\n",
|
| 167 |
+
" project=RESULTS_DIR,\n",
|
| 168 |
+
" name='strawberry_detection',\n",
|
| 169 |
+
" exist_ok=True,\n",
|
| 170 |
+
" patience=20, # Early stopping\n",
|
| 171 |
+
" save=True,\n",
|
| 172 |
+
" save_period=10, # Save checkpoint every 10 epochs\n",
|
| 173 |
+
" cache=True,\n",
|
| 174 |
+
" verbose=True\n",
|
| 175 |
+
")"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "markdown",
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"source": [
|
| 182 |
+
"## Save and Export Model"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": null,
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [],
|
| 190 |
+
"source": [
|
| 191 |
+
"# Save final model\n",
|
| 192 |
+
"final_model_path = f'{WEIGHTS_DIR}/strawberry_{MODEL_NAME}.pt'\n",
|
| 193 |
+
"model.save(final_model_path)\n",
|
| 194 |
+
"print(f\"✓ Model saved to: {final_model_path}\")\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"# Export to ONNX format\n",
|
| 197 |
+
"print(\"Exporting to ONNX format...\")\n",
|
| 198 |
+
"onnx_path = model.export(format='onnx', imgsz=IMG_SIZE, dynamic=True)\n",
|
| 199 |
+
"print(f\"✓ ONNX model exported to: {onnx_path}\")\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"# Export to TensorFlow Lite format (for Raspberry Pi)\n",
|
| 202 |
+
"print(\"Exporting to TensorFlow Lite format...\")\n",
|
| 203 |
+
"tflite_path = model.export(format='tflite', imgsz=IMG_SIZE)\n",
|
| 204 |
+
"print(f\"✓ TFLite model exported to: {tflite_path}\")"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "markdown",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"source": [
|
| 211 |
+
"## View Training Results"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"# Display training results\n",
|
| 221 |
+
"import matplotlib.pyplot as plt\n",
|
| 222 |
+
"from pathlib import Path\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"results_dir = Path(RESULTS_DIR) / 'strawberry_detection'\n",
|
| 225 |
+
"plots_dir = results_dir / 'plots'\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"if plots_dir.exists():\n",
|
| 228 |
+
" print(\"Training plots:\")\n",
|
| 229 |
+
" for plot_file in plots_dir.glob('*.png'):\n",
|
| 230 |
+
" print(f\" - {plot_file.name}\")\n",
|
| 231 |
+
"else:\n",
|
| 232 |
+
" print(\"Plots directory not found yet. Training may still be in progress.\")"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"cell_type": "code",
|
| 237 |
+
"execution_count": null,
|
| 238 |
+
"metadata": {},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"# Show confusion matrix\n",
|
| 242 |
+
"from IPython.display import Image, display\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"confusion_matrix = plots_dir / 'confusion_matrix.png'\n",
|
| 245 |
+
"if confusion_matrix.exists():\n",
|
| 246 |
+
" print(\"Confusion Matrix:\")\n",
|
| 247 |
+
" display(Image(filename=str(confusion_matrix)))"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"cell_type": "markdown",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"source": [
|
| 254 |
+
"## Download Trained Model\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"Download the trained model to your local machine:"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"cell_type": "code",
|
| 261 |
+
"execution_count": null,
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"outputs": [],
|
| 264 |
+
"source": [
|
| 265 |
+
"from google.colab import files\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"# Download PyTorch model\n",
|
| 268 |
+
"files.download(final_model_path)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"# Download ONNX model\n",
|
| 271 |
+
"files.download(str(onnx_path))\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Download TFLite model\n",
|
| 274 |
+
"files.download(str(tflite_path))"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "markdown",
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"source": [
|
| 281 |
+
"## Next Steps\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"1. Copy the trained model to your Raspberry Pi\n",
|
| 284 |
+
"2. Use the `detect_realtime_pi.py` script for inference\n",
|
| 285 |
+
"3. Integrate with your robotic arm control system\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"## Tips for Better Results\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"- If accuracy is low, increase EPOCHS to 150-200\n",
|
| 290 |
+
"- Try different MODEL_NAME sizes: yolov8s, yolov8m (slower but more accurate)\n",
|
| 291 |
+
"- Collect more training images with varied lighting conditions\n",
|
| 292 |
+
"- Use data augmentation techniques"
|
| 293 |
+
]
|
| 294 |
+
}
|
| 295 |
+
],
|
| 296 |
+
"metadata": {
|
| 297 |
+
"kernelspec": {
|
| 298 |
+
"display_name": "Python 3",
|
| 299 |
+
"language": "python",
|
| 300 |
+
"name": "python3"
|
| 301 |
+
},
|
| 302 |
+
"language_info": {
|
| 303 |
+
"name": "python",
|
| 304 |
+
"version": "3.8.0"
|
| 305 |
+
}
|
| 306 |
+
},
|
| 307 |
+
"nbformat": 4,
|
| 308 |
+
"nbformat_minor": 4
|
| 309 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
torch>=1.8.0
|
| 3 |
+
torchvision>=0.9.0
|
| 4 |
+
ultralytics>=8.0.0
|
| 5 |
+
opencv-python>=4.5.0
|
| 6 |
+
numpy>=1.21.0
|
| 7 |
+
matplotlib>=3.3.0
|
| 8 |
+
Pillow>=8.0.0
|
| 9 |
+
tqdm>=4.60.0
|
| 10 |
+
tensorboard>=2.7.0
|
| 11 |
+
onnx>=1.10.0
|
| 12 |
+
onnxruntime>=1.10.0
|
| 13 |
+
tensorflow>=2.8.0
|
| 14 |
+
=======
|
| 15 |
+
ultralytics>=8.0.0
|
| 16 |
+
opencv-python>=4.5.0
|
| 17 |
+
numpy>=1.21.0
|
| 18 |
+
torch>=1.8.0
|
| 19 |
+
torchvision>=0.9.0
|
| 20 |
+
>>>>>>> bb77661e9aecb09169fb60057ff0ebb1f504de58
|
results.csv
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
|
| 2 |
+
1,6.86824,1.32115,2.72359,1.31365,0.0042,0.96396,0.34849,0.15119,1.72013,3.19145,1.6125,0.08416,0.00016,0.00016
|
| 3 |
+
2,12.0365,1.29183,1.25338,1.18955,0.825,0.59459,0.73878,0.51353,1.36044,3.55862,1.26268,0.0673235,0.000323466,0.000323466
|
| 4 |
+
3,16.7713,1.21518,0.99614,1.17689,0.89744,0.63063,0.7709,0.49709,1.44287,3.62474,1.297,0.0504802,0.0004802,0.0004802
|
| 5 |
+
4,21.4651,1.2383,0.9566,1.17915,0.47534,0.95495,0.91297,0.62608,1.21511,2.44191,1.15645,0.0336302,0.000630202,0.000630202
|
| 6 |
+
5,26.0183,1.23907,0.91469,1.16288,0.88867,0.9349,0.94819,0.61003,1.24465,1.14903,1.21458,0.0167735,0.000773472,0.000773472
|
| 7 |
+
6,30.4778,1.22134,0.90462,1.19466,0.91765,0.90347,0.9523,0.62325,1.19483,0.9782,1.23663,0.000901,0.000901,0.000901
|
| 8 |
+
7,34.6509,1.20304,0.8498,1.20764,0.7236,0.75676,0.78748,0.52191,1.15781,1.36482,1.33971,0.0008812,0.0008812,0.0008812
|
| 9 |
+
8,39.22,1.16651,0.84314,1.21721,0.86762,0.88573,0.95211,0.63231,1.19173,0.83562,1.39059,0.0008614,0.0008614,0.0008614
|
| 10 |
+
9,43.7182,1.19071,0.75512,1.21773,0.92471,0.88288,0.9604,0.64617,1.16949,0.8096,1.34081,0.0008416,0.0008416,0.0008416
|
| 11 |
+
10,48.0715,1.15343,0.75549,1.19624,0.95433,0.85586,0.94036,0.63013,1.15544,0.87035,1.3024,0.0008218,0.0008218,0.0008218
|
| 12 |
+
11,52.1875,1.17306,0.77483,1.19032,0.90561,0.86486,0.91627,0.61251,1.1512,0.87167,1.33988,0.000802,0.000802,0.000802
|
| 13 |
+
12,56.3393,1.15909,0.72052,1.21331,0.93901,0.92793,0.97553,0.65695,1.19064,0.67324,1.35,0.0007822,0.0007822,0.0007822
|
| 14 |
+
13,60.7417,1.12648,0.70107,1.1631,0.92996,0.9009,0.95406,0.63081,1.20018,0.74019,1.3586,0.0007624,0.0007624,0.0007624
|
| 15 |
+
14,65.2298,1.09008,0.70923,1.15539,0.91466,0.91892,0.96205,0.67663,1.10559,0.7621,1.25653,0.0007426,0.0007426,0.0007426
|
| 16 |
+
15,69.571,1.14397,0.69424,1.1814,0.9719,0.93464,0.97478,0.66421,1.14459,0.60123,1.3482,0.0007228,0.0007228,0.0007228
|
| 17 |
+
16,74.0715,1.1273,0.69053,1.15822,0.94736,0.89189,0.96811,0.67023,1.15103,0.6696,1.29367,0.000703,0.000703,0.000703
|
| 18 |
+
17,78.482,1.14856,0.70743,1.19375,0.93885,0.95495,0.97045,0.6589,1.1875,0.61151,1.36611,0.0006832,0.0006832,0.0006832
|
| 19 |
+
18,82.5976,1.11894,0.68007,1.1589,0.98034,0.89868,0.97938,0.65622,1.1357,0.6026,1.32801,0.0006634,0.0006634,0.0006634
|
| 20 |
+
19,87.0478,1.11983,0.64666,1.16769,0.92238,0.96396,0.97889,0.66274,1.15968,0.60674,1.31971,0.0006436,0.0006436,0.0006436
|
| 21 |
+
20,91.4751,1.08926,0.64732,1.14873,0.95361,0.95495,0.98242,0.67862,1.13128,0.56679,1.32197,0.0006238,0.0006238,0.0006238
|
| 22 |
+
21,95.769,1.06938,0.63992,1.14856,0.93911,0.97265,0.98132,0.6723,1.11844,0.53648,1.29829,0.000604,0.000604,0.000604
|
| 23 |
+
22,100.516,1.0275,0.60948,1.11803,0.96362,0.9544,0.9832,0.66712,1.11005,0.57067,1.29318,0.0005842,0.0005842,0.0005842
|
| 24 |
+
23,104.83,1.08322,0.63118,1.16093,0.961,0.93694,0.9822,0.66812,1.12503,0.59445,1.29324,0.0005644,0.0005644,0.0005644
|
| 25 |
+
24,109.564,1.03814,0.59752,1.13687,0.92285,0.95495,0.98092,0.66146,1.16236,0.56458,1.30195,0.0005446,0.0005446,0.0005446
|
| 26 |
+
25,113.789,1.05733,0.61255,1.12831,0.93744,0.94499,0.97589,0.66985,1.14589,0.58107,1.29102,0.0005248,0.0005248,0.0005248
|
| 27 |
+
26,118.014,1.03262,0.59764,1.11505,0.93427,0.96396,0.97761,0.66389,1.13641,0.56017,1.31125,0.000505,0.000505,0.000505
|
| 28 |
+
27,122.046,1.05034,0.59402,1.1313,0.9387,0.95495,0.98023,0.66648,1.1571,0.58369,1.29765,0.0004852,0.0004852,0.0004852
|
| 29 |
+
28,126.296,1.02708,0.58665,1.14179,0.95501,0.95618,0.98264,0.64604,1.17408,0.59978,1.35501,0.0004654,0.0004654,0.0004654
|
| 30 |
+
29,130.486,1.0421,0.59585,1.11415,0.94498,0.98198,0.98278,0.66907,1.13142,0.55493,1.31176,0.0004456,0.0004456,0.0004456
|
| 31 |
+
30,135.122,1.00967,0.58472,1.11493,0.98144,0.95295,0.98468,0.67596,1.10479,0.53006,1.32835,0.0004258,0.0004258,0.0004258
|
| 32 |
+
31,139.655,0.96788,0.52585,1.13779,0.97128,0.91892,0.98086,0.66954,1.12559,0.55724,1.33653,0.000406,0.000406,0.000406
|
| 33 |
+
32,144.364,0.98252,0.56894,1.10962,0.92865,0.96396,0.98034,0.66634,1.12278,0.5327,1.33008,0.0003862,0.0003862,0.0003862
|
| 34 |
+
33,148.518,0.98348,0.55054,1.11565,0.93032,0.98198,0.98389,0.6858,1.12571,0.49984,1.32711,0.0003664,0.0003664,0.0003664
|
| 35 |
+
34,152.545,1.00252,0.55027,1.09994,0.96066,0.94595,0.98377,0.69641,1.12416,0.52283,1.2885,0.0003466,0.0003466,0.0003466
|
| 36 |
+
35,156.54,0.94835,0.53754,1.08347,0.9303,0.96198,0.98116,0.67964,1.11046,0.5242,1.30978,0.0003268,0.0003268,0.0003268
|
| 37 |
+
36,160.636,0.96319,0.53923,1.11176,0.91568,0.99099,0.98692,0.68369,1.11171,0.48497,1.30834,0.000307,0.000307,0.000307
|
| 38 |
+
37,164.837,0.95618,0.53593,1.08664,0.97244,0.95364,0.98547,0.69178,1.11108,0.50238,1.28898,0.0002872,0.0002872,0.0002872
|
| 39 |
+
38,168.917,0.97244,0.5482,1.08408,0.96389,0.91892,0.98071,0.66854,1.12583,0.51788,1.32157,0.0002674,0.0002674,0.0002674
|
| 40 |
+
39,173.37,0.93709,0.51704,1.08327,0.97083,0.90991,0.98181,0.66532,1.14223,0.52647,1.35507,0.0002476,0.0002476,0.0002476
|
| 41 |
+
40,178.082,0.92316,0.50575,1.07146,0.96253,0.92569,0.98231,0.69138,1.11065,0.51471,1.30314,0.0002278,0.0002278,0.0002278
|
| 42 |
+
41,183.476,0.8683,0.46482,1.0931,0.95441,0.94301,0.98327,0.69583,1.11377,0.5044,1.29681,0.000208,0.000208,0.000208
|
| 43 |
+
42,187.581,0.859,0.45945,1.08201,0.96091,0.97297,0.98492,0.70061,1.10066,0.47832,1.30928,0.0001882,0.0001882,0.0001882
|
| 44 |
+
43,191.651,0.85681,0.44947,1.08108,0.9516,0.97297,0.98641,0.70187,1.10791,0.47281,1.33175,0.0001684,0.0001684,0.0001684
|
| 45 |
+
44,195.634,0.82438,0.44285,1.06652,0.95144,0.97297,0.98817,0.6897,1.10938,0.47912,1.32276,0.0001486,0.0001486,0.0001486
|
| 46 |
+
45,199.92,0.838,0.42537,1.06944,0.95343,0.97297,0.98846,0.68458,1.12846,0.4846,1.33303,0.0001288,0.0001288,0.0001288
|
| 47 |
+
46,204.796,0.8131,0.41691,1.0408,0.94776,0.95495,0.98962,0.67581,1.15034,0.48295,1.38289,0.000109,0.000109,0.000109
|
| 48 |
+
47,209.422,0.81183,0.40894,1.06726,0.93117,0.97502,0.98866,0.68269,1.13011,0.47496,1.35196,8.92e-05,8.92e-05,8.92e-05
|
| 49 |
+
48,213.529,0.80037,0.39995,1.04921,0.95491,0.95402,0.98838,0.6885,1.12861,0.48232,1.32018,6.94e-05,6.94e-05,6.94e-05
|
| 50 |
+
49,217.607,0.81871,0.4227,1.0661,0.95526,0.96185,0.98927,0.69025,1.14264,0.48573,1.33214,4.96e-05,4.96e-05,4.96e-05
|
| 51 |
+
50,221.537,0.802,0.41167,1.05813,0.95484,0.96396,0.98857,0.68227,1.13242,0.47929,1.33976,2.98e-05,2.98e-05,2.98e-05
|
scripts/all_combine3.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dual_camera_15cm_tuned.py
|
| 2 |
+
"""
|
| 3 |
+
FINAL TUNED VERSION
|
| 4 |
+
- Baseline: 15.0 cm
|
| 5 |
+
- Range: 20-40 cm
|
| 6 |
+
- Calibration: Tuned to 0.82 based on your latest log (48.5cm -> 30cm).
|
| 7 |
+
"""
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from ultralytics import YOLO
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
# ----------------------------
|
| 14 |
+
# USER SETTINGS
|
| 15 |
+
# ----------------------------
|
| 16 |
+
CAM_A_ID = 2 # LEFT Camera
|
| 17 |
+
CAM_B_ID = 1 # RIGHT Camera
|
| 18 |
+
|
| 19 |
+
FRAME_W = 640
|
| 20 |
+
FRAME_H = 408
|
| 21 |
+
YOLO_MODEL_PATH = "strawberry.pt"
|
| 22 |
+
|
| 23 |
+
# --- GEOMETRY & CALIBRATION ---
|
| 24 |
+
BASELINE_CM = 15.0
|
| 25 |
+
FOCUS_DIST_CM = 30.0
|
| 26 |
+
|
| 27 |
+
# RE-CALIBRATED SCALAR
|
| 28 |
+
# Your Raw Reading is approx 36.6cm. Real is 30.0cm.
|
| 29 |
+
# 30.0 / 36.6 = ~0.82
|
| 30 |
+
DEPTH_SCALAR = 0.82
|
| 31 |
+
|
| 32 |
+
# Auto-calculate angle
|
| 33 |
+
val = (BASELINE_CM / 2.0) / FOCUS_DIST_CM
|
| 34 |
+
calc_yaw_deg = math.degrees(math.atan(val))
|
| 35 |
+
YAW_LEFT_DEG = calc_yaw_deg
|
| 36 |
+
YAW_RIGHT_DEG = -calc_yaw_deg
|
| 37 |
+
|
| 38 |
+
print(f"--- CONFIGURATION ---")
|
| 39 |
+
print(f"1. Baseline: {BASELINE_CM} cm")
|
| 40 |
+
print(f"2. Scalar: x{DEPTH_SCALAR} (Reducing raw estimate to match reality)")
|
| 41 |
+
print(f"3. REQUIRED ANGLE: +/- {calc_yaw_deg:.2f} degrees")
|
| 42 |
+
print(f"---------------------")
|
| 43 |
+
|
| 44 |
+
# Intrinsics
|
| 45 |
+
K_A = np.array([[629.10808758, 0.0, 347.20913144],
|
| 46 |
+
[0.0, 631.11321979, 277.5222819],
|
| 47 |
+
[0.0, 0.0, 1.0]], dtype=np.float64)
|
| 48 |
+
dist_A = np.array([-0.35469562, 0.10232556, -0.0005468, -0.00174671, 0.01546246], dtype=np.float64)
|
| 49 |
+
|
| 50 |
+
K_B = np.array([[1001.67997, 0.0, 367.736216],
|
| 51 |
+
[0.0, 996.698369, 312.866527],
|
| 52 |
+
[0.0, 0.0, 1.0]], dtype=np.float64)
|
| 53 |
+
dist_B = np.array([-0.49543094, 0.82826695, -0.00180861, -0.00362202, -1.42667838], dtype=np.float64)
|
| 54 |
+
|
| 55 |
+
# ----------------------------
|
| 56 |
+
# HELPERS
|
| 57 |
+
# ----------------------------
|
| 58 |
+
def capture_single(cam_id):
|
| 59 |
+
cap = cv2.VideoCapture(cam_id, cv2.CAP_DSHOW)
|
| 60 |
+
if not cap.isOpened(): return None
|
| 61 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_W)
|
| 62 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_H)
|
| 63 |
+
for _ in range(5): cap.read() # Warmup
|
| 64 |
+
ret, frame = cap.read()
|
| 65 |
+
cap.release()
|
| 66 |
+
return frame
|
| 67 |
+
|
| 68 |
+
def build_undistort_maps(K, dist):
|
| 69 |
+
newK, _ = cv2.getOptimalNewCameraMatrix(K, dist, (FRAME_W, FRAME_H), 1.0)
|
| 70 |
+
mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, newK, (FRAME_W, FRAME_H), cv2.CV_32FC1)
|
| 71 |
+
return mapx, mapy, newK
|
| 72 |
+
|
| 73 |
+
def detect_on_image(model, img):
|
| 74 |
+
results = model(img, verbose=False)[0]
|
| 75 |
+
dets = []
|
| 76 |
+
for box in results.boxes:
|
| 77 |
+
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
|
| 78 |
+
cx = int((x1 + x2) / 2)
|
| 79 |
+
cy = int((y1 + y2) / 2)
|
| 80 |
+
conf = float(box.conf[0])
|
| 81 |
+
cls = int(box.cls[0])
|
| 82 |
+
name = model.names.get(cls, str(cls))
|
| 83 |
+
dets.append({'x1':x1,'y1':y1,'x2':x2,'y2':y2,'cx':cx,'cy':cy,'conf':conf,'cls':cls,'name':name})
|
| 84 |
+
return sorted(dets, key=lambda d: d['cx'])
|
| 85 |
+
|
| 86 |
+
def match_stereo(detL, detR):
|
| 87 |
+
matches = []
|
| 88 |
+
usedR = set()
|
| 89 |
+
for l in detL:
|
| 90 |
+
best_idx = -1
|
| 91 |
+
best_score = 9999
|
| 92 |
+
for i, r in enumerate(detR):
|
| 93 |
+
if i in usedR: continue
|
| 94 |
+
if l['cls'] != r['cls']: continue
|
| 95 |
+
dy = abs(l['cy'] - r['cy'])
|
| 96 |
+
if dy > 60: continue
|
| 97 |
+
if dy < best_score:
|
| 98 |
+
best_score = dy
|
| 99 |
+
best_idx = i
|
| 100 |
+
if best_idx != -1:
|
| 101 |
+
matches.append((l, detR[best_idx]))
|
| 102 |
+
usedR.add(best_idx)
|
| 103 |
+
return matches
|
| 104 |
+
|
| 105 |
+
# --- 3D MATH ---
|
| 106 |
+
def yaw_to_R_deg(yaw_deg):
|
| 107 |
+
y = math.radians(yaw_deg)
|
| 108 |
+
cy = math.cos(y); sy = math.sin(y)
|
| 109 |
+
return np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float64)
|
| 110 |
+
|
| 111 |
+
def build_projection_matrices(newK_A, newK_B, yaw_L, yaw_R, baseline):
|
| 112 |
+
R_L = yaw_to_R_deg(yaw_L)
|
| 113 |
+
R_R = yaw_to_R_deg(yaw_R)
|
| 114 |
+
R_W2A = R_L.T
|
| 115 |
+
t_W2A = np.zeros((3,1))
|
| 116 |
+
R_W2B = R_R.T
|
| 117 |
+
C_B_world = np.array([[baseline], [0.0], [0.0]])
|
| 118 |
+
t_W2B = -R_W2B @ C_B_world
|
| 119 |
+
P1 = newK_A @ np.hstack((R_W2A, t_W2A))
|
| 120 |
+
P2 = newK_B @ np.hstack((R_W2B, t_W2B))
|
| 121 |
+
return P1, P2
|
| 122 |
+
|
| 123 |
+
def triangulate_matrix(dL, dR, P1, P2):
|
| 124 |
+
ptsL = np.array([[float(dL['cx'])],[float(dL['cy'])]], dtype=np.float64)
|
| 125 |
+
ptsR = np.array([[float(dR['cx'])],[float(dR['cy'])]], dtype=np.float64)
|
| 126 |
+
|
| 127 |
+
Xh = cv2.triangulatePoints(P1, P2, ptsL, ptsR)
|
| 128 |
+
Xh /= Xh[3]
|
| 129 |
+
|
| 130 |
+
X = float(Xh[0].item())
|
| 131 |
+
Y = float(Xh[1].item())
|
| 132 |
+
Z_raw = float(Xh[2].item())
|
| 133 |
+
|
| 134 |
+
return X, Y, Z_raw * DEPTH_SCALAR
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
print("[INFO] Loading YOLO...")
|
| 138 |
+
model = YOLO(YOLO_MODEL_PATH)
|
| 139 |
+
|
| 140 |
+
print(f"[INFO] Capturing...")
|
| 141 |
+
frameA = capture_single(CAM_A_ID)
|
| 142 |
+
frameB = capture_single(CAM_B_ID)
|
| 143 |
+
if frameA is None or frameB is None: return
|
| 144 |
+
|
| 145 |
+
mapAx, mapAy, newKA = build_undistort_maps(K_A, dist_A)
|
| 146 |
+
mapBx, mapBy, newKB = build_undistort_maps(K_B, dist_B)
|
| 147 |
+
undA = cv2.remap(frameA, mapAx, mapAy, cv2.INTER_LINEAR)
|
| 148 |
+
undB = cv2.remap(frameB, mapBx, mapBy, cv2.INTER_LINEAR)
|
| 149 |
+
|
| 150 |
+
detA = detect_on_image(model, undA)
|
| 151 |
+
detB = detect_on_image(model, undB)
|
| 152 |
+
|
| 153 |
+
matches = match_stereo(detA, detB)
|
| 154 |
+
print(f"--- Matches found: {len(matches)} ---")
|
| 155 |
+
|
| 156 |
+
P1, P2 = build_projection_matrices(newKA, newKB, YAW_LEFT_DEG, YAW_RIGHT_DEG, BASELINE_CM)
|
| 157 |
+
|
| 158 |
+
combo = np.hstack((undA, undB))
|
| 159 |
+
|
| 160 |
+
for l, r in matches:
|
| 161 |
+
XYZ = triangulate_matrix(l, r, P1, P2)
|
| 162 |
+
X,Y,Z = XYZ
|
| 163 |
+
|
| 164 |
+
label = f"Z={Z:.1f}cm"
|
| 165 |
+
print(f"Target ({l['name']}): {label} (X={X:.1f}, Y={Y:.1f})")
|
| 166 |
+
|
| 167 |
+
cv2.line(combo, (l['cx'], l['cy']), (r['cx']+FRAME_W, r['cy']), (0,255,0), 2)
|
| 168 |
+
cv2.rectangle(combo, (l['x1'], l['y1']), (l['x2'], l['y2']), (0,0,255), 2)
|
| 169 |
+
cv2.rectangle(combo, (r['x1']+FRAME_W, r['y1']), (r['x2']+FRAME_W, r['y2']), (0,0,255), 2)
|
| 170 |
+
cv2.putText(combo, label, (l['cx'], l['cy']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,255), 2)
|
| 171 |
+
|
| 172 |
+
cv2.imshow("Tuned Depth Result", combo)
|
| 173 |
+
cv2.waitKey(0)
|
| 174 |
+
cv2.destroyAllWindows()
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
scripts/auto_label_strawberries.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Automated Strawberry Ripeness Labeling System
|
| 4 |
+
Uses color analysis to automatically label strawberry ripeness
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
class AutoRipenessLabeler:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
"""Initialize the automatic ripeness labeler"""
|
| 20 |
+
print("✅ Initialized automatic ripeness labeler")
|
| 21 |
+
|
| 22 |
+
def analyze_strawberry_color(self, image_path):
|
| 23 |
+
"""Analyze the color of strawberries to determine ripeness"""
|
| 24 |
+
try:
|
| 25 |
+
# Load image
|
| 26 |
+
img = cv2.imread(str(image_path))
|
| 27 |
+
if img is None:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
# Convert BGR to RGB
|
| 31 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 32 |
+
|
| 33 |
+
# Convert to HSV for better color analysis
|
| 34 |
+
hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
|
| 35 |
+
|
| 36 |
+
# Define color ranges for different ripeness stages
|
| 37 |
+
# Red range (ripe strawberries)
|
| 38 |
+
red_lower1 = np.array([0, 50, 50])
|
| 39 |
+
red_upper1 = np.array([10, 255, 255])
|
| 40 |
+
red_lower2 = np.array([170, 50, 50])
|
| 41 |
+
red_upper2 = np.array([180, 255, 255])
|
| 42 |
+
|
| 43 |
+
# Green range (unripe strawberries)
|
| 44 |
+
green_lower = np.array([40, 40, 40])
|
| 45 |
+
green_upper = np.array([80, 255, 255])
|
| 46 |
+
|
| 47 |
+
# Dark red range (overripe strawberries)
|
| 48 |
+
dark_red_lower = np.array([0, 100, 0])
|
| 49 |
+
dark_red_upper = np.array([20, 255, 100])
|
| 50 |
+
|
| 51 |
+
# Create masks for each color range
|
| 52 |
+
red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
|
| 53 |
+
red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
|
| 54 |
+
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
|
| 55 |
+
|
| 56 |
+
green_mask = cv2.inRange(hsv, green_lower, green_upper)
|
| 57 |
+
dark_red_mask = cv2.inRange(hsv, dark_red_lower, dark_red_upper)
|
| 58 |
+
|
| 59 |
+
# Calculate percentages
|
| 60 |
+
total_pixels = hsv.shape[0] * hsv.shape[1]
|
| 61 |
+
red_pixels = np.sum(red_mask > 0)
|
| 62 |
+
green_pixels = np.sum(green_mask > 0)
|
| 63 |
+
dark_red_pixels = np.sum(dark_red_mask > 0)
|
| 64 |
+
|
| 65 |
+
red_percentage = red_pixels / total_pixels
|
| 66 |
+
green_percentage = green_pixels / total_pixels
|
| 67 |
+
dark_red_percentage = dark_red_pixels / total_pixels
|
| 68 |
+
|
| 69 |
+
# Calculate brightness and saturation for fallback analysis
|
| 70 |
+
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 71 |
+
avg_brightness = np.mean(gray)
|
| 72 |
+
avg_saturation = np.mean(hsv[:, :, 1])
|
| 73 |
+
|
| 74 |
+
# Determine ripeness based on color percentages
|
| 75 |
+
if green_percentage > 0.3:
|
| 76 |
+
ripeness = "unripe"
|
| 77 |
+
confidence = min(green_percentage * 2, 0.9)
|
| 78 |
+
elif dark_red_percentage > 0.2:
|
| 79 |
+
ripeness = "overripe"
|
| 80 |
+
confidence = min(dark_red_percentage * 2, 0.9)
|
| 81 |
+
elif red_percentage > 0.2:
|
| 82 |
+
ripeness = "ripe"
|
| 83 |
+
confidence = min(red_percentage * 2, 0.9)
|
| 84 |
+
else:
|
| 85 |
+
# Fallback: use brightness and saturation
|
| 86 |
+
if avg_brightness < 80:
|
| 87 |
+
ripeness = "overripe"
|
| 88 |
+
confidence = 0.6
|
| 89 |
+
elif avg_brightness > 150:
|
| 90 |
+
ripeness = "unripe"
|
| 91 |
+
confidence = 0.6
|
| 92 |
+
else:
|
| 93 |
+
ripeness = "ripe"
|
| 94 |
+
confidence = 0.7
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
'ripeness': ripeness,
|
| 98 |
+
'confidence': confidence,
|
| 99 |
+
'color_analysis': {
|
| 100 |
+
'red_percentage': red_percentage,
|
| 101 |
+
'green_percentage': green_percentage,
|
| 102 |
+
'dark_red_percentage': dark_red_percentage,
|
| 103 |
+
'avg_brightness': float(avg_brightness),
|
| 104 |
+
'avg_saturation': float(avg_saturation)
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Error analyzing color in {image_path}: {e}")
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
def batch_auto_label(self, image_files, output_dirs, confidence_threshold=0.6):
|
| 113 |
+
"""Automatically label a batch of images"""
|
| 114 |
+
results = []
|
| 115 |
+
|
| 116 |
+
for i, image_path in enumerate(image_files):
|
| 117 |
+
print(f"Processing {i+1}/{len(image_files)}: {image_path.name}")
|
| 118 |
+
|
| 119 |
+
analysis = self.analyze_strawberry_color(image_path)
|
| 120 |
+
|
| 121 |
+
if analysis and analysis['confidence'] >= confidence_threshold:
|
| 122 |
+
ripeness = analysis['ripeness']
|
| 123 |
+
confidence = analysis['confidence']
|
| 124 |
+
|
| 125 |
+
# Copy image to appropriate directory
|
| 126 |
+
dest_path = output_dirs[ripeness] / image_path.name
|
| 127 |
+
try:
|
| 128 |
+
import shutil
|
| 129 |
+
shutil.copy2(image_path, dest_path)
|
| 130 |
+
print(f" ✅ {ripeness} (confidence: {confidence:.2f})")
|
| 131 |
+
results.append({
|
| 132 |
+
'image': image_path.name,
|
| 133 |
+
'label': ripeness,
|
| 134 |
+
'confidence': confidence,
|
| 135 |
+
'analysis': analysis['color_analysis']
|
| 136 |
+
})
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f" ❌ Error copying file: {e}")
|
| 139 |
+
else:
|
| 140 |
+
print(f" ⚠️ Low confidence or analysis failed")
|
| 141 |
+
results.append({
|
| 142 |
+
'image': image_path.name,
|
| 143 |
+
'label': 'unknown',
|
| 144 |
+
'confidence': analysis['confidence'] if analysis else 0.0,
|
| 145 |
+
'analysis': analysis['color_analysis'] if analysis else None
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
return results
|
| 149 |
+
|
| 150 |
+
def main():
|
| 151 |
+
parser = argparse.ArgumentParser(description='Automatically label strawberry ripeness dataset')
|
| 152 |
+
parser.add_argument('--dataset-path', type=str,
|
| 153 |
+
default='model/ripeness_manual_dataset',
|
| 154 |
+
help='Path to the ripeness dataset directory')
|
| 155 |
+
parser.add_argument('--confidence-threshold', type=float, default=0.6,
|
| 156 |
+
help='Minimum confidence for automatic labeling')
|
| 157 |
+
parser.add_argument('--max-images', type=int, default=50,
|
| 158 |
+
help='Maximum number of images to process')
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
|
| 162 |
+
base_path = Path(args.dataset_path)
|
| 163 |
+
to_label_path = base_path / 'to_label'
|
| 164 |
+
|
| 165 |
+
if not to_label_path.exists():
|
| 166 |
+
print(f"Error: to_label directory not found at {to_label_path}")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
# Create output directories
|
| 170 |
+
output_dirs = {}
|
| 171 |
+
for label in ['unripe', 'ripe', 'overripe']:
|
| 172 |
+
dir_path = base_path / label
|
| 173 |
+
dir_path.mkdir(exist_ok=True)
|
| 174 |
+
output_dirs[label] = dir_path
|
| 175 |
+
|
| 176 |
+
# Get image files
|
| 177 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 178 |
+
image_files = []
|
| 179 |
+
for file_path in to_label_path.iterdir():
|
| 180 |
+
if file_path.suffix.lower() in image_extensions:
|
| 181 |
+
image_files.append(file_path)
|
| 182 |
+
|
| 183 |
+
image_files = sorted(image_files)[:args.max_images]
|
| 184 |
+
|
| 185 |
+
print(f"Found {len(image_files)} images to process")
|
| 186 |
+
print(f"Confidence threshold: {args.confidence_threshold}")
|
| 187 |
+
|
| 188 |
+
if not image_files:
|
| 189 |
+
print("No images found to process.")
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
# Initialize auto labeler
|
| 193 |
+
labeler = AutoRipenessLabeler()
|
| 194 |
+
|
| 195 |
+
# Process images
|
| 196 |
+
results = labeler.batch_auto_label(image_files, output_dirs, args.confidence_threshold)
|
| 197 |
+
|
| 198 |
+
# Save results
|
| 199 |
+
results_file = base_path / f'auto_labeling_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
|
| 200 |
+
with open(results_file, 'w') as f:
|
| 201 |
+
json.dump(results, f, indent=2)
|
| 202 |
+
|
| 203 |
+
# Print summary
|
| 204 |
+
label_counts = {'unripe': 0, 'ripe': 0, 'overripe': 0, 'unknown': 0}
|
| 205 |
+
for result in results:
|
| 206 |
+
label_counts[result['label']] += 1
|
| 207 |
+
|
| 208 |
+
print("\n=== AUTOMATIC LABELING RESULTS ===")
|
| 209 |
+
for label, count in label_counts.items():
|
| 210 |
+
print(f"{label}: {count} images")
|
| 211 |
+
|
| 212 |
+
print(f"\nResults saved to: {results_file}")
|
| 213 |
+
|
| 214 |
+
if label_counts['unknown'] > 0:
|
| 215 |
+
print(f"\n⚠️ {label_counts['unknown']} images need manual review")
|
| 216 |
+
print("You can use the manual labeling tool for these:")
|
| 217 |
+
print("python3 label_ripeness_dataset.py")
|
| 218 |
+
|
| 219 |
+
if __name__ == '__main__':
|
| 220 |
+
main()
|
scripts/benchmark_models.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Benchmark YOLO models for performance on Raspberry Pi 4B (or current machine).
|
| 4 |
+
Measures inference time, FPS, and memory usage for different model formats.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import time
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
import yaml
|
| 15 |
+
from ultralytics import YOLO
|
| 16 |
+
import psutil
|
| 17 |
+
import platform
|
| 18 |
+
|
| 19 |
+
def get_system_info():
|
| 20 |
+
"""Get system information for benchmarking context."""
|
| 21 |
+
info = {
|
| 22 |
+
'system': platform.system(),
|
| 23 |
+
'processor': platform.processor(),
|
| 24 |
+
'architecture': platform.architecture()[0],
|
| 25 |
+
'python_version': platform.python_version(),
|
| 26 |
+
'cpu_count': psutil.cpu_count(logical=False),
|
| 27 |
+
'memory_gb': psutil.virtual_memory().total / (1024**3),
|
| 28 |
+
}
|
| 29 |
+
return info
|
| 30 |
+
|
| 31 |
+
def load_test_images(dataset_path, max_images=50):
|
| 32 |
+
"""Load test images from dataset for benchmarking."""
|
| 33 |
+
test_images = []
|
| 34 |
+
|
| 35 |
+
# Try multiple possible locations
|
| 36 |
+
possible_paths = [
|
| 37 |
+
Path(dataset_path) / "test" / "images",
|
| 38 |
+
Path(dataset_path) / "valid" / "images",
|
| 39 |
+
Path(dataset_path) / "val" / "images",
|
| 40 |
+
Path(dataset_path) / "train" / "images",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for path in possible_paths:
|
| 44 |
+
if path.exists():
|
| 45 |
+
image_files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
|
| 46 |
+
if image_files:
|
| 47 |
+
test_images = [str(p) for p in image_files[:max_images]]
|
| 48 |
+
print(f"Found {len(test_images)} images in {path}")
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
if not test_images:
|
| 52 |
+
# Create dummy images if no dataset found
|
| 53 |
+
print("No test images found. Creating dummy images for benchmarking.")
|
| 54 |
+
test_images = []
|
| 55 |
+
for i in range(10):
|
| 56 |
+
# Create a dummy image
|
| 57 |
+
dummy_img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
|
| 58 |
+
dummy_path = f"/tmp/dummy_{i}.jpg"
|
| 59 |
+
cv2.imwrite(dummy_path, dummy_img)
|
| 60 |
+
test_images.append(dummy_path)
|
| 61 |
+
|
| 62 |
+
return test_images
|
| 63 |
+
|
| 64 |
+
def benchmark_model(model_path, test_images, img_size=640, warmup=10, runs=100):
|
| 65 |
+
"""
|
| 66 |
+
Benchmark a single model.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
model_path: Path to model file (.pt, .onnx, .tflite)
|
| 70 |
+
test_images: List of image paths for testing
|
| 71 |
+
img_size: Input image size
|
| 72 |
+
warmup: Number of warmup runs
|
| 73 |
+
runs: Number of benchmark runs
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Dictionary with benchmark results
|
| 77 |
+
"""
|
| 78 |
+
print(f"\n{'='*60}")
|
| 79 |
+
print(f"Benchmarking: {model_path}")
|
| 80 |
+
print(f"{'='*60}")
|
| 81 |
+
|
| 82 |
+
results = {
|
| 83 |
+
'model': os.path.basename(model_path),
|
| 84 |
+
'format': Path(model_path).suffix[1:],
|
| 85 |
+
'size_mb': os.path.getsize(model_path) / (1024 * 1024) if os.path.exists(model_path) else 0,
|
| 86 |
+
'inference_times': [],
|
| 87 |
+
'memory_usage_mb': [],
|
| 88 |
+
'success': False
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Check if model exists
|
| 92 |
+
if not os.path.exists(model_path):
|
| 93 |
+
print(f" ❌ Model not found: {model_path}")
|
| 94 |
+
return results
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Load model
|
| 98 |
+
print(f" Loading model...")
|
| 99 |
+
start_load = time.time()
|
| 100 |
+
model = YOLO(model_path)
|
| 101 |
+
load_time = time.time() - start_load
|
| 102 |
+
results['load_time'] = load_time
|
| 103 |
+
|
| 104 |
+
# Warmup
|
| 105 |
+
print(f" Warming up ({warmup} runs)...")
|
| 106 |
+
for i in range(warmup):
|
| 107 |
+
if i >= len(test_images):
|
| 108 |
+
img_path = test_images[0]
|
| 109 |
+
else:
|
| 110 |
+
img_path = test_images[i]
|
| 111 |
+
img = cv2.imread(img_path)
|
| 112 |
+
if img is None:
|
| 113 |
+
# Create dummy image
|
| 114 |
+
img = np.random.randint(0, 255, (img_size, img_size, 3), dtype=np.uint8)
|
| 115 |
+
_ = model(img, verbose=False)
|
| 116 |
+
|
| 117 |
+
# Benchmark runs
|
| 118 |
+
print(f" Running benchmark ({runs} runs)...")
|
| 119 |
+
for i in range(runs):
|
| 120 |
+
# Cycle through test images
|
| 121 |
+
img_idx = i % len(test_images)
|
| 122 |
+
img_path = test_images[img_idx]
|
| 123 |
+
img = cv2.imread(img_path)
|
| 124 |
+
if img is None:
|
| 125 |
+
img = np.random.randint(0, 255, (img_size, img_size, 3), dtype=np.uint8)
|
| 126 |
+
|
| 127 |
+
# Measure memory before
|
| 128 |
+
process = psutil.Process(os.getpid())
|
| 129 |
+
mem_before = process.memory_info().rss / 1024 / 1024 # MB
|
| 130 |
+
|
| 131 |
+
# Inference
|
| 132 |
+
start_time = time.perf_counter()
|
| 133 |
+
results_inference = model(img, verbose=False)
|
| 134 |
+
inference_time = time.perf_counter() - start_time
|
| 135 |
+
|
| 136 |
+
# Measure memory after
|
| 137 |
+
mem_after = process.memory_info().rss / 1024 / 1024 # MB
|
| 138 |
+
mem_used = mem_after - mem_before
|
| 139 |
+
|
| 140 |
+
results['inference_times'].append(inference_time)
|
| 141 |
+
results['memory_usage_mb'].append(mem_used)
|
| 142 |
+
|
| 143 |
+
# Print progress
|
| 144 |
+
if (i + 1) % 20 == 0:
|
| 145 |
+
print(f" Completed {i+1}/{runs} runs...")
|
| 146 |
+
|
| 147 |
+
# Calculate statistics
|
| 148 |
+
if results['inference_times']:
|
| 149 |
+
times = np.array(results['inference_times'])
|
| 150 |
+
results['avg_inference_ms'] = np.mean(times) * 1000
|
| 151 |
+
results['std_inference_ms'] = np.std(times) * 1000
|
| 152 |
+
results['min_inference_ms'] = np.min(times) * 1000
|
| 153 |
+
results['max_inference_ms'] = np.max(times) * 1000
|
| 154 |
+
results['fps'] = 1.0 / np.mean(times)
|
| 155 |
+
results['avg_memory_mb'] = np.mean(results['memory_usage_mb'])
|
| 156 |
+
results['success'] = True
|
| 157 |
+
|
| 158 |
+
print(f" ✅ Benchmark completed:")
|
| 159 |
+
print(f" Model size: {results['size_mb']:.2f} MB")
|
| 160 |
+
print(f" Avg inference: {results['avg_inference_ms']:.2f} ms")
|
| 161 |
+
print(f" FPS: {results['fps']:.2f}")
|
| 162 |
+
print(f" Memory usage: {results['avg_memory_mb']:.2f} MB")
|
| 163 |
+
else:
|
| 164 |
+
print(f" ❌ No inference times recorded")
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f" ❌ Error benchmarking {model_path}: {e}")
|
| 168 |
+
import traceback
|
| 169 |
+
traceback.print_exc()
|
| 170 |
+
|
| 171 |
+
return results
|
| 172 |
+
|
| 173 |
+
def benchmark_all_models(models_to_test, test_images, img_size=640):
|
| 174 |
+
"""Benchmark multiple models and return results."""
|
| 175 |
+
all_results = []
|
| 176 |
+
|
| 177 |
+
for model_info in models_to_test:
|
| 178 |
+
model_path = model_info['path']
|
| 179 |
+
if not os.path.exists(model_path):
|
| 180 |
+
print(f"Skipping {model_path} - not found")
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
results = benchmark_model(
|
| 184 |
+
model_path=model_path,
|
| 185 |
+
test_images=test_images,
|
| 186 |
+
img_size=img_size,
|
| 187 |
+
warmup=10,
|
| 188 |
+
runs=50 # Reduced for faster benchmarking
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
results.update({
|
| 192 |
+
'name': model_info['name'],
|
| 193 |
+
'description': model_info.get('description', '')
|
| 194 |
+
})
|
| 195 |
+
all_results.append(results)
|
| 196 |
+
|
| 197 |
+
return all_results
|
| 198 |
+
|
| 199 |
+
def print_results_table(results):
|
| 200 |
+
"""Print benchmark results in a formatted table."""
|
| 201 |
+
print("\n" + "="*100)
|
| 202 |
+
print("BENCHMARK RESULTS")
|
| 203 |
+
print("="*100)
|
| 204 |
+
print(f"{'Model':<30} {'Format':<8} {'Size (MB)':<10} {'Inference (ms)':<15} {'FPS':<10} {'Memory (MB)':<12} {'Status':<10}")
|
| 205 |
+
print("-"*100)
|
| 206 |
+
|
| 207 |
+
for r in results:
|
| 208 |
+
if r['success']:
|
| 209 |
+
print(f"{r['name'][:28]:<30} {r['format']:<8} {r['size_mb']:>9.2f} "
|
| 210 |
+
f"{r['avg_inference_ms']:>14.2f} {r['fps']:>9.2f} {r['avg_memory_mb']:>11.2f} {'✅':<10}")
|
| 211 |
+
else:
|
| 212 |
+
print(f"{r['name'][:28]:<30} {r['format']:<8} {r['size_mb']:>9.2f} "
|
| 213 |
+
f"{'N/A':>14} {'N/A':>9} {'N/A':>11} {'❌':<10}")
|
| 214 |
+
|
| 215 |
+
print("="*100)
|
| 216 |
+
|
| 217 |
+
# Find best model by FPS
|
| 218 |
+
successful = [r for r in results if r['success']]
|
| 219 |
+
if successful:
|
| 220 |
+
best_by_fps = max(successful, key=lambda x: x['fps'])
|
| 221 |
+
best_by_size = min(successful, key=lambda x: x['size_mb'])
|
| 222 |
+
best_by_memory = min(successful, key=lambda x: x['avg_memory_mb'])
|
| 223 |
+
|
| 224 |
+
print(f"\n🏆 Best by FPS: {best_by_fps['name']} ({best_by_fps['fps']:.2f} FPS)")
|
| 225 |
+
print(f"🏆 Best by size: {best_by_size['name']} ({best_by_size['size_mb']:.2f} MB)")
|
| 226 |
+
print(f"🏆 Best by memory: {best_by_memory['name']} ({best_by_memory['avg_memory_mb']:.2f} MB)")
|
| 227 |
+
|
| 228 |
+
def save_results_to_csv(results, output_path="benchmark_results.csv"):
|
| 229 |
+
"""Save benchmark results to CSV file."""
|
| 230 |
+
import csv
|
| 231 |
+
|
| 232 |
+
with open(output_path, 'w', newline='') as csvfile:
|
| 233 |
+
fieldnames = ['name', 'format', 'size_mb', 'avg_inference_ms',
|
| 234 |
+
'std_inference_ms', 'min_inference_ms', 'max_inference_ms',
|
| 235 |
+
'fps', 'avg_memory_mb', 'load_time', 'success']
|
| 236 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
| 237 |
+
|
| 238 |
+
writer.writeheader()
|
| 239 |
+
for r in results:
|
| 240 |
+
writer.writerow({
|
| 241 |
+
'name': r['name'],
|
| 242 |
+
'format': r['format'],
|
| 243 |
+
'size_mb': r.get('size_mb', 0),
|
| 244 |
+
'avg_inference_ms': r.get('avg_inference_ms', 0),
|
| 245 |
+
'std_inference_ms': r.get('std_inference_ms', 0),
|
| 246 |
+
'min_inference_ms': r.get('min_inference_ms', 0),
|
| 247 |
+
'max_inference_ms': r.get('max_inference_ms', 0),
|
| 248 |
+
'fps': r.get('fps', 0),
|
| 249 |
+
'avg_memory_mb': r.get('avg_memory_mb', 0),
|
| 250 |
+
'load_time': r.get('load_time', 0),
|
| 251 |
+
'success': r['success']
|
| 252 |
+
})
|
| 253 |
+
|
| 254 |
+
print(f"\n📊 Results saved to {output_path}")
|
| 255 |
+
|
| 256 |
+
def main():
|
| 257 |
+
parser = argparse.ArgumentParser(description='Benchmark YOLO models for performance')
|
| 258 |
+
parser.add_argument('--dataset', type=str, default='model/dataset_strawberry_detect_v3',
|
| 259 |
+
help='Path to dataset for test images')
|
| 260 |
+
parser.add_argument('--img-size', type=int, default=640,
|
| 261 |
+
help='Input image size for inference')
|
| 262 |
+
parser.add_argument('--output', type=str, default='benchmark_results.csv',
|
| 263 |
+
help='Output CSV file for results')
|
| 264 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 265 |
+
help='Path to config file')
|
| 266 |
+
|
| 267 |
+
args = parser.parse_args()
|
| 268 |
+
|
| 269 |
+
# Load config
|
| 270 |
+
config = {}
|
| 271 |
+
if os.path.exists(args.config):
|
| 272 |
+
with open(args.config, 'r') as f:
|
| 273 |
+
config = yaml.safe_load(f)
|
| 274 |
+
|
| 275 |
+
# Get system info
|
| 276 |
+
system_info = get_system_info()
|
| 277 |
+
print("="*60)
|
| 278 |
+
print("SYSTEM INFORMATION")
|
| 279 |
+
print("="*60)
|
| 280 |
+
for key, value in system_info.items():
|
| 281 |
+
print(f"{key.replace('_', ' ').title():<20}: {value}")
|
| 282 |
+
|
| 283 |
+
# Define models to test
|
| 284 |
+
models_to_test = [
|
| 285 |
+
# Base YOLO models
|
| 286 |
+
{'name': 'YOLOv8n', 'path': 'yolov8n.pt', 'description': 'Ultralytics YOLOv8n'},
|
| 287 |
+
{'name': 'YOLOv8s', 'path': 'yolov8s.pt', 'description': 'Ultralytics YOLOv8s'},
|
| 288 |
+
{'name': 'YOLOv8m', 'path': 'yolov8m.pt', 'description': 'Ultralytics YOLOv8m'},
|
| 289 |
+
|
| 290 |
+
# Custom trained models
|
| 291 |
+
{'name': 'Strawberry YOLOv11n', 'path': 'model/weights/strawberry_yolov11n.pt', 'description': 'Custom trained on strawberry dataset'},
|
| 292 |
+
{'name': 'Strawberry YOLOv11n ONNX', 'path': 'model/weights/strawberry_yolov11n.onnx', 'description': 'ONNX export'},
|
| 293 |
+
|
| 294 |
+
# Ripeness detection models
|
| 295 |
+
{'name': 'Ripeness YOLOv11n', 'path': 'model/weights/ripeness_detection_yolov11n.pt', 'description': 'Ripeness detection model'},
|
| 296 |
+
{'name': 'Ripeness YOLOv11n ONNX', 'path': 'model/weights/ripeness_detection_yolov11n.onnx', 'description': 'ONNX export'},
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
# Check which models exist
|
| 300 |
+
existing_models = []
|
| 301 |
+
for model in models_to_test:
|
| 302 |
+
if os.path.exists(model['path']):
|
| 303 |
+
existing_models.append(model)
|
| 304 |
+
else:
|
| 305 |
+
print(f"⚠️ Model not found: {model['path']}")
|
| 306 |
+
|
| 307 |
+
if not existing_models:
|
| 308 |
+
print("❌ No models found for benchmarking.")
|
| 309 |
+
print("Please train a model first or download pretrained weights.")
|
| 310 |
+
sys.exit(1)
|
| 311 |
+
|
| 312 |
+
# Load test images
|
| 313 |
+
print(f"\n📷 Loading test images from {args.dataset}...")
|
| 314 |
+
test_images = load_test_images(args.dataset, max_images=50)
|
| 315 |
+
print(f" Loaded {len(test_images)} test images")
|
| 316 |
+
|
| 317 |
+
# Run benchmarks
|
| 318 |
+
print(f"\n🚀 Starting benchmarks...")
|
| 319 |
+
results = benchmark_all_models(existing_models, test_images, img_size=args.img_size)
|
| 320 |
+
|
| 321 |
+
# Print results
|
| 322 |
+
print_results_table(results)
|
| 323 |
+
|
| 324 |
+
# Save results
|
| 325 |
+
save_results_to_csv(results, args.output)
|
| 326 |
+
|
| 327 |
+
# Generate recommendations
|
| 328 |
+
print(f"\n💡 RECOMMENDATIONS FOR RASPBERRY PI 4B:")
|
| 329 |
+
print(f" 1. For fastest inference: Choose model with highest FPS")
|
| 330 |
+
print(f" 2. For memory-constrained environments: Choose smallest model")
|
| 331 |
+
print(f" 3. For best accuracy/speed tradeoff: Consider YOLOv8s")
|
| 332 |
+
print(f" 4. For edge deployment: Convert to TFLite INT8 for ~2-3x speedup")
|
| 333 |
+
|
| 334 |
+
# Check if we're on Raspberry Pi
|
| 335 |
+
if 'arm' in platform.machine().lower() or 'raspberry' in platform.system().lower():
|
| 336 |
+
print(f"\n🎯 Running on Raspberry Pi - results are accurate for deployment.")
|
| 337 |
+
else:
|
| 338 |
+
print(f"\n⚠️ Not running on Raspberry Pi - results are for reference only.")
|
| 339 |
+
print(f" Actual Raspberry Pi performance may be 2-5x slower.")
|
| 340 |
+
|
| 341 |
+
if __name__ == '__main__':
|
| 342 |
+
main()
|
scripts/collect_dataset.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
cap = cv2.VideoCapture(0)
|
| 5 |
+
|
| 6 |
+
if not cap.isOpened():
|
| 7 |
+
print("Cannot open camera")
|
| 8 |
+
exit()
|
| 9 |
+
|
| 10 |
+
good_dir = 'dataset/good'
|
| 11 |
+
bad_dir = 'dataset/bad'
|
| 12 |
+
os.makedirs(good_dir, exist_ok=True)
|
| 13 |
+
os.makedirs(bad_dir, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
good_count = len(os.listdir(good_dir))
|
| 16 |
+
bad_count = len(os.listdir(bad_dir))
|
| 17 |
+
|
| 18 |
+
print("Press 'g' to save as good, 'b' to save as bad, 'q' to quit")
|
| 19 |
+
|
| 20 |
+
while True:
|
| 21 |
+
ret, frame = cap.read()
|
| 22 |
+
if not ret:
|
| 23 |
+
print("Can't receive frame")
|
| 24 |
+
break
|
| 25 |
+
cv2.imshow('Dataset Collection', frame)
|
| 26 |
+
key = cv2.waitKey(1) & 0xFF
|
| 27 |
+
if key == ord('g'):
|
| 28 |
+
filename = f'{good_dir}/good_{good_count:04d}.jpg'
|
| 29 |
+
cv2.imwrite(filename, frame)
|
| 30 |
+
good_count += 1
|
| 31 |
+
print(f"Saved {filename}")
|
| 32 |
+
elif key == ord('b'):
|
| 33 |
+
filename = f'{bad_dir}/bad_{bad_count:04d}.jpg'
|
| 34 |
+
cv2.imwrite(filename, frame)
|
| 35 |
+
bad_count += 1
|
| 36 |
+
print(f"Saved {filename}")
|
| 37 |
+
elif key == ord('q'):
|
| 38 |
+
break
|
| 39 |
+
|
| 40 |
+
cap.release()
|
| 41 |
+
cv2.destroyAllWindows()
|
scripts/combine3.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dual_camera_15cm_tuned.py
|
| 2 |
+
"""
|
| 3 |
+
FINAL TUNED VERSION
|
| 4 |
+
- Baseline: 15.0 cm
|
| 5 |
+
- Range: 20-40 cm
|
| 6 |
+
- Calibration: Tuned to 0.82 based on your latest log (48.5cm -> 30cm).
|
| 7 |
+
"""
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from ultralytics import YOLO
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
# ----------------------------
|
| 14 |
+
# USER SETTINGS
|
| 15 |
+
# ----------------------------
|
| 16 |
+
CAM_A_ID = 2 # LEFT Camera
|
| 17 |
+
CAM_B_ID = 1 # RIGHT Camera
|
| 18 |
+
|
| 19 |
+
FRAME_W = 640
|
| 20 |
+
FRAME_H = 408
|
| 21 |
+
YOLO_MODEL_PATH = "strawberry.pt"
|
| 22 |
+
|
| 23 |
+
# --- GEOMETRY & CALIBRATION ---
|
| 24 |
+
BASELINE_CM = 15.0
|
| 25 |
+
FOCUS_DIST_CM = 30.0
|
| 26 |
+
|
| 27 |
+
# RE-CALIBRATED SCALAR
|
| 28 |
+
# Your Raw Reading is approx 36.6cm. Real is 30.0cm.
|
| 29 |
+
# 30.0 / 36.6 = ~0.82
|
| 30 |
+
DEPTH_SCALAR = 0.82
|
| 31 |
+
|
| 32 |
+
# Auto-calculate angle
|
| 33 |
+
val = (BASELINE_CM / 2.0) / FOCUS_DIST_CM
|
| 34 |
+
calc_yaw_deg = math.degrees(math.atan(val))
|
| 35 |
+
YAW_LEFT_DEG = calc_yaw_deg
|
| 36 |
+
YAW_RIGHT_DEG = -calc_yaw_deg
|
| 37 |
+
|
| 38 |
+
print(f"--- CONFIGURATION ---")
|
| 39 |
+
print(f"1. Baseline: {BASELINE_CM} cm")
|
| 40 |
+
print(f"2. Scalar: x{DEPTH_SCALAR} (Reducing raw estimate to match reality)")
|
| 41 |
+
print(f"3. REQUIRED ANGLE: +/- {calc_yaw_deg:.2f} degrees")
|
| 42 |
+
print(f"---------------------")
|
| 43 |
+
|
| 44 |
+
# Intrinsics
|
| 45 |
+
K_A = np.array([[629.10808758, 0.0, 347.20913144],
|
| 46 |
+
[0.0, 631.11321979, 277.5222819],
|
| 47 |
+
[0.0, 0.0, 1.0]], dtype=np.float64)
|
| 48 |
+
dist_A = np.array([-0.35469562, 0.10232556, -0.0005468, -0.00174671, 0.01546246], dtype=np.float64)
|
| 49 |
+
|
| 50 |
+
K_B = np.array([[1001.67997, 0.0, 367.736216],
|
| 51 |
+
[0.0, 996.698369, 312.866527],
|
| 52 |
+
[0.0, 0.0, 1.0]], dtype=np.float64)
|
| 53 |
+
dist_B = np.array([-0.49543094, 0.82826695, -0.00180861, -0.00362202, -1.42667838], dtype=np.float64)
|
| 54 |
+
|
| 55 |
+
# ----------------------------
|
| 56 |
+
# HELPERS
|
| 57 |
+
# ----------------------------
|
| 58 |
+
def capture_single(cam_id):
|
| 59 |
+
cap = cv2.VideoCapture(cam_id, cv2.CAP_DSHOW)
|
| 60 |
+
if not cap.isOpened(): return None
|
| 61 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_W)
|
| 62 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_H)
|
| 63 |
+
for _ in range(5): cap.read() # Warmup
|
| 64 |
+
ret, frame = cap.read()
|
| 65 |
+
cap.release()
|
| 66 |
+
return frame
|
| 67 |
+
|
| 68 |
+
def build_undistort_maps(K, dist):
|
| 69 |
+
newK, _ = cv2.getOptimalNewCameraMatrix(K, dist, (FRAME_W, FRAME_H), 1.0)
|
| 70 |
+
mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, newK, (FRAME_W, FRAME_H), cv2.CV_32FC1)
|
| 71 |
+
return mapx, mapy, newK
|
| 72 |
+
|
| 73 |
+
def detect_on_image(model, img):
|
| 74 |
+
results = model(img, verbose=False)[0]
|
| 75 |
+
dets = []
|
| 76 |
+
for box in results.boxes:
|
| 77 |
+
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
|
| 78 |
+
cx = int((x1 + x2) / 2)
|
| 79 |
+
cy = int((y1 + y2) / 2)
|
| 80 |
+
conf = float(box.conf[0])
|
| 81 |
+
cls = int(box.cls[0])
|
| 82 |
+
name = model.names.get(cls, str(cls))
|
| 83 |
+
dets.append({'x1':x1,'y1':y1,'x2':x2,'y2':y2,'cx':cx,'cy':cy,'conf':conf,'cls':cls,'name':name})
|
| 84 |
+
return sorted(dets, key=lambda d: d['cx'])
|
| 85 |
+
|
| 86 |
+
def match_stereo(detL, detR):
|
| 87 |
+
matches = []
|
| 88 |
+
usedR = set()
|
| 89 |
+
for l in detL:
|
| 90 |
+
best_idx = -1
|
| 91 |
+
best_score = 9999
|
| 92 |
+
for i, r in enumerate(detR):
|
| 93 |
+
if i in usedR: continue
|
| 94 |
+
if l['cls'] != r['cls']: continue
|
| 95 |
+
dy = abs(l['cy'] - r['cy'])
|
| 96 |
+
if dy > 60: continue
|
| 97 |
+
if dy < best_score:
|
| 98 |
+
best_score = dy
|
| 99 |
+
best_idx = i
|
| 100 |
+
if best_idx != -1:
|
| 101 |
+
matches.append((l, detR[best_idx]))
|
| 102 |
+
usedR.add(best_idx)
|
| 103 |
+
return matches
|
| 104 |
+
|
| 105 |
+
# --- 3D MATH ---
|
| 106 |
+
def yaw_to_R_deg(yaw_deg):
|
| 107 |
+
y = math.radians(yaw_deg)
|
| 108 |
+
cy = math.cos(y); sy = math.sin(y)
|
| 109 |
+
return np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=np.float64)
|
| 110 |
+
|
| 111 |
+
def build_projection_matrices(newK_A, newK_B, yaw_L, yaw_R, baseline):
|
| 112 |
+
R_L = yaw_to_R_deg(yaw_L)
|
| 113 |
+
R_R = yaw_to_R_deg(yaw_R)
|
| 114 |
+
R_W2A = R_L.T
|
| 115 |
+
t_W2A = np.zeros((3,1))
|
| 116 |
+
R_W2B = R_R.T
|
| 117 |
+
C_B_world = np.array([[baseline], [0.0], [0.0]])
|
| 118 |
+
t_W2B = -R_W2B @ C_B_world
|
| 119 |
+
P1 = newK_A @ np.hstack((R_W2A, t_W2A))
|
| 120 |
+
P2 = newK_B @ np.hstack((R_W2B, t_W2B))
|
| 121 |
+
return P1, P2
|
| 122 |
+
|
| 123 |
+
def triangulate_matrix(dL, dR, P1, P2):
|
| 124 |
+
ptsL = np.array([[float(dL['cx'])],[float(dL['cy'])]], dtype=np.float64)
|
| 125 |
+
ptsR = np.array([[float(dR['cx'])],[float(dR['cy'])]], dtype=np.float64)
|
| 126 |
+
|
| 127 |
+
Xh = cv2.triangulatePoints(P1, P2, ptsL, ptsR)
|
| 128 |
+
Xh /= Xh[3]
|
| 129 |
+
|
| 130 |
+
X = float(Xh[0].item())
|
| 131 |
+
Y = float(Xh[1].item())
|
| 132 |
+
Z_raw = float(Xh[2].item())
|
| 133 |
+
|
| 134 |
+
return X, Y, Z_raw * DEPTH_SCALAR
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
print("[INFO] Loading YOLO...")
|
| 138 |
+
model = YOLO(YOLO_MODEL_PATH)
|
| 139 |
+
|
| 140 |
+
print(f"[INFO] Capturing...")
|
| 141 |
+
frameA = capture_single(CAM_A_ID)
|
| 142 |
+
frameB = capture_single(CAM_B_ID)
|
| 143 |
+
if frameA is None or frameB is None: return
|
| 144 |
+
|
| 145 |
+
mapAx, mapAy, newKA = build_undistort_maps(K_A, dist_A)
|
| 146 |
+
mapBx, mapBy, newKB = build_undistort_maps(K_B, dist_B)
|
| 147 |
+
undA = cv2.remap(frameA, mapAx, mapAy, cv2.INTER_LINEAR)
|
| 148 |
+
undB = cv2.remap(frameB, mapBx, mapBy, cv2.INTER_LINEAR)
|
| 149 |
+
|
| 150 |
+
detA = detect_on_image(model, undA)
|
| 151 |
+
detB = detect_on_image(model, undB)
|
| 152 |
+
|
| 153 |
+
matches = match_stereo(detA, detB)
|
| 154 |
+
print(f"--- Matches found: {len(matches)} ---")
|
| 155 |
+
|
| 156 |
+
P1, P2 = build_projection_matrices(newKA, newKB, YAW_LEFT_DEG, YAW_RIGHT_DEG, BASELINE_CM)
|
| 157 |
+
|
| 158 |
+
combo = np.hstack((undA, undB))
|
| 159 |
+
|
| 160 |
+
for l, r in matches:
|
| 161 |
+
XYZ = triangulate_matrix(l, r, P1, P2)
|
| 162 |
+
X,Y,Z = XYZ
|
| 163 |
+
|
| 164 |
+
label = f"Z={Z:.1f}cm"
|
| 165 |
+
print(f"Target ({l['name']}): {label} (X={X:.1f}, Y={Y:.1f})")
|
| 166 |
+
|
| 167 |
+
cv2.line(combo, (l['cx'], l['cy']), (r['cx']+FRAME_W, r['cy']), (0,255,0), 2)
|
| 168 |
+
cv2.rectangle(combo, (l['x1'], l['y1']), (l['x2'], l['y2']), (0,0,255), 2)
|
| 169 |
+
cv2.rectangle(combo, (r['x1']+FRAME_W, r['y1']), (r['x2']+FRAME_W, r['y2']), (0,0,255), 2)
|
| 170 |
+
cv2.putText(combo, label, (l['cx'], l['cy']-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,255,255), 2)
|
| 171 |
+
|
| 172 |
+
cv2.imshow("Tuned Depth Result", combo)
|
| 173 |
+
cv2.waitKey(0)
|
| 174 |
+
cv2.destroyAllWindows()
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
scripts/complete_final_labeling.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Complete the final batch of ripeness labeling using conservative color analysis.
|
| 4 |
+
This script processes the remaining 46 images with higher confidence thresholds.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import cv2
|
| 15 |
+
|
| 16 |
+
def analyze_ripeness_conservative(image_path, confidence_threshold=0.8):
|
| 17 |
+
"""
|
| 18 |
+
Conservative ripeness analysis with higher confidence thresholds.
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
# Load and convert image
|
| 22 |
+
img = cv2.imread(str(image_path))
|
| 23 |
+
if img is None:
|
| 24 |
+
return None, 0.0
|
| 25 |
+
|
| 26 |
+
# Convert to HSV for better color analysis
|
| 27 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 28 |
+
|
| 29 |
+
# Define color ranges for strawberry ripeness (more conservative)
|
| 30 |
+
# Red ranges (ripe strawberries)
|
| 31 |
+
red_lower1 = np.array([0, 50, 50])
|
| 32 |
+
red_upper1 = np.array([10, 255, 255])
|
| 33 |
+
red_lower2 = np.array([170, 50, 50])
|
| 34 |
+
red_upper2 = np.array([180, 255, 255])
|
| 35 |
+
|
| 36 |
+
# Green ranges (unripe strawberries)
|
| 37 |
+
green_lower = np.array([40, 40, 40])
|
| 38 |
+
green_upper = np.array([80, 255, 255])
|
| 39 |
+
|
| 40 |
+
# Yellow/orange ranges (overripe strawberries)
|
| 41 |
+
yellow_lower = np.array([15, 50, 50])
|
| 42 |
+
yellow_upper = np.array([35, 255, 255])
|
| 43 |
+
|
| 44 |
+
# Create masks
|
| 45 |
+
red_mask1 = cv2.inRange(hsv, red_lower1, red_upper1)
|
| 46 |
+
red_mask2 = cv2.inRange(hsv, red_lower2, red_upper2)
|
| 47 |
+
red_mask = cv2.bitwise_or(red_mask1, red_mask2)
|
| 48 |
+
|
| 49 |
+
green_mask = cv2.inRange(hsv, green_lower, green_upper)
|
| 50 |
+
yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)
|
| 51 |
+
|
| 52 |
+
# Calculate percentages
|
| 53 |
+
total_pixels = img.shape[0] * img.shape[1]
|
| 54 |
+
red_percentage = np.sum(red_mask > 0) / total_pixels
|
| 55 |
+
green_percentage = np.sum(green_mask > 0) / total_pixels
|
| 56 |
+
yellow_percentage = np.sum(yellow_mask > 0) / total_pixels
|
| 57 |
+
|
| 58 |
+
# Conservative classification logic
|
| 59 |
+
if red_percentage > 0.35 and red_percentage > green_percentage and red_percentage > yellow_percentage:
|
| 60 |
+
return "ripe", red_percentage
|
| 61 |
+
elif green_percentage > 0.25 and green_percentage > red_percentage and green_percentage > yellow_percentage:
|
| 62 |
+
return "unripe", green_percentage
|
| 63 |
+
elif yellow_percentage > 0.20 and yellow_percentage > red_percentage and yellow_percentage > green_percentage:
|
| 64 |
+
return "overripe", yellow_percentage
|
| 65 |
+
else:
|
| 66 |
+
# If no clear dominant color, use the highest percentage
|
| 67 |
+
max_percentage = max(red_percentage, green_percentage, yellow_percentage)
|
| 68 |
+
if max_percentage == red_percentage:
|
| 69 |
+
return "ripe", red_percentage
|
| 70 |
+
elif max_percentage == green_percentage:
|
| 71 |
+
return "unripe", green_percentage
|
| 72 |
+
else:
|
| 73 |
+
return "overripe", yellow_percentage
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"Error analyzing {image_path}: {e}")
|
| 77 |
+
return None, 0.0
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
"""Complete the final batch of labeling."""
|
| 81 |
+
|
| 82 |
+
# Paths
|
| 83 |
+
to_label_dir = Path("model/ripeness_manual_dataset/to_label")
|
| 84 |
+
unripe_dir = Path("model/ripeness_manual_dataset/unripe")
|
| 85 |
+
ripe_dir = Path("model/ripeness_manual_dataset/ripe")
|
| 86 |
+
overripe_dir = Path("model/ripeness_manual_dataset/overripe")
|
| 87 |
+
|
| 88 |
+
# Get remaining files
|
| 89 |
+
remaining_files = list(to_label_dir.glob("*.jpg"))
|
| 90 |
+
|
| 91 |
+
if not remaining_files:
|
| 92 |
+
print("No images remaining to label!")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
print(f"=== FINAL BATCH LABELING ===")
|
| 96 |
+
print(f"Processing {len(remaining_files)} remaining images with conservative analysis...")
|
| 97 |
+
|
| 98 |
+
results = {
|
| 99 |
+
"timestamp": datetime.now().isoformat(),
|
| 100 |
+
"total_processed": len(remaining_files),
|
| 101 |
+
"unripe": 0,
|
| 102 |
+
"ripe": 0,
|
| 103 |
+
"overripe": 0,
|
| 104 |
+
"unknown": 0,
|
| 105 |
+
"images": []
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
for i, image_path in enumerate(remaining_files, 1):
|
| 109 |
+
print(f"Processing {i}/{len(remaining_files)}: {image_path.name}")
|
| 110 |
+
|
| 111 |
+
# Analyze with conservative threshold
|
| 112 |
+
label, confidence = analyze_ripeness_conservative(image_path, confidence_threshold=0.8)
|
| 113 |
+
|
| 114 |
+
if label:
|
| 115 |
+
# Move to appropriate directory
|
| 116 |
+
if label == "unripe":
|
| 117 |
+
dest = unripe_dir / image_path.name
|
| 118 |
+
results["unripe"] += 1
|
| 119 |
+
elif label == "ripe":
|
| 120 |
+
dest = ripe_dir / image_path.name
|
| 121 |
+
results["ripe"] += 1
|
| 122 |
+
elif label == "overripe":
|
| 123 |
+
dest = overripe_dir / image_path.name
|
| 124 |
+
results["overripe"] += 1
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
shutil.move(str(image_path), str(dest))
|
| 128 |
+
print(f" ✅ {label} (confidence: {confidence:.2f})")
|
| 129 |
+
results["images"].append({
|
| 130 |
+
"filename": image_path.name,
|
| 131 |
+
"label": label,
|
| 132 |
+
"confidence": confidence
|
| 133 |
+
})
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f" ❌ Error moving file: {e}")
|
| 136 |
+
results["unknown"] += 1
|
| 137 |
+
else:
|
| 138 |
+
print(f" ⚠️ Analysis failed")
|
| 139 |
+
results["unknown"] += 1
|
| 140 |
+
|
| 141 |
+
# Save results
|
| 142 |
+
results_file = f"model/ripeness_manual_dataset/final_labeling_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 143 |
+
with open(results_file, 'w') as f:
|
| 144 |
+
json.dump(results, f, indent=2)
|
| 145 |
+
|
| 146 |
+
# Print final summary
|
| 147 |
+
print(f"\n=== FINAL LABELING COMPLETE ===")
|
| 148 |
+
print(f"unripe: {results['unripe']} images")
|
| 149 |
+
print(f"ripe: {results['ripe']} images")
|
| 150 |
+
print(f"overripe: {results['overripe']} images")
|
| 151 |
+
print(f"unknown: {results['unknown']} images")
|
| 152 |
+
print(f"Total processed: {results['total_processed']} images")
|
| 153 |
+
|
| 154 |
+
# Calculate final dataset statistics
|
| 155 |
+
total_unripe = len(list(unripe_dir.glob("*.jpg")))
|
| 156 |
+
total_ripe = len(list(ripe_dir.glob("*.jpg")))
|
| 157 |
+
total_overripe = len(list(overripe_dir.glob("*.jpg")))
|
| 158 |
+
remaining = len(list(to_label_dir.glob("*.jpg")))
|
| 159 |
+
total_dataset = total_unripe + total_ripe + total_overripe + remaining
|
| 160 |
+
|
| 161 |
+
completion_percentage = (total_dataset - remaining) / total_dataset * 100
|
| 162 |
+
|
| 163 |
+
print(f"\n=== FINAL DATASET STATUS ===")
|
| 164 |
+
print(f"unripe: {total_unripe} images")
|
| 165 |
+
print(f"ripe: {total_ripe} images")
|
| 166 |
+
print(f"overripe: {total_overripe} images")
|
| 167 |
+
print(f"to_label: {remaining} images")
|
| 168 |
+
print(f"TOTAL: {total_dataset} images")
|
| 169 |
+
print(f"Completion: {completion_percentage:.1f}%")
|
| 170 |
+
|
| 171 |
+
if remaining == 0:
|
| 172 |
+
print(f"\n🎉 DATASET LABELING 100% COMPLETE! 🎉")
|
| 173 |
+
print(f"Total labeled images: {total_dataset}")
|
| 174 |
+
else:
|
| 175 |
+
print(f"\n⚠️ {remaining} images still need manual review")
|
| 176 |
+
|
| 177 |
+
print(f"Results saved to: {results_file}")
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|
scripts/convert_tflite.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Convert trained model to TensorFlow Lite format with optional quantization.
|
| 4 |
+
Supports conversion from Keras (.h5) and PyTorch (.pt) models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
def convert_keras_to_tflite(input_path, output_path, quantization=None, optimize_for_size=False):
|
| 14 |
+
"""Convert Keras model to TFLite format."""
|
| 15 |
+
if not os.path.exists(input_path):
|
| 16 |
+
raise FileNotFoundError(f"Input model not found: {input_path}")
|
| 17 |
+
|
| 18 |
+
print(f"Loading Keras model from {input_path}...")
|
| 19 |
+
model = tf.keras.models.load_model(input_path)
|
| 20 |
+
|
| 21 |
+
# Configure converter
|
| 22 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
| 23 |
+
|
| 24 |
+
# Optimization options
|
| 25 |
+
if optimize_for_size:
|
| 26 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 27 |
+
|
| 28 |
+
# Quantization options
|
| 29 |
+
if quantization == 'int8':
|
| 30 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 31 |
+
converter.target_spec.supported_types = [tf.int8]
|
| 32 |
+
converter.inference_input_type = tf.int8
|
| 33 |
+
converter.inference_output_type = tf.int8
|
| 34 |
+
print("Using INT8 quantization")
|
| 35 |
+
elif quantization == 'float16':
|
| 36 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 37 |
+
converter.target_spec.supported_types = [tf.float16]
|
| 38 |
+
print("Using Float16 quantization")
|
| 39 |
+
elif quantization == 'dynamic_range':
|
| 40 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 41 |
+
print("Using dynamic range quantization")
|
| 42 |
+
|
| 43 |
+
# Convert model
|
| 44 |
+
print("Converting model to TFLite...")
|
| 45 |
+
tflite_model = converter.convert()
|
| 46 |
+
|
| 47 |
+
# Save model
|
| 48 |
+
with open(output_path, 'wb') as f:
|
| 49 |
+
f.write(tflite_model)
|
| 50 |
+
|
| 51 |
+
print(f"TFLite model saved to {output_path}")
|
| 52 |
+
|
| 53 |
+
# Print model size
|
| 54 |
+
size_kb = os.path.getsize(output_path) / 1024
|
| 55 |
+
print(f"Model size: {size_kb:.2f} KB")
|
| 56 |
+
|
| 57 |
+
return output_path
|
| 58 |
+
|
| 59 |
+
def convert_pytorch_to_tflite(input_path, output_path, quantization=None):
|
| 60 |
+
"""Convert PyTorch model to TFLite format (placeholder for future implementation)."""
|
| 61 |
+
print("PyTorch to TFLite conversion not yet implemented.")
|
| 62 |
+
print("Please convert PyTorch model to ONNX first, then use TensorFlow's converter.")
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
parser = argparse.ArgumentParser(description='Convert model to TensorFlow Lite format')
|
| 67 |
+
parser.add_argument('--input', type=str, default='strawberry_model.h5',
|
| 68 |
+
help='Input model path (Keras .h5 or PyTorch .pt)')
|
| 69 |
+
parser.add_argument('--output', type=str, default='strawberry_model.tflite',
|
| 70 |
+
help='Output TFLite model path')
|
| 71 |
+
parser.add_argument('--quantization', type=str, choices=['int8', 'float16', 'dynamic_range', 'none'],
|
| 72 |
+
default='none', help='Quantization method (default: none)')
|
| 73 |
+
parser.add_argument('--optimize-for-size', action='store_true',
|
| 74 |
+
help='Apply size optimization (reduces model size)')
|
| 75 |
+
parser.add_argument('--model-type', type=str, choices=['keras', 'pytorch'], default='keras',
|
| 76 |
+
help='Type of input model (default: keras)')
|
| 77 |
+
|
| 78 |
+
args = parser.parse_args()
|
| 79 |
+
|
| 80 |
+
# Validate input file exists
|
| 81 |
+
if not Path(args.input).exists():
|
| 82 |
+
print(f"Error: Input model '{args.input}' not found.")
|
| 83 |
+
print("Available model files in current directory:")
|
| 84 |
+
for f in os.listdir('.'):
|
| 85 |
+
if f.endswith(('.h5', '.pt', '.pth', '.onnx')):
|
| 86 |
+
print(f" - {f}")
|
| 87 |
+
sys.exit(1)
|
| 88 |
+
|
| 89 |
+
# Create output directory if needed
|
| 90 |
+
output_dir = os.path.dirname(args.output)
|
| 91 |
+
if output_dir and not os.path.exists(output_dir):
|
| 92 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
if args.model_type == 'keras':
|
| 96 |
+
convert_keras_to_tflite(
|
| 97 |
+
input_path=args.input,
|
| 98 |
+
output_path=args.output,
|
| 99 |
+
quantization=args.quantization if args.quantization != 'none' else None,
|
| 100 |
+
optimize_for_size=args.optimize_for_size
|
| 101 |
+
)
|
| 102 |
+
elif args.model_type == 'pytorch':
|
| 103 |
+
convert_pytorch_to_tflite(
|
| 104 |
+
input_path=args.input,
|
| 105 |
+
output_path=args.output,
|
| 106 |
+
quantization=args.quantization if args.quantization != 'none' else None
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
print(f"Unsupported model type: {args.model_type}")
|
| 110 |
+
sys.exit(1)
|
| 111 |
+
|
| 112 |
+
print("Conversion completed successfully!")
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Error during conversion: {e}")
|
| 116 |
+
sys.exit(1)
|
| 117 |
+
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
main()
|
scripts/data/preprocess_strawberry_dataset.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
class_name = 'strawberry'
|
| 4 |
+
|
| 5 |
+
def rewrite_single_class_data_yaml(dataset_dir, class_name='strawberry'):
|
| 6 |
+
dataset_dir = Path(dataset_dir)
|
| 7 |
+
data_yaml_path = dataset_dir / 'data.yaml'
|
| 8 |
+
if not data_yaml_path.exists():
|
| 9 |
+
print('⚠️ data.yaml not found, skipping rewrite.')
|
| 10 |
+
return
|
| 11 |
+
|
| 12 |
+
train_path = dataset_dir / 'train' / 'images'
|
| 13 |
+
val_path = dataset_dir / 'valid' / 'images'
|
| 14 |
+
test_path = dataset_dir / 'test' / 'images'
|
| 15 |
+
|
| 16 |
+
content_lines = [
|
| 17 |
+
'# Strawberry-only dataset',
|
| 18 |
+
f'train: {train_path}',
|
| 19 |
+
f'val: {val_path}',
|
| 20 |
+
f"test: {test_path if test_path.exists() else ''}",
|
| 21 |
+
'',
|
| 22 |
+
'nc: 1',
|
| 23 |
+
f"names: ['{class_name}']",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
data_yaml_path.write_text('\n'.join(content_lines) + '\n')
|
| 27 |
+
print(f"✅ data.yaml updated for single-class training ({class_name}).")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def enforce_single_class_dataset(dataset_dir, target_class=0, class_name='strawberry'):
|
| 31 |
+
dataset_dir = Path(dataset_dir)
|
| 32 |
+
stats = {'split_kept': {}, 'labels_removed': 0, 'images_removed': 0}
|
| 33 |
+
allowed_ext = ['.jpg', '.jpeg', '.png']
|
| 34 |
+
|
| 35 |
+
for split in ['train', 'valid', 'test']:
|
| 36 |
+
labels_dir = dataset_dir / split / 'labels'
|
| 37 |
+
images_dir = dataset_dir / split / 'images'
|
| 38 |
+
if not labels_dir.exists():
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
kept = 0
|
| 42 |
+
for label_path in labels_dir.glob('*.txt'):
|
| 43 |
+
kept_lines = []
|
| 44 |
+
for raw_line in label_path.read_text().splitlines():
|
| 45 |
+
line = raw_line.strip()
|
| 46 |
+
if not line:
|
| 47 |
+
continue
|
| 48 |
+
parts = line.split()
|
| 49 |
+
if not parts:
|
| 50 |
+
continue
|
| 51 |
+
try:
|
| 52 |
+
class_id = int(parts[0])
|
| 53 |
+
except ValueError:
|
| 54 |
+
continue
|
| 55 |
+
if class_id == target_class:
|
| 56 |
+
kept_lines.append(line)
|
| 57 |
+
|
| 58 |
+
if kept_lines:
|
| 59 |
+
label_path.write_text('\n'.join(kept_lines) + '\n')
|
| 60 |
+
kept += len(kept_lines)
|
| 61 |
+
else:
|
| 62 |
+
label_path.unlink()
|
| 63 |
+
stats['labels_removed'] += 1
|
| 64 |
+
for ext in allowed_ext:
|
| 65 |
+
candidate = images_dir / f"{label_path.stem}{ext}"
|
| 66 |
+
if candidate.exists():
|
| 67 |
+
candidate.unlink()
|
| 68 |
+
stats['images_removed'] += 1
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
stats['split_kept'][split] = kept
|
| 72 |
+
|
| 73 |
+
rewrite_single_class_data_yaml(dataset_dir, class_name)
|
| 74 |
+
|
| 75 |
+
print('\n🍓 Strawberry-only filtering summary:')
|
| 76 |
+
for split, count in stats['split_kept'].items():
|
| 77 |
+
print(f" {split}: {count} annotations kept")
|
| 78 |
+
print(f" Label files removed: {stats['labels_removed']}")
|
| 79 |
+
print(f" Images removed (non-strawberry or empty labels): {stats['images_removed']}")
|
| 80 |
+
|
| 81 |
+
return stats
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
# Example usage - replace with your dataset path
|
| 86 |
+
dataset_path = "path/to/your/dataset" # Update this path
|
| 87 |
+
|
| 88 |
+
if dataset_path and Path(dataset_path).exists():
|
| 89 |
+
strawberry_stats = enforce_single_class_dataset(dataset_path, target_class=0, class_name=class_name)
|
| 90 |
+
else:
|
| 91 |
+
print("Please set a valid dataset_path variable.")
|
scripts/detect_realtime.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real-time strawberry detection/classification using TFLite model.
|
| 4 |
+
Supports both binary classification (good/bad) and YOLOv8 detection.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
def load_tflite_model(model_path):
|
| 15 |
+
"""Load TFLite model and allocate tensors."""
|
| 16 |
+
if not Path(model_path).exists():
|
| 17 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 18 |
+
|
| 19 |
+
interpreter = tf.lite.Interpreter(model_path=model_path)
|
| 20 |
+
interpreter.allocate_tensors()
|
| 21 |
+
return interpreter
|
| 22 |
+
|
| 23 |
+
def get_model_details(interpreter):
|
| 24 |
+
"""Get input and output details of the TFLite model."""
|
| 25 |
+
input_details = interpreter.get_input_details()
|
| 26 |
+
output_details = interpreter.get_output_details()
|
| 27 |
+
return input_details, output_details
|
| 28 |
+
|
| 29 |
+
def preprocess_image(image, input_shape):
|
| 30 |
+
"""Preprocess image for model inference."""
|
| 31 |
+
height, width = input_shape[1:3] if len(input_shape) == 4 else input_shape[1:3]
|
| 32 |
+
img = cv2.resize(image, (width, height))
|
| 33 |
+
img = img / 255.0 # Normalize to [0,1]
|
| 34 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
| 35 |
+
return img
|
| 36 |
+
|
| 37 |
+
def run_inference(interpreter, input_details, output_details, preprocessed_img):
|
| 38 |
+
"""Run inference on preprocessed image."""
|
| 39 |
+
interpreter.set_tensor(input_details[0]['index'], preprocessed_img)
|
| 40 |
+
interpreter.invoke()
|
| 41 |
+
return interpreter.get_tensor(output_details[0]['index'])
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser(description='Real-time strawberry detection/classification')
|
| 45 |
+
parser.add_argument('--model', type=str, default='strawberry_model.tflite',
|
| 46 |
+
help='Path to TFLite model (default: strawberry_model.tflite)')
|
| 47 |
+
parser.add_argument('--camera', type=int, default=0,
|
| 48 |
+
help='Camera index (default: 0)')
|
| 49 |
+
parser.add_argument('--threshold', type=float, default=0.5,
|
| 50 |
+
help='Confidence threshold for binary classification (default: 0.5)')
|
| 51 |
+
parser.add_argument('--input-size', type=int, default=224,
|
| 52 |
+
help='Input image size (width=height) for model (default: 224)')
|
| 53 |
+
parser.add_argument('--mode', choices=['classification', 'detection'], default='classification',
|
| 54 |
+
help='Inference mode: classification (good/bad) or detection (YOLO)')
|
| 55 |
+
parser.add_argument('--verbose', action='store_true',
|
| 56 |
+
help='Print detailed inference information')
|
| 57 |
+
|
| 58 |
+
args = parser.parse_args()
|
| 59 |
+
|
| 60 |
+
# Load model
|
| 61 |
+
try:
|
| 62 |
+
interpreter = load_tflite_model(args.model)
|
| 63 |
+
input_details, output_details = get_model_details(interpreter)
|
| 64 |
+
input_shape = input_details[0]['shape']
|
| 65 |
+
if args.verbose:
|
| 66 |
+
print(f"Model loaded: {args.model}")
|
| 67 |
+
print(f"Input shape: {input_shape}")
|
| 68 |
+
print(f"Output details: {output_details[0]}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error loading model: {e}")
|
| 71 |
+
sys.exit(1)
|
| 72 |
+
|
| 73 |
+
# Open camera
|
| 74 |
+
cap = cv2.VideoCapture(args.camera)
|
| 75 |
+
if not cap.isOpened():
|
| 76 |
+
print(f"Cannot open camera index {args.camera}")
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
print(f"Starting real-time inference (mode: {args.mode})")
|
| 80 |
+
print("Press 'q' to quit, 's' to save current frame")
|
| 81 |
+
|
| 82 |
+
while True:
|
| 83 |
+
ret, frame = cap.read()
|
| 84 |
+
if not ret:
|
| 85 |
+
print("Failed to capture frame")
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
# Preprocess
|
| 89 |
+
preprocessed = preprocess_image(frame, input_shape)
|
| 90 |
+
|
| 91 |
+
# Inference
|
| 92 |
+
predictions = run_inference(interpreter, input_details, output_details, preprocessed)
|
| 93 |
+
|
| 94 |
+
# Process predictions based on mode
|
| 95 |
+
if args.mode == 'classification':
|
| 96 |
+
# Binary classification: single probability
|
| 97 |
+
confidence = predictions[0][0]
|
| 98 |
+
label = 'Good' if confidence > args.threshold else 'Bad'
|
| 99 |
+
display_text = f'{label}: {confidence:.2f}'
|
| 100 |
+
color = (0, 255, 0) if confidence > args.threshold else (0, 0, 255)
|
| 101 |
+
else:
|
| 102 |
+
# Detection mode (YOLO) - placeholder for future implementation
|
| 103 |
+
display_text = 'Detection mode not yet implemented'
|
| 104 |
+
color = (255, 255, 0)
|
| 105 |
+
|
| 106 |
+
# Display
|
| 107 |
+
cv2.putText(frame, display_text, (10, 30),
|
| 108 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
|
| 109 |
+
cv2.imshow('Strawberry Detection', frame)
|
| 110 |
+
|
| 111 |
+
key = cv2.waitKey(1) & 0xFF
|
| 112 |
+
if key == ord('q'):
|
| 113 |
+
break
|
| 114 |
+
elif key == ord('s'):
|
| 115 |
+
filename = f'capture_{cv2.getTickCount()}.jpg'
|
| 116 |
+
cv2.imwrite(filename, frame)
|
| 117 |
+
print(f"Frame saved as {filename}")
|
| 118 |
+
|
| 119 |
+
cap.release()
|
| 120 |
+
cv2.destroyAllWindows()
|
| 121 |
+
print("Real-time detection stopped.")
|
| 122 |
+
|
| 123 |
+
if __name__ == '__main__':
|
| 124 |
+
main()
|
scripts/download_dataset.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from roboflow import Roboflow
|
| 2 |
+
|
| 3 |
+
# Replace with your Roboflow API key and project details
|
| 4 |
+
rf = Roboflow(api_key="YOUR_API_KEY")
|
| 5 |
+
project = rf.workspace("YOUR_WORKSPACE").project("YOUR_PROJECT")
|
| 6 |
+
dataset = project.version("YOUR_VERSION").download("folder")
|
| 7 |
+
|
| 8 |
+
print("Dataset downloaded to dataset/")
|
scripts/export_onnx.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Export YOLOv8/v11 model to ONNX format for optimized inference.
|
| 4 |
+
Supports dynamic axes, batch size, and different opset versions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import yaml
|
| 12 |
+
from ultralytics import YOLO
|
| 13 |
+
|
| 14 |
+
def load_config(config_path="config.yaml"):
|
| 15 |
+
"""Load configuration from YAML file."""
|
| 16 |
+
if not os.path.exists(config_path):
|
| 17 |
+
print(f"Warning: Config file {config_path} not found. Using defaults.")
|
| 18 |
+
return {}
|
| 19 |
+
with open(config_path, 'r') as f:
|
| 20 |
+
return yaml.safe_load(f)
|
| 21 |
+
|
| 22 |
+
def export_to_onnx(
|
| 23 |
+
model_path,
|
| 24 |
+
output_path=None,
|
| 25 |
+
imgsz=640,
|
| 26 |
+
batch=1,
|
| 27 |
+
dynamic=False,
|
| 28 |
+
simplify=True,
|
| 29 |
+
opset=12,
|
| 30 |
+
half=False
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Export YOLO model to ONNX format.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_path: Path to .pt model file
|
| 37 |
+
output_path: Output .onnx file path (optional, auto-generated if None)
|
| 38 |
+
imgsz: Input image size
|
| 39 |
+
batch: Batch size (1 for static, -1 for dynamic)
|
| 40 |
+
dynamic: Enable dynamic axes (batch, height, width)
|
| 41 |
+
simplify: Apply ONNX simplifier
|
| 42 |
+
opset: ONNX opset version
|
| 43 |
+
half: FP16 quantization
|
| 44 |
+
"""
|
| 45 |
+
print(f"Loading model from {model_path}")
|
| 46 |
+
model = YOLO(model_path)
|
| 47 |
+
|
| 48 |
+
# Determine output path
|
| 49 |
+
if output_path is None:
|
| 50 |
+
model_name = Path(model_path).stem
|
| 51 |
+
output_dir = Path(model_path).parent / "exports"
|
| 52 |
+
output_dir.mkdir(exist_ok=True)
|
| 53 |
+
output_path = str(output_dir / f"{model_name}.onnx")
|
| 54 |
+
|
| 55 |
+
# Create output directory if needed
|
| 56 |
+
output_dir = os.path.dirname(output_path)
|
| 57 |
+
if output_dir and not os.path.exists(output_dir):
|
| 58 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
# Prepare export arguments
|
| 61 |
+
export_args = {
|
| 62 |
+
'format': 'onnx',
|
| 63 |
+
'imgsz': imgsz,
|
| 64 |
+
'batch': batch,
|
| 65 |
+
'simplify': simplify,
|
| 66 |
+
'opset': opset,
|
| 67 |
+
'half': half,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if dynamic:
|
| 71 |
+
export_args['dynamic'] = True
|
| 72 |
+
|
| 73 |
+
print(f"Exporting to ONNX with args: {export_args}")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Export model
|
| 77 |
+
exported_path = model.export(**export_args)
|
| 78 |
+
|
| 79 |
+
# The exported file will be in the same directory as the model
|
| 80 |
+
# Find the .onnx file that was just created
|
| 81 |
+
exported_files = list(Path(model_path).parent.glob("*.onnx"))
|
| 82 |
+
if exported_files:
|
| 83 |
+
latest_onnx = max(exported_files, key=os.path.getctime)
|
| 84 |
+
# Move to desired output path if different
|
| 85 |
+
if str(latest_onnx) != output_path:
|
| 86 |
+
import shutil
|
| 87 |
+
shutil.move(str(latest_onnx), output_path)
|
| 88 |
+
print(f"Model moved to {output_path}")
|
| 89 |
+
else:
|
| 90 |
+
print(f"Model exported to {output_path}")
|
| 91 |
+
else:
|
| 92 |
+
# Try to find the exported file in the current directory
|
| 93 |
+
exported_files = list(Path('.').glob("*.onnx"))
|
| 94 |
+
if exported_files:
|
| 95 |
+
latest_onnx = max(exported_files, key=os.path.getctime)
|
| 96 |
+
if str(latest_onnx) != output_path:
|
| 97 |
+
import shutil
|
| 98 |
+
shutil.move(str(latest_onnx), output_path)
|
| 99 |
+
print(f"Model moved to {output_path}")
|
| 100 |
+
else:
|
| 101 |
+
print(f"Model exported to {output_path}")
|
| 102 |
+
else:
|
| 103 |
+
print(f"Warning: Could not locate exported ONNX file.")
|
| 104 |
+
print(f"Expected at: {output_path}")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
# Print model info
|
| 108 |
+
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
| 109 |
+
print(f"✅ ONNX export successful!")
|
| 110 |
+
print(f" Output: {output_path}")
|
| 111 |
+
print(f" Size: {size_mb:.2f} MB")
|
| 112 |
+
print(f" Input shape: {batch if batch > 0 else 'dynamic'}x3x{imgsz}x{imgsz}")
|
| 113 |
+
print(f" Opset: {opset}")
|
| 114 |
+
print(f" Dynamic: {dynamic}")
|
| 115 |
+
print(f" FP16: {half}")
|
| 116 |
+
|
| 117 |
+
return output_path
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"❌ Error during ONNX export: {e}")
|
| 121 |
+
import traceback
|
| 122 |
+
traceback.print_exc()
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def main():
|
| 126 |
+
parser = argparse.ArgumentParser(description='Export YOLO model to ONNX format')
|
| 127 |
+
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
| 128 |
+
help='Path to YOLO model (.pt file)')
|
| 129 |
+
parser.add_argument('--output', type=str, default=None,
|
| 130 |
+
help='Output ONNX file path (default: model/exports/<model_name>.onnx)')
|
| 131 |
+
parser.add_argument('--img-size', type=int, default=640,
|
| 132 |
+
help='Input image size (default: 640)')
|
| 133 |
+
parser.add_argument('--batch', type=int, default=1,
|
| 134 |
+
help='Batch size (default: 1, use -1 for dynamic)')
|
| 135 |
+
parser.add_argument('--dynamic', action='store_true',
|
| 136 |
+
help='Enable dynamic axes (batch, height, width)')
|
| 137 |
+
parser.add_argument('--no-simplify', action='store_true',
|
| 138 |
+
help='Disable ONNX simplifier')
|
| 139 |
+
parser.add_argument('--opset', type=int, default=12,
|
| 140 |
+
help='ONNX opset version (default: 12)')
|
| 141 |
+
parser.add_argument('--half', action='store_true',
|
| 142 |
+
help='Use FP16 quantization')
|
| 143 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 144 |
+
help='Path to config file')
|
| 145 |
+
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
# Load config
|
| 149 |
+
config = load_config(args.config)
|
| 150 |
+
|
| 151 |
+
# Use model from config if not specified
|
| 152 |
+
if args.model == 'yolov8n.pt' and config:
|
| 153 |
+
models_config = config.get('models', {})
|
| 154 |
+
detection_config = models_config.get('detection', {})
|
| 155 |
+
default_model = detection_config.get('strawberry_yolov8n', 'yolov8n.pt')
|
| 156 |
+
if os.path.exists(default_model):
|
| 157 |
+
args.model = default_model
|
| 158 |
+
else:
|
| 159 |
+
# Check for other available models
|
| 160 |
+
available_models = ['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt',
|
| 161 |
+
'model/weights/strawberry_yolov11n.pt',
|
| 162 |
+
'model/weights/ripeness_detection_yolov11n.pt']
|
| 163 |
+
for model in available_models:
|
| 164 |
+
if os.path.exists(model):
|
| 165 |
+
args.model = model
|
| 166 |
+
print(f"Using available model: {model}")
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
# Export model
|
| 170 |
+
success = export_to_onnx(
|
| 171 |
+
model_path=args.model,
|
| 172 |
+
output_path=args.output,
|
| 173 |
+
imgsz=args.img_size,
|
| 174 |
+
batch=args.batch,
|
| 175 |
+
dynamic=args.dynamic,
|
| 176 |
+
simplify=not args.no_simplify,
|
| 177 |
+
opset=args.opset,
|
| 178 |
+
half=args.half
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if success:
|
| 182 |
+
print(f"\n✅ Export completed successfully!")
|
| 183 |
+
print(f"\nNext steps:")
|
| 184 |
+
print(f"1. Test the ONNX model with ONNX Runtime:")
|
| 185 |
+
print(f" python -m onnxruntime.tools.onnx_model_test {success}")
|
| 186 |
+
print(f"2. Convert to TensorFlow Lite for edge deployment:")
|
| 187 |
+
print(f" python export_tflite_int8.py --model {success}")
|
| 188 |
+
print(f"3. Use in your application with ONNX Runtime")
|
| 189 |
+
else:
|
| 190 |
+
print("\n❌ Export failed.")
|
| 191 |
+
sys.exit(1)
|
| 192 |
+
|
| 193 |
+
if __name__ == '__main__':
|
| 194 |
+
main()
|
scripts/export_tflite_int8.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Export YOLOv8 model to TensorFlow Lite with INT8 quantization.
|
| 4 |
+
Uses Ultralytics YOLOv8 export functionality with calibration dataset.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import yaml
|
| 12 |
+
import numpy as np
|
| 13 |
+
from ultralytics import YOLO
|
| 14 |
+
|
| 15 |
+
def load_config(config_path="config.yaml"):
|
| 16 |
+
"""Load configuration from YAML file."""
|
| 17 |
+
if not os.path.exists(config_path):
|
| 18 |
+
print(f"Warning: Config file {config_path} not found. Using defaults.")
|
| 19 |
+
return {}
|
| 20 |
+
with open(config_path, 'r') as f:
|
| 21 |
+
return yaml.safe_load(f)
|
| 22 |
+
|
| 23 |
+
def get_representative_dataset(dataset_path, num_calibration=100):
|
| 24 |
+
"""
|
| 25 |
+
Create representative dataset for INT8 calibration.
|
| 26 |
+
Returns a generator that yields normalized images.
|
| 27 |
+
"""
|
| 28 |
+
import cv2
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
# Find validation images
|
| 32 |
+
val_path = Path(dataset_path) / "valid" / "images"
|
| 33 |
+
if not val_path.exists():
|
| 34 |
+
val_path = Path(dataset_path) / "val" / "images"
|
| 35 |
+
if not val_path.exists():
|
| 36 |
+
print(f"Warning: Validation images not found at {val_path}")
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
image_files = list(val_path.glob("*.jpg")) + list(val_path.glob("*.png"))
|
| 40 |
+
if len(image_files) == 0:
|
| 41 |
+
print("No validation images found for calibration.")
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
# Limit to num_calibration
|
| 45 |
+
image_files = image_files[:num_calibration]
|
| 46 |
+
|
| 47 |
+
def representative_dataset():
|
| 48 |
+
for img_path in image_files:
|
| 49 |
+
# Load and preprocess image
|
| 50 |
+
img = cv2.imread(str(img_path))
|
| 51 |
+
if img is None:
|
| 52 |
+
continue
|
| 53 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 54 |
+
img = cv2.resize(img, (640, 640))
|
| 55 |
+
img = img.astype(np.float32) / 255.0 # Normalize to [0, 1]
|
| 56 |
+
img = np.expand_dims(img, axis=0) # Add batch dimension
|
| 57 |
+
yield [img]
|
| 58 |
+
|
| 59 |
+
return representative_dataset
|
| 60 |
+
|
| 61 |
+
def export_yolov8_to_tflite_int8(
|
| 62 |
+
model_path,
|
| 63 |
+
output_path,
|
| 64 |
+
dataset_path=None,
|
| 65 |
+
img_size=640,
|
| 66 |
+
int8=True,
|
| 67 |
+
dynamic=False,
|
| 68 |
+
half=False
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Export YOLOv8 model to TFLite format with optional INT8 quantization.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_path: Path to YOLOv8 .pt model
|
| 75 |
+
output_path: Output .tflite file path
|
| 76 |
+
dataset_path: Path to dataset for INT8 calibration
|
| 77 |
+
img_size: Input image size
|
| 78 |
+
int8: Enable INT8 quantization
|
| 79 |
+
dynamic: Enable dynamic range quantization (alternative to INT8)
|
| 80 |
+
half: Enable FP16 quantization
|
| 81 |
+
"""
|
| 82 |
+
print(f"Loading YOLOv8 model from {model_path}")
|
| 83 |
+
model = YOLO(model_path)
|
| 84 |
+
|
| 85 |
+
# Check if model is a detection model
|
| 86 |
+
task = model.task if hasattr(model, 'task') else 'detect'
|
| 87 |
+
print(f"Model task: {task}")
|
| 88 |
+
|
| 89 |
+
# Prepare export arguments
|
| 90 |
+
export_args = {
|
| 91 |
+
'format': 'tflite',
|
| 92 |
+
'imgsz': img_size,
|
| 93 |
+
'optimize': True,
|
| 94 |
+
'int8': int8,
|
| 95 |
+
'half': half,
|
| 96 |
+
'dynamic': dynamic,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# If INT8 quantization is requested, provide representative dataset
|
| 100 |
+
if int8 and dataset_path:
|
| 101 |
+
print(f"Using dataset at {dataset_path} for INT8 calibration")
|
| 102 |
+
representative_dataset = get_representative_dataset(dataset_path)
|
| 103 |
+
if representative_dataset:
|
| 104 |
+
# Note: Ultralytics YOLOv8 export doesn't directly accept representative_dataset
|
| 105 |
+
# We'll need to use a different approach
|
| 106 |
+
print("INT8 calibration with representative dataset requires custom implementation.")
|
| 107 |
+
print("Falling back to Ultralytics built-in INT8 calibration...")
|
| 108 |
+
# Use built-in calibration images
|
| 109 |
+
export_args['int8'] = True
|
| 110 |
+
else:
|
| 111 |
+
print("Warning: No representative dataset available. Using default calibration.")
|
| 112 |
+
export_args['int8'] = True
|
| 113 |
+
elif int8:
|
| 114 |
+
print("Using built-in calibration images for INT8 quantization")
|
| 115 |
+
export_args['int8'] = True
|
| 116 |
+
|
| 117 |
+
# Export model
|
| 118 |
+
print(f"Exporting model to TFLite with args: {export_args}")
|
| 119 |
+
try:
|
| 120 |
+
# Use Ultralytics export
|
| 121 |
+
exported_path = model.export(**export_args)
|
| 122 |
+
|
| 123 |
+
# The exported file will be in the same directory as the model
|
| 124 |
+
# with a .tflite extension
|
| 125 |
+
exported_files = list(Path(model_path).parent.glob("*.tflite"))
|
| 126 |
+
if exported_files:
|
| 127 |
+
latest_tflite = max(exported_files, key=os.path.getctime)
|
| 128 |
+
# Move to desired output path
|
| 129 |
+
import shutil
|
| 130 |
+
shutil.move(str(latest_tflite), output_path)
|
| 131 |
+
print(f"Model exported to {output_path}")
|
| 132 |
+
|
| 133 |
+
# Print model size
|
| 134 |
+
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
| 135 |
+
print(f"Model size: {size_mb:.2f} MB")
|
| 136 |
+
|
| 137 |
+
return output_path
|
| 138 |
+
else:
|
| 139 |
+
print("Error: No .tflite file was generated")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Error during export: {e}")
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
def main():
|
| 147 |
+
parser = argparse.ArgumentParser(description='Export YOLOv8 model to TFLite with INT8 quantization')
|
| 148 |
+
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
| 149 |
+
help='Path to YOLOv8 model (.pt file)')
|
| 150 |
+
parser.add_argument('--output', type=str, default='model/exports/strawberry_yolov8n_int8.tflite',
|
| 151 |
+
help='Output TFLite file path')
|
| 152 |
+
parser.add_argument('--dataset', type=str, default='model/dataset_strawberry_detect_v3',
|
| 153 |
+
help='Path to dataset for INT8 calibration')
|
| 154 |
+
parser.add_argument('--img-size', type=int, default=640,
|
| 155 |
+
help='Input image size (default: 640)')
|
| 156 |
+
parser.add_argument('--no-int8', action='store_true',
|
| 157 |
+
help='Disable INT8 quantization')
|
| 158 |
+
parser.add_argument('--dynamic', action='store_true',
|
| 159 |
+
help='Use dynamic range quantization')
|
| 160 |
+
parser.add_argument('--half', action='store_true',
|
| 161 |
+
help='Use FP16 quantization')
|
| 162 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 163 |
+
help='Path to config file')
|
| 164 |
+
|
| 165 |
+
args = parser.parse_args()
|
| 166 |
+
|
| 167 |
+
# Load config
|
| 168 |
+
config = load_config(args.config)
|
| 169 |
+
|
| 170 |
+
# Use dataset path from config if not provided
|
| 171 |
+
if args.dataset is None and config:
|
| 172 |
+
args.dataset = config.get('dataset', {}).get('detection', {}).get('path', 'model/dataset_strawberry_detect_v3')
|
| 173 |
+
|
| 174 |
+
# Create output directory if it doesn't exist
|
| 175 |
+
output_dir = os.path.dirname(args.output)
|
| 176 |
+
if output_dir and not os.path.exists(output_dir):
|
| 177 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 178 |
+
|
| 179 |
+
# Export model
|
| 180 |
+
success = export_yolov8_to_tflite_int8(
|
| 181 |
+
model_path=args.model,
|
| 182 |
+
output_path=args.output,
|
| 183 |
+
dataset_path=args.dataset,
|
| 184 |
+
img_size=args.img_size,
|
| 185 |
+
int8=not args.no_int8,
|
| 186 |
+
dynamic=args.dynamic,
|
| 187 |
+
half=args.half
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if success:
|
| 191 |
+
print(f"\n✅ Successfully exported model to: {success}")
|
| 192 |
+
print("\nUsage:")
|
| 193 |
+
print(f" python detect_realtime.py --model {args.output}")
|
| 194 |
+
print(f" python detect_realtime.py --model {args.output} --mode detection")
|
| 195 |
+
else:
|
| 196 |
+
print("\n❌ Export failed.")
|
| 197 |
+
sys.exit(1)
|
| 198 |
+
|
| 199 |
+
if __name__ == '__main__':
|
| 200 |
+
main()
|
scripts/get-pip.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/label_ripeness_dataset.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Strawberry Ripeness Dataset Labeling Tool
|
| 4 |
+
Helps organize and label the 889 unlabeled images into 3 categories:
|
| 5 |
+
- Unripe (green/white/pale pink)
|
| 6 |
+
- Ripe (bright red)
|
| 7 |
+
- Overripe (dark red/maroon/rotting)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import random
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
def create_labeling_directories(base_path):
|
| 18 |
+
"""Create the three labeling directories"""
|
| 19 |
+
labels = ['unripe', 'ripe', 'overripe']
|
| 20 |
+
dirs = {}
|
| 21 |
+
|
| 22 |
+
for label in labels:
|
| 23 |
+
dir_path = base_path / label
|
| 24 |
+
dir_path.mkdir(exist_ok=True)
|
| 25 |
+
dirs[label] = dir_path
|
| 26 |
+
|
| 27 |
+
return dirs
|
| 28 |
+
|
| 29 |
+
def get_image_files(to_label_path):
|
| 30 |
+
"""Get all image files from the to_label directory"""
|
| 31 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
|
| 32 |
+
image_files = []
|
| 33 |
+
|
| 34 |
+
for file_path in to_label_path.iterdir():
|
| 35 |
+
if file_path.suffix.lower() in image_extensions:
|
| 36 |
+
image_files.append(file_path)
|
| 37 |
+
|
| 38 |
+
return sorted(image_files)
|
| 39 |
+
|
| 40 |
+
def batch_label_images(image_files, dirs, batch_size=50):
|
| 41 |
+
"""Label images in batches for easier management"""
|
| 42 |
+
total_images = len(image_files)
|
| 43 |
+
num_batches = (total_images + batch_size - 1) // batch_size
|
| 44 |
+
|
| 45 |
+
print(f"Found {total_images} images to label")
|
| 46 |
+
print(f"Will process in {num_batches} batches of {batch_size} images each")
|
| 47 |
+
|
| 48 |
+
for batch_num in range(num_batches):
|
| 49 |
+
start_idx = batch_num * batch_size
|
| 50 |
+
end_idx = min(start_idx + batch_size, total_images)
|
| 51 |
+
batch_files = image_files[start_idx:end_idx]
|
| 52 |
+
|
| 53 |
+
print(f"\n=== BATCH {batch_num + 1}/{num_batches} ===")
|
| 54 |
+
print(f"Images {start_idx + 1} to {end_idx}")
|
| 55 |
+
|
| 56 |
+
# Show first few images in batch for preview
|
| 57 |
+
for i, img_file in enumerate(batch_files[:5]):
|
| 58 |
+
try:
|
| 59 |
+
with Image.open(img_file) as img:
|
| 60 |
+
print(f" {start_idx + i + 1}. {img_file.name} ({img.size[0]}x{img.size[1]})")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f" {start_idx + i + 1}. {img_file.name} (Error: {e})")
|
| 63 |
+
|
| 64 |
+
if len(batch_files) > 5:
|
| 65 |
+
print(f" ... and {len(batch_files) - 5} more images")
|
| 66 |
+
|
| 67 |
+
# Interactive labeling for this batch
|
| 68 |
+
label_batch(batch_files, dirs)
|
| 69 |
+
|
| 70 |
+
def label_batch(batch_files, dirs):
|
| 71 |
+
"""Interactive labeling for a batch of images"""
|
| 72 |
+
print("\nLabeling Instructions:")
|
| 73 |
+
print("1 = Unripe (green/white/pale pink)")
|
| 74 |
+
print("2 = Ripe (bright red)")
|
| 75 |
+
print("3 = Overripe (dark red/maroon/rotting)")
|
| 76 |
+
print("s = Skip this image")
|
| 77 |
+
print("q = Quit labeling")
|
| 78 |
+
print("Enter = Next image")
|
| 79 |
+
|
| 80 |
+
for img_file in batch_files:
|
| 81 |
+
try:
|
| 82 |
+
# Try to display image info
|
| 83 |
+
with Image.open(img_file) as img:
|
| 84 |
+
print(f"\nProcessing: {img_file.name}")
|
| 85 |
+
print(f"Size: {img.size[0]}x{img.size[1]}, Mode: {img.mode}")
|
| 86 |
+
|
| 87 |
+
while True:
|
| 88 |
+
choice = input("Label (1/2/3/s/q/Enter): ").strip().lower()
|
| 89 |
+
|
| 90 |
+
if choice == '1' or choice == 'unripe':
|
| 91 |
+
shutil.move(str(img_file), str(dirs['unripe'] / img_file.name))
|
| 92 |
+
print(f"✓ Moved to unripe/")
|
| 93 |
+
break
|
| 94 |
+
elif choice == '2' or choice == 'ripe':
|
| 95 |
+
shutil.move(str(img_file), str(dirs['ripe'] / img_file.name))
|
| 96 |
+
print(f"✓ Moved to ripe/")
|
| 97 |
+
break
|
| 98 |
+
elif choice == '3' or choice == 'overripe':
|
| 99 |
+
shutil.move(str(img_file), str(dirs['overripe'] / img_file.name))
|
| 100 |
+
print(f"✓ Moved to overripe/")
|
| 101 |
+
break
|
| 102 |
+
elif choice == 's' or choice == 'skip':
|
| 103 |
+
print("⏭️ Skipped")
|
| 104 |
+
break
|
| 105 |
+
elif choice == 'q' or choice == 'quit':
|
| 106 |
+
print("Quitting...")
|
| 107 |
+
return
|
| 108 |
+
elif choice == '':
|
| 109 |
+
print("⏭️ Skipped (no input)")
|
| 110 |
+
break
|
| 111 |
+
else:
|
| 112 |
+
print("Invalid choice. Please enter 1, 2, 3, s, q, or press Enter.")
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Error processing {img_file.name}: {e}")
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
def count_labeled_images(dirs):
|
| 119 |
+
"""Count images in each label directory"""
|
| 120 |
+
counts = {}
|
| 121 |
+
for label, dir_path in dirs.items():
|
| 122 |
+
count = len([f for f in dir_path.iterdir() if f.is_file()])
|
| 123 |
+
counts[label] = count
|
| 124 |
+
return counts
|
| 125 |
+
|
| 126 |
+
def main():
|
| 127 |
+
parser = argparse.ArgumentParser(description='Label strawberry ripeness dataset')
|
| 128 |
+
parser.add_argument('--dataset-path', type=str,
|
| 129 |
+
default='model/ripeness_manual_dataset',
|
| 130 |
+
help='Path to the ripeness dataset directory')
|
| 131 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
| 132 |
+
help='Number of images to process in each batch')
|
| 133 |
+
parser.add_argument('--count-only', action='store_true',
|
| 134 |
+
help='Only count current images, do not start labeling')
|
| 135 |
+
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
|
| 138 |
+
base_path = Path(args.dataset_path)
|
| 139 |
+
to_label_path = base_path / 'to_label'
|
| 140 |
+
|
| 141 |
+
if not to_label_path.exists():
|
| 142 |
+
print(f"Error: to_label directory not found at {to_label_path}")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
# Create labeling directories
|
| 146 |
+
dirs = create_labeling_directories(base_path)
|
| 147 |
+
|
| 148 |
+
# Count current state
|
| 149 |
+
remaining_files = get_image_files(to_label_path)
|
| 150 |
+
labeled_counts = count_labeled_images(dirs)
|
| 151 |
+
|
| 152 |
+
print("=== CURRENT STATUS ===")
|
| 153 |
+
print(f"Images remaining to label: {len(remaining_files)}")
|
| 154 |
+
print(f"Already labeled:")
|
| 155 |
+
for label, count in labeled_counts.items():
|
| 156 |
+
print(f" {label}: {count} images")
|
| 157 |
+
print(f"Total labeled: {sum(labeled_counts.values())}")
|
| 158 |
+
|
| 159 |
+
if args.count_only:
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
if len(remaining_files) == 0:
|
| 163 |
+
print("\n✅ All images have been labeled!")
|
| 164 |
+
print("You can now run: python3 train_ripeness_classifier.py")
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
# Ask user if they want to continue
|
| 168 |
+
response = input(f"\nStart labeling {len(remaining_files)} remaining images? (y/n): ")
|
| 169 |
+
if response.lower() != 'y':
|
| 170 |
+
print("Labeling cancelled.")
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
# Start labeling process
|
| 174 |
+
batch_label_images(remaining_files, dirs, args.batch_size)
|
| 175 |
+
|
| 176 |
+
# Final count
|
| 177 |
+
final_counts = count_labeled_images(dirs)
|
| 178 |
+
print("\n=== FINAL COUNTS ===")
|
| 179 |
+
for label, count in final_counts.items():
|
| 180 |
+
print(f"{label}: {count} images")
|
| 181 |
+
|
| 182 |
+
total_labeled = sum(final_counts.values())
|
| 183 |
+
print(f"Total labeled: {total_labeled}")
|
| 184 |
+
|
| 185 |
+
if total_labeled > 0:
|
| 186 |
+
print("\n✅ Labeling complete!")
|
| 187 |
+
print("Next steps:")
|
| 188 |
+
print("1. Review the labeled images for quality")
|
| 189 |
+
print("2. Run: python3 train_ripeness_classifier.py")
|
| 190 |
+
|
| 191 |
+
if __name__ == '__main__':
|
| 192 |
+
main()
|
scripts/optimization/optimized_onnx_inference.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Optimized ONNX Inference for Raspberry Pi
|
| 4 |
+
High-performance inference with ONNX Runtime optimizations
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Tuple, List, Optional
|
| 14 |
+
|
| 15 |
+
class OptimizedONNXInference:
|
| 16 |
+
"""
|
| 17 |
+
Optimized ONNX inference engine for Raspberry Pi
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_path: str, conf_threshold: float = 0.5):
|
| 21 |
+
"""
|
| 22 |
+
Initialize optimized ONNX inference engine
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path: Path to ONNX model
|
| 26 |
+
conf_threshold: Confidence threshold for detections
|
| 27 |
+
"""
|
| 28 |
+
self.conf_threshold = conf_threshold
|
| 29 |
+
self.model_path = model_path
|
| 30 |
+
self.session = self._create_optimized_session()
|
| 31 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 32 |
+
self.input_shape = self.session.get_inputs()[0].shape
|
| 33 |
+
|
| 34 |
+
# Extract input dimensions
|
| 35 |
+
self.input_height = self.input_shape[2]
|
| 36 |
+
self.input_width = self.input_shape[3]
|
| 37 |
+
|
| 38 |
+
print(f"✅ Optimized ONNX model loaded: {model_path}")
|
| 39 |
+
print(f"📐 Input shape: {self.input_shape}")
|
| 40 |
+
print(f"🎯 Confidence threshold: {conf_threshold}")
|
| 41 |
+
|
| 42 |
+
def _create_optimized_session(self) -> ort.InferenceSession:
|
| 43 |
+
"""
|
| 44 |
+
Create ONNX session with Raspberry Pi optimizations
|
| 45 |
+
"""
|
| 46 |
+
# Set environment variables for optimization
|
| 47 |
+
os.environ["OMP_NUM_THREADS"] = "4" # Raspberry Pi 4 has 4 cores
|
| 48 |
+
os.environ["OMP_THREAD_LIMIT"] = "4"
|
| 49 |
+
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"
|
| 50 |
+
os.environ["MKL_NUM_THREADS"] = "4"
|
| 51 |
+
|
| 52 |
+
# Session options for maximum performance
|
| 53 |
+
session_options = ort.SessionOptions()
|
| 54 |
+
|
| 55 |
+
# Enable all graph optimizations
|
| 56 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 57 |
+
|
| 58 |
+
# Use sequential execution for consistency
|
| 59 |
+
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 60 |
+
|
| 61 |
+
# Optimize thread usage for Raspberry Pi
|
| 62 |
+
session_options.intra_op_num_threads = 4
|
| 63 |
+
session_options.inter_op_num_threads = 1
|
| 64 |
+
|
| 65 |
+
# Enable memory pattern optimization
|
| 66 |
+
session_options.enable_mem_pattern = True
|
| 67 |
+
session_options.enable_mem_reuse = True
|
| 68 |
+
|
| 69 |
+
# CPU execution provider (Raspberry Pi doesn't have CUDA)
|
| 70 |
+
providers = ['CPUExecutionProvider']
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
session = ort.InferenceSession(
|
| 74 |
+
self.model_path,
|
| 75 |
+
sess_options=session_options,
|
| 76 |
+
providers=providers
|
| 77 |
+
)
|
| 78 |
+
return session
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"❌ Failed to create optimized session: {e}")
|
| 81 |
+
# Fallback to basic session
|
| 82 |
+
return ort.InferenceSession(self.model_path, providers=providers)
|
| 83 |
+
|
| 84 |
+
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""
|
| 86 |
+
Optimized preprocessing for Raspberry Pi
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
image: Input image (BGR format)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Preprocessed tensor
|
| 93 |
+
"""
|
| 94 |
+
# Convert BGR to RGB
|
| 95 |
+
if len(image.shape) == 3:
|
| 96 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 97 |
+
|
| 98 |
+
# Resize with optimization
|
| 99 |
+
image = cv2.resize(image, (self.input_width, self.input_height),
|
| 100 |
+
interpolation=cv2.INTER_LINEAR)
|
| 101 |
+
|
| 102 |
+
# Convert to float32 and normalize
|
| 103 |
+
image = image.astype(np.float32) / 255.0
|
| 104 |
+
|
| 105 |
+
# Transpose to CHW format (ONNX expects this)
|
| 106 |
+
image = np.transpose(image, (2, 0, 1))
|
| 107 |
+
|
| 108 |
+
# Add batch dimension
|
| 109 |
+
image = np.expand_dims(image, axis=0)
|
| 110 |
+
|
| 111 |
+
return image
|
| 112 |
+
|
| 113 |
+
def postprocess(self, outputs: np.ndarray) -> List[dict]:
|
| 114 |
+
"""
|
| 115 |
+
Post-process YOLOv8 outputs
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
outputs: Raw model outputs
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
List of detections
|
| 122 |
+
"""
|
| 123 |
+
detections = []
|
| 124 |
+
|
| 125 |
+
# YOLOv8 output shape: [1, 5, 8400] for 640x640
|
| 126 |
+
# Where 5 = [x, y, w, h, conf] and 8400 = 80x80 + 40x40 + 20x20
|
| 127 |
+
|
| 128 |
+
# Reshape outputs
|
| 129 |
+
outputs = outputs[0] # Remove batch dimension
|
| 130 |
+
|
| 131 |
+
# Filter by confidence
|
| 132 |
+
conf_mask = outputs[4] > self.conf_threshold
|
| 133 |
+
filtered_outputs = outputs[:, conf_mask]
|
| 134 |
+
|
| 135 |
+
if filtered_outputs.shape[1] == 0:
|
| 136 |
+
return detections
|
| 137 |
+
|
| 138 |
+
# Extract boxes and scores
|
| 139 |
+
boxes = filtered_outputs[:4].T # [x, y, w, h]
|
| 140 |
+
scores = filtered_outputs[4] # confidence scores
|
| 141 |
+
|
| 142 |
+
# Convert from center format to corner format
|
| 143 |
+
x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
| 144 |
+
x1 = x - w / 2
|
| 145 |
+
y1 = y - h / 2
|
| 146 |
+
x2 = x + w / 2
|
| 147 |
+
y2 = y + h / 2
|
| 148 |
+
|
| 149 |
+
# Clip to image boundaries
|
| 150 |
+
x1 = np.clip(x1, 0, self.input_width)
|
| 151 |
+
y1 = np.clip(y1, 0, self.input_height)
|
| 152 |
+
x2 = np.clip(x2, 0, self.input_width)
|
| 153 |
+
y2 = np.clip(y2, 0, self.input_height)
|
| 154 |
+
|
| 155 |
+
# Create detection dictionaries
|
| 156 |
+
for i in range(len(scores)):
|
| 157 |
+
detection = {
|
| 158 |
+
'bbox': [float(x1[i]), float(y1[i]), float(x2[i]), float(y2[i])],
|
| 159 |
+
'confidence': float(scores[i]),
|
| 160 |
+
'class': 0, # Strawberry class
|
| 161 |
+
'class_name': 'strawberry'
|
| 162 |
+
}
|
| 163 |
+
detections.append(detection)
|
| 164 |
+
|
| 165 |
+
return detections
|
| 166 |
+
|
| 167 |
+
def predict(self, image: np.ndarray) -> Tuple[List[dict], float]:
|
| 168 |
+
"""
|
| 169 |
+
Run optimized inference
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
image: Input image
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Tuple of (detections, inference_time)
|
| 176 |
+
"""
|
| 177 |
+
# Preprocess
|
| 178 |
+
input_tensor = self.preprocess(image)
|
| 179 |
+
|
| 180 |
+
# Run inference with timing
|
| 181 |
+
start_time = time.perf_counter()
|
| 182 |
+
outputs = self.session.run(None, {self.input_name: input_tensor})
|
| 183 |
+
inference_time = time.perf_counter() - start_time
|
| 184 |
+
|
| 185 |
+
# Post-process
|
| 186 |
+
detections = self.postprocess(outputs)
|
| 187 |
+
|
| 188 |
+
return detections, inference_time
|
| 189 |
+
|
| 190 |
+
def predict_batch(self, images: List[np.ndarray]) -> Tuple[List[List[dict]], float]:
|
| 191 |
+
"""
|
| 192 |
+
Run batch inference for multiple images
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
images: List of input images
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tuple of (list_of_detections, total_inference_time)
|
| 199 |
+
"""
|
| 200 |
+
if not images:
|
| 201 |
+
return [], 0.0
|
| 202 |
+
|
| 203 |
+
# Preprocess all images
|
| 204 |
+
input_tensors = [self.preprocess(img) for img in images]
|
| 205 |
+
batch_tensor = np.concatenate(input_tensors, axis=0)
|
| 206 |
+
|
| 207 |
+
# Run batch inference
|
| 208 |
+
start_time = time.perf_counter()
|
| 209 |
+
outputs = self.session.run(None, {self.input_name: batch_tensor})
|
| 210 |
+
inference_time = time.perf_counter() - start_time
|
| 211 |
+
|
| 212 |
+
# Post-process each image in batch
|
| 213 |
+
all_detections = []
|
| 214 |
+
for i in range(len(images)):
|
| 215 |
+
single_output = outputs[0][i:i+1] # Extract single image output
|
| 216 |
+
detections = self.postprocess([single_output])
|
| 217 |
+
all_detections.append(detections)
|
| 218 |
+
|
| 219 |
+
return all_detections, inference_time
|
| 220 |
+
|
| 221 |
+
def benchmark_model(model_path: str, test_image_path: str, runs: int = 10) -> dict:
|
| 222 |
+
"""
|
| 223 |
+
Benchmark model performance
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
model_path: Path to ONNX model
|
| 227 |
+
test_image_path: Path to test image
|
| 228 |
+
runs: Number of benchmark runs
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Benchmark results dictionary
|
| 232 |
+
"""
|
| 233 |
+
# Load model
|
| 234 |
+
model = OptimizedONNXInference(model_path)
|
| 235 |
+
|
| 236 |
+
# Load test image
|
| 237 |
+
test_image = cv2.imread(test_image_path)
|
| 238 |
+
if test_image is None:
|
| 239 |
+
raise ValueError(f"Could not load test image: {test_image_path}")
|
| 240 |
+
|
| 241 |
+
# Warmup run
|
| 242 |
+
_ = model.predict(test_image)
|
| 243 |
+
|
| 244 |
+
# Benchmark runs
|
| 245 |
+
times = []
|
| 246 |
+
for _ in range(runs):
|
| 247 |
+
_, inference_time = model.predict(test_image)
|
| 248 |
+
times.append(inference_time * 1000) # Convert to milliseconds
|
| 249 |
+
|
| 250 |
+
# Calculate statistics
|
| 251 |
+
times_array = np.array(times)
|
| 252 |
+
results = {
|
| 253 |
+
'mean_ms': float(np.mean(times_array)),
|
| 254 |
+
'median_ms': float(np.median(times_array)),
|
| 255 |
+
'std_ms': float(np.std(times_array)),
|
| 256 |
+
'min_ms': float(np.min(times_array)),
|
| 257 |
+
'max_ms': float(np.max(times_array)),
|
| 258 |
+
'fps': float(1000 / np.mean(times_array)),
|
| 259 |
+
'runs': runs
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
return results
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
# Example usage
|
| 266 |
+
model_path = "model/detection/yolov8n/best_416.onnx"
|
| 267 |
+
test_image = "test_detection_result.jpg"
|
| 268 |
+
|
| 269 |
+
if os.path.exists(model_path) and os.path.exists(test_image):
|
| 270 |
+
print("🚀 Testing Optimized ONNX Inference")
|
| 271 |
+
print("=" * 50)
|
| 272 |
+
|
| 273 |
+
# Load model
|
| 274 |
+
model = OptimizedONNXInference(model_path)
|
| 275 |
+
|
| 276 |
+
# Load and predict
|
| 277 |
+
image = cv2.imread(test_image)
|
| 278 |
+
detections, inference_time = model.predict(image)
|
| 279 |
+
|
| 280 |
+
print(".2f" print(f"📊 Detections found: {len(detections)}")
|
| 281 |
+
|
| 282 |
+
# Benchmark
|
| 283 |
+
print("\n📈 Running benchmark (10 runs)...")
|
| 284 |
+
results = benchmark_model(model_path, test_image, runs=10)
|
| 285 |
+
|
| 286 |
+
print("📊 Benchmark Results:" print(".2f" print(".2f" print(".2f" print(".2f" print(".2f" print(".1f"
|
| 287 |
+
print("\n✅ Optimized inference test complete!")
|
| 288 |
+
else:
|
| 289 |
+
print("❌ Model or test image not found")
|
| 290 |
+
print(f"Model: {model_path} - {'✅' if os.path.exists(model_path) else '❌'}")
|
| 291 |
+
print(f"Image: {test_image} - {'✅' if os.path.exists(test_image) else '❌'}")
|
scripts/organize_labeled_images.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Organize automatically labeled images from JSON results into proper directories.
|
| 4 |
+
This script processes the auto_labeling_results JSON files and moves images
|
| 5 |
+
from to_label/ to the appropriate subdirectories (unripe/ripe/overripe/).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import glob
|
| 13 |
+
|
| 14 |
+
def organize_labeled_images():
|
| 15 |
+
"""Organize automatically labeled images based on JSON results."""
|
| 16 |
+
|
| 17 |
+
# Paths
|
| 18 |
+
base_dir = Path("model/ripeness_manual_dataset")
|
| 19 |
+
to_label_dir = base_dir / "to_label"
|
| 20 |
+
unripe_dir = base_dir / "unripe"
|
| 21 |
+
ripe_dir = base_dir / "ripe"
|
| 22 |
+
overripe_dir = base_dir / "overripe"
|
| 23 |
+
|
| 24 |
+
# Find all auto_labeling_results JSON files
|
| 25 |
+
json_files = glob.glob(str(base_dir / "auto_labeling_results_*.json"))
|
| 26 |
+
json_files.sort() # Process in chronological order
|
| 27 |
+
|
| 28 |
+
print(f"Found {len(json_files)} auto_labeling_results files")
|
| 29 |
+
|
| 30 |
+
total_moved = 0
|
| 31 |
+
total_errors = 0
|
| 32 |
+
|
| 33 |
+
for json_file in json_files:
|
| 34 |
+
print(f"\nProcessing: {os.path.basename(json_file)}")
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
with open(json_file, 'r') as f:
|
| 38 |
+
results = json.load(f)
|
| 39 |
+
|
| 40 |
+
batch_moved = 0
|
| 41 |
+
batch_errors = 0
|
| 42 |
+
|
| 43 |
+
# Process each labeled image (results is a list)
|
| 44 |
+
for label_data in results:
|
| 45 |
+
if isinstance(label_data, dict) and 'image' in label_data and 'label' in label_data:
|
| 46 |
+
image_name = label_data['image']
|
| 47 |
+
label = label_data['label']
|
| 48 |
+
confidence = label_data.get('confidence', 0.0)
|
| 49 |
+
|
| 50 |
+
# Skip if confidence is too low or label is unknown
|
| 51 |
+
if label in ['unknown', 'skip'] or confidence < 0.6:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# Source and destination paths
|
| 55 |
+
src_path = to_label_dir / image_name
|
| 56 |
+
|
| 57 |
+
if not src_path.exists():
|
| 58 |
+
print(f" ⚠️ Source file not found: {image_name}")
|
| 59 |
+
batch_errors += 1
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# Determine destination directory
|
| 63 |
+
if label == 'unripe':
|
| 64 |
+
dst_dir = unripe_dir
|
| 65 |
+
elif label == 'ripe':
|
| 66 |
+
dst_dir = ripe_dir
|
| 67 |
+
elif label == 'overripe':
|
| 68 |
+
dst_dir = overripe_dir
|
| 69 |
+
else:
|
| 70 |
+
print(f" ⚠️ Unknown label '{label}': {image_name}")
|
| 71 |
+
batch_errors += 1
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
dst_path = dst_dir / image_name
|
| 75 |
+
|
| 76 |
+
# Move the file
|
| 77 |
+
try:
|
| 78 |
+
shutil.move(str(src_path), str(dst_path))
|
| 79 |
+
print(f" ✅ {label} ({confidence:.2f}): {image_name}")
|
| 80 |
+
batch_moved += 1
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f" ❌ Error moving {image_name}: {e}")
|
| 83 |
+
batch_errors += 1
|
| 84 |
+
|
| 85 |
+
print(f" 📦 Batch results: {batch_moved} moved, {batch_errors} errors")
|
| 86 |
+
total_moved += batch_moved
|
| 87 |
+
total_errors += batch_errors
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f" ❌ Error processing {json_file}: {e}")
|
| 91 |
+
total_errors += 1
|
| 92 |
+
|
| 93 |
+
# Final summary
|
| 94 |
+
print(f"\n=== ORGANIZATION COMPLETE ===")
|
| 95 |
+
print(f"Total images moved: {total_moved}")
|
| 96 |
+
print(f"Total errors: {total_errors}")
|
| 97 |
+
|
| 98 |
+
# Show final counts
|
| 99 |
+
unripe_count = len(list(unripe_dir.glob("*.jpg")))
|
| 100 |
+
ripe_count = len(list(ripe_dir.glob("*.jpg")))
|
| 101 |
+
overripe_count = len(list(overripe_dir.glob("*.jpg")))
|
| 102 |
+
remaining_count = len(list(to_label_dir.glob("*.jpg")))
|
| 103 |
+
|
| 104 |
+
print(f"\nFinal counts:")
|
| 105 |
+
print(f" unripe: {unripe_count} images")
|
| 106 |
+
print(f" ripe: {ripe_count} images")
|
| 107 |
+
print(f" overripe: {overripe_count} images")
|
| 108 |
+
print(f" to_label: {remaining_count} images")
|
| 109 |
+
print(f" TOTAL: {unripe_count + ripe_count + overripe_count + remaining_count} images")
|
| 110 |
+
|
| 111 |
+
completion_rate = (unripe_count + ripe_count + overripe_count) / (unripe_count + ripe_count + overripe_count + remaining_count) * 100
|
| 112 |
+
print(f" Completion: {completion_rate:.1f}%")
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
organize_labeled_images()
|
scripts/setup_training.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Setup script for Strawberry Picker ML Training Environment
|
| 4 |
+
This script installs dependencies and validates the training setup.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import subprocess
|
| 10 |
+
import argparse
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
def run_command(cmd, description=""):
|
| 14 |
+
"""Run a shell command and handle errors"""
|
| 15 |
+
print(f"\n{'='*60}")
|
| 16 |
+
print(f"Running: {description or cmd}")
|
| 17 |
+
print(f"{'='*60}\n")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
|
| 21 |
+
if result.stdout:
|
| 22 |
+
print(result.stdout)
|
| 23 |
+
return True
|
| 24 |
+
except subprocess.CalledProcessError as e:
|
| 25 |
+
print(f"ERROR: Command failed with return code {e.returncode}")
|
| 26 |
+
if e.stdout:
|
| 27 |
+
print(f"STDOUT: {e.stdout}")
|
| 28 |
+
if e.stderr:
|
| 29 |
+
print(f"STDERR: {e.stderr}")
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
def check_python_version():
|
| 33 |
+
"""Check if Python version is compatible"""
|
| 34 |
+
print("Checking Python version...")
|
| 35 |
+
version = sys.version_info
|
| 36 |
+
if version.major < 3 or (version.major == 3 and version.minor < 8):
|
| 37 |
+
print(f"ERROR: Python 3.8+ required. Found {version.major}.{version.minor}")
|
| 38 |
+
return False
|
| 39 |
+
print(f"✓ Python {version.major}.{version.minor}.{version.micro}")
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
def check_pip():
|
| 43 |
+
"""Check if pip is available"""
|
| 44 |
+
print("Checking pip availability...")
|
| 45 |
+
return run_command("pip --version", "Check pip version")
|
| 46 |
+
|
| 47 |
+
def install_requirements():
|
| 48 |
+
"""Install Python dependencies"""
|
| 49 |
+
print("Installing Python dependencies...")
|
| 50 |
+
requirements_file = Path(__file__).parent / "requirements.txt"
|
| 51 |
+
|
| 52 |
+
if not requirements_file.exists():
|
| 53 |
+
print(f"ERROR: requirements.txt not found at {requirements_file}")
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
# Upgrade pip first
|
| 57 |
+
if not run_command("pip install --upgrade pip", "Upgrade pip"):
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
# Install requirements
|
| 61 |
+
return run_command(f"pip install -r {requirements_file}", "Install requirements")
|
| 62 |
+
|
| 63 |
+
def check_ultralytics():
|
| 64 |
+
"""Check if ultralytics is installed correctly"""
|
| 65 |
+
print("Checking ultralytics installation...")
|
| 66 |
+
try:
|
| 67 |
+
from ultralytics import YOLO
|
| 68 |
+
print("✓ ultralytics installed successfully")
|
| 69 |
+
return True
|
| 70 |
+
except ImportError as e:
|
| 71 |
+
print(f"ERROR: Failed to import ultralytics: {e}")
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
def check_torch():
|
| 75 |
+
"""Check PyTorch installation and GPU availability"""
|
| 76 |
+
print("Checking PyTorch installation...")
|
| 77 |
+
try:
|
| 78 |
+
import torch
|
| 79 |
+
print(f"✓ PyTorch version: {torch.__version__}")
|
| 80 |
+
|
| 81 |
+
if torch.cuda.is_available():
|
| 82 |
+
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
| 83 |
+
print(f"✓ CUDA version: {torch.version.cuda}")
|
| 84 |
+
else:
|
| 85 |
+
print("⚠ GPU not available, will use CPU for training")
|
| 86 |
+
|
| 87 |
+
return True
|
| 88 |
+
except ImportError as e:
|
| 89 |
+
print(f"ERROR: Failed to import torch: {e}")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def validate_dataset():
|
| 93 |
+
"""Validate dataset structure"""
|
| 94 |
+
print("Validating dataset structure...")
|
| 95 |
+
|
| 96 |
+
dataset_path = Path(__file__).parent / "model" / "dataset" / "straw-detect.v1-straw-detect.yolov8"
|
| 97 |
+
data_yaml = dataset_path / "data.yaml"
|
| 98 |
+
|
| 99 |
+
if not data_yaml.exists():
|
| 100 |
+
print(f"ERROR: data.yaml not found at {data_yaml}")
|
| 101 |
+
print("Please ensure your dataset is in the correct location")
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
import yaml
|
| 106 |
+
with open(data_yaml, 'r') as f:
|
| 107 |
+
data = yaml.safe_load(f)
|
| 108 |
+
|
| 109 |
+
print(f"✓ Dataset configuration loaded")
|
| 110 |
+
print(f" Classes: {data['nc']}")
|
| 111 |
+
print(f" Names: {data['names']}")
|
| 112 |
+
|
| 113 |
+
# Check training images
|
| 114 |
+
train_path = dataset_path / data['train']
|
| 115 |
+
if train_path.exists():
|
| 116 |
+
train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))
|
| 117 |
+
print(f" Training images: {len(train_images)}")
|
| 118 |
+
else:
|
| 119 |
+
print(f"⚠ Training path not found: {train_path}")
|
| 120 |
+
|
| 121 |
+
# Check validation images
|
| 122 |
+
val_path = dataset_path / data['val']
|
| 123 |
+
if val_path.exists():
|
| 124 |
+
val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))
|
| 125 |
+
print(f" Validation images: {len(val_images)}")
|
| 126 |
+
else:
|
| 127 |
+
print(f"⚠ Validation path not found: {val_path}")
|
| 128 |
+
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"ERROR: Failed to validate dataset: {e}")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
def create_directories():
|
| 136 |
+
"""Create necessary directories"""
|
| 137 |
+
print("Creating project directories...")
|
| 138 |
+
|
| 139 |
+
base_path = Path(__file__).parent
|
| 140 |
+
dirs = [
|
| 141 |
+
base_path / "model" / "weights",
|
| 142 |
+
base_path / "model" / "results",
|
| 143 |
+
base_path / "model" / "exports"
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
for dir_path in dirs:
|
| 147 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
print(f"✓ Created: {dir_path}")
|
| 149 |
+
|
| 150 |
+
return True
|
| 151 |
+
|
| 152 |
+
def main():
|
| 153 |
+
parser = argparse.ArgumentParser(description='Setup training environment for strawberry detection')
|
| 154 |
+
parser.add_argument('--skip-install', action='store_true', help='Skip package installation')
|
| 155 |
+
parser.add_argument('--validate-only', action='store_true', help='Only validate setup without installing')
|
| 156 |
+
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
|
| 159 |
+
print("="*60)
|
| 160 |
+
print("Strawberry Picker ML Training Environment Setup")
|
| 161 |
+
print("="*60)
|
| 162 |
+
|
| 163 |
+
# Step 1: Check Python version
|
| 164 |
+
if not check_python_version():
|
| 165 |
+
sys.exit(1)
|
| 166 |
+
|
| 167 |
+
# Step 2: Check pip
|
| 168 |
+
if not check_pip():
|
| 169 |
+
sys.exit(1)
|
| 170 |
+
|
| 171 |
+
# Step 3: Install requirements (unless skipped)
|
| 172 |
+
if not args.skip_install and not args.validate_only:
|
| 173 |
+
if not install_requirements():
|
| 174 |
+
print("\n⚠ Installation failed. Please check the errors above.")
|
| 175 |
+
response = input("Continue with validation anyway? (y/n): ")
|
| 176 |
+
if response.lower() != 'y':
|
| 177 |
+
sys.exit(1)
|
| 178 |
+
|
| 179 |
+
# Step 4: Check ultralytics
|
| 180 |
+
if not check_ultralytics():
|
| 181 |
+
sys.exit(1)
|
| 182 |
+
|
| 183 |
+
# Step 5: Check PyTorch
|
| 184 |
+
if not check_torch():
|
| 185 |
+
sys.exit(1)
|
| 186 |
+
|
| 187 |
+
# Step 6: Validate dataset
|
| 188 |
+
if not validate_dataset():
|
| 189 |
+
print("\n⚠ Dataset validation failed. Please fix the issues above.")
|
| 190 |
+
if not args.validate_only:
|
| 191 |
+
response = input("Continue with directory creation anyway? (y/n): ")
|
| 192 |
+
if response.lower() != 'y':
|
| 193 |
+
sys.exit(1)
|
| 194 |
+
|
| 195 |
+
# Step 7: Create directories
|
| 196 |
+
if not args.validate_only:
|
| 197 |
+
if not create_directories():
|
| 198 |
+
sys.exit(1)
|
| 199 |
+
|
| 200 |
+
print("\n" + "="*60)
|
| 201 |
+
if args.validate_only:
|
| 202 |
+
print("Setup validation completed!")
|
| 203 |
+
else:
|
| 204 |
+
print("Setup completed successfully!")
|
| 205 |
+
|
| 206 |
+
print("\nNext steps:")
|
| 207 |
+
print("1. Run training: python train_yolov8.py")
|
| 208 |
+
print("2. Or open train_yolov8_colab.ipynb in Google Colab")
|
| 209 |
+
print("3. Check README.md for detailed instructions")
|
| 210 |
+
print("="*60)
|
| 211 |
+
|
| 212 |
+
if __name__ == '__main__':
|
| 213 |
+
main()
|
scripts/train_model.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
| 3 |
+
from tensorflow.keras.applications import MobileNetV2
|
| 4 |
+
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
| 5 |
+
from tensorflow.keras.models import Model
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Data directories
|
| 9 |
+
data_dir = 'dataset'
|
| 10 |
+
train_datagen = ImageDataGenerator(
|
| 11 |
+
rescale=1./255,
|
| 12 |
+
validation_split=0.2,
|
| 13 |
+
rotation_range=20,
|
| 14 |
+
width_shift_range=0.2,
|
| 15 |
+
height_shift_range=0.2,
|
| 16 |
+
horizontal_flip=True
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
train_generator = train_datagen.flow_from_directory(
|
| 20 |
+
data_dir,
|
| 21 |
+
target_size=(224, 224),
|
| 22 |
+
batch_size=32,
|
| 23 |
+
class_mode='binary',
|
| 24 |
+
subset='training'
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
validation_generator = train_datagen.flow_from_directory(
|
| 28 |
+
data_dir,
|
| 29 |
+
target_size=(224, 224),
|
| 30 |
+
batch_size=32,
|
| 31 |
+
class_mode='binary',
|
| 32 |
+
subset='validation'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Load pre-trained MobileNetV2
|
| 36 |
+
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
| 37 |
+
|
| 38 |
+
# Add custom layers
|
| 39 |
+
x = base_model.output
|
| 40 |
+
x = GlobalAveragePooling2D()(x)
|
| 41 |
+
x = Dense(1024, activation='relu')(x)
|
| 42 |
+
predictions = Dense(1, activation='sigmoid')(x)
|
| 43 |
+
|
| 44 |
+
model = Model(inputs=base_model.input, outputs=predictions)
|
| 45 |
+
|
| 46 |
+
# Freeze base layers
|
| 47 |
+
for layer in base_model.layers:
|
| 48 |
+
layer.trainable = False
|
| 49 |
+
|
| 50 |
+
# Compile
|
| 51 |
+
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
| 52 |
+
|
| 53 |
+
# Train
|
| 54 |
+
model.fit(train_generator, validation_data=validation_generator, epochs=10)
|
| 55 |
+
|
| 56 |
+
# Save model
|
| 57 |
+
model.save('strawberry_model.h5')
|
| 58 |
+
|
| 59 |
+
print("Model trained and saved as strawberry_model.h5")
|
scripts/train_ripeness_classifier.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Strawberry Ripeness Classification Training Script
|
| 4 |
+
Trains a 3-class classifier (unripe/ripe/overripe) using transfer learning
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import yaml
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
|
| 18 |
+
# Deep Learning
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.optim as optim
|
| 22 |
+
from torch.utils.data import Dataset, DataLoader
|
| 23 |
+
import torchvision.transforms as transforms
|
| 24 |
+
import torchvision.models as models
|
| 25 |
+
from torchvision.datasets import ImageFolder
|
| 26 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 27 |
+
from sklearn.model_selection import train_test_split
|
| 28 |
+
|
| 29 |
+
# Set random seeds for reproducibility
|
| 30 |
+
torch.manual_seed(42)
|
| 31 |
+
np.random.seed(42)
|
| 32 |
+
|
| 33 |
+
class RipenessDataset(Dataset):
|
| 34 |
+
"""Custom dataset for strawberry ripeness classification"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, data_dir, transform=None, split='train'):
|
| 37 |
+
self.data_dir = Path(data_dir)
|
| 38 |
+
self.transform = transform
|
| 39 |
+
self.split = split
|
| 40 |
+
|
| 41 |
+
# Get class names and counts (exclude 'to_label' directory)
|
| 42 |
+
self.classes = sorted([d.name for d in self.data_dir.iterdir()
|
| 43 |
+
if d.is_dir() and d.name != 'to_label'])
|
| 44 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
|
| 45 |
+
|
| 46 |
+
# Get all image paths and labels
|
| 47 |
+
self.samples = []
|
| 48 |
+
for class_name in self.classes:
|
| 49 |
+
class_dir = self.data_dir / class_name
|
| 50 |
+
if class_dir.exists():
|
| 51 |
+
for img_path in class_dir.glob('*.jpg'):
|
| 52 |
+
self.samples.append((str(img_path), self.class_to_idx[class_name]))
|
| 53 |
+
|
| 54 |
+
print(f"{split} dataset: {len(self.samples)} samples")
|
| 55 |
+
print(f"Classes: {self.classes}")
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.samples)
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, idx):
|
| 61 |
+
img_path, label = self.samples[idx]
|
| 62 |
+
|
| 63 |
+
# Load image
|
| 64 |
+
from PIL import Image
|
| 65 |
+
image = Image.open(img_path).convert('RGB')
|
| 66 |
+
|
| 67 |
+
if self.transform:
|
| 68 |
+
image = self.transform(image)
|
| 69 |
+
|
| 70 |
+
return image, label
|
| 71 |
+
|
| 72 |
+
def get_transforms(img_size=224):
|
| 73 |
+
"""Get data transforms for training and validation"""
|
| 74 |
+
|
| 75 |
+
# Training transforms with augmentation
|
| 76 |
+
train_transform = transforms.Compose([
|
| 77 |
+
transforms.Resize((img_size, img_size)),
|
| 78 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 79 |
+
transforms.RandomRotation(degrees=15),
|
| 80 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 81 |
+
transforms.ToTensor(),
|
| 82 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
# Validation transforms (no augmentation)
|
| 86 |
+
val_transform = transforms.Compose([
|
| 87 |
+
transforms.Resize((img_size, img_size)),
|
| 88 |
+
transforms.ToTensor(),
|
| 89 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 90 |
+
])
|
| 91 |
+
|
| 92 |
+
return train_transform, val_transform
|
| 93 |
+
|
| 94 |
+
def create_model(num_classes=3, backbone='resnet18', pretrained=True):
|
| 95 |
+
"""Create model with transfer learning"""
|
| 96 |
+
|
| 97 |
+
if backbone == 'resnet18':
|
| 98 |
+
model = models.resnet18(pretrained=pretrained)
|
| 99 |
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 100 |
+
elif backbone == 'resnet50':
|
| 101 |
+
model = models.resnet50(pretrained=pretrained)
|
| 102 |
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 103 |
+
elif backbone == 'efficientnet_b0':
|
| 104 |
+
model = models.efficientnet_b0(pretrained=pretrained)
|
| 105 |
+
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unsupported backbone: {backbone}")
|
| 108 |
+
|
| 109 |
+
return model
|
| 110 |
+
|
| 111 |
+
def train_model(model, train_loader, val_loader, device, num_epochs=50, lr=0.001):
|
| 112 |
+
"""Train the model"""
|
| 113 |
+
|
| 114 |
+
criterion = nn.CrossEntropyLoss()
|
| 115 |
+
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 116 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.5)
|
| 117 |
+
|
| 118 |
+
best_val_acc = 0.0
|
| 119 |
+
train_losses = []
|
| 120 |
+
val_accuracies = []
|
| 121 |
+
|
| 122 |
+
for epoch in range(num_epochs):
|
| 123 |
+
# Training phase
|
| 124 |
+
model.train()
|
| 125 |
+
running_loss = 0.0
|
| 126 |
+
correct = 0
|
| 127 |
+
total = 0
|
| 128 |
+
|
| 129 |
+
for batch_idx, (images, labels) in enumerate(train_loader):
|
| 130 |
+
images, labels = images.to(device), labels.to(device)
|
| 131 |
+
|
| 132 |
+
optimizer.zero_grad()
|
| 133 |
+
outputs = model(images)
|
| 134 |
+
loss = criterion(outputs, labels)
|
| 135 |
+
loss.backward()
|
| 136 |
+
optimizer.step()
|
| 137 |
+
|
| 138 |
+
running_loss += loss.item()
|
| 139 |
+
_, predicted = outputs.max(1)
|
| 140 |
+
total += labels.size(0)
|
| 141 |
+
correct += predicted.eq(labels).sum().item()
|
| 142 |
+
|
| 143 |
+
if batch_idx % 10 == 0:
|
| 144 |
+
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, '
|
| 145 |
+
f'Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')
|
| 146 |
+
|
| 147 |
+
train_loss = running_loss / len(train_loader)
|
| 148 |
+
train_acc = 100. * correct / total
|
| 149 |
+
|
| 150 |
+
# Validation phase
|
| 151 |
+
model.eval()
|
| 152 |
+
val_correct = 0
|
| 153 |
+
val_total = 0
|
| 154 |
+
val_loss = 0.0
|
| 155 |
+
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
for images, labels in val_loader:
|
| 158 |
+
images, labels = images.to(device), labels.to(device)
|
| 159 |
+
outputs = model(images)
|
| 160 |
+
loss = criterion(outputs, labels)
|
| 161 |
+
|
| 162 |
+
val_loss += loss.item()
|
| 163 |
+
_, predicted = outputs.max(1)
|
| 164 |
+
val_total += labels.size(0)
|
| 165 |
+
val_correct += predicted.eq(labels).sum().item()
|
| 166 |
+
|
| 167 |
+
val_acc = 100. * val_correct / val_total
|
| 168 |
+
val_loss = val_loss / len(val_loader)
|
| 169 |
+
|
| 170 |
+
train_losses.append(train_loss)
|
| 171 |
+
val_accuracies.append(val_acc)
|
| 172 |
+
|
| 173 |
+
print(f'Epoch {epoch+1}/{num_epochs}:')
|
| 174 |
+
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
|
| 175 |
+
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
|
| 176 |
+
print('-' * 50)
|
| 177 |
+
|
| 178 |
+
# Save best model
|
| 179 |
+
if val_acc > best_val_acc:
|
| 180 |
+
best_val_acc = val_acc
|
| 181 |
+
torch.save(model.state_dict(), 'model/ripeness_classifier_best.pth')
|
| 182 |
+
print(f'New best model saved! Val Acc: {best_val_acc:.2f}%')
|
| 183 |
+
|
| 184 |
+
scheduler.step(val_acc)
|
| 185 |
+
|
| 186 |
+
return train_losses, val_accuracies, best_val_acc
|
| 187 |
+
|
| 188 |
+
def evaluate_model(model, test_loader, device, class_names):
|
| 189 |
+
"""Evaluate model and generate reports"""
|
| 190 |
+
|
| 191 |
+
model.eval()
|
| 192 |
+
all_preds = []
|
| 193 |
+
all_labels = []
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for images, labels in test_loader:
|
| 197 |
+
images, labels = images.to(device), labels.to(device)
|
| 198 |
+
outputs = model(images)
|
| 199 |
+
_, predicted = outputs.max(1)
|
| 200 |
+
|
| 201 |
+
all_preds.extend(predicted.cpu().numpy())
|
| 202 |
+
all_labels.extend(labels.cpu().numpy())
|
| 203 |
+
|
| 204 |
+
# Classification report
|
| 205 |
+
report = classification_report(all_labels, all_preds, target_names=class_names)
|
| 206 |
+
print("Classification Report:")
|
| 207 |
+
print(report)
|
| 208 |
+
|
| 209 |
+
# Confusion matrix
|
| 210 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 211 |
+
|
| 212 |
+
# Plot confusion matrix
|
| 213 |
+
plt.figure(figsize=(8, 6))
|
| 214 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 215 |
+
xticklabels=class_names, yticklabels=class_names)
|
| 216 |
+
plt.title('Confusion Matrix')
|
| 217 |
+
plt.ylabel('True Label')
|
| 218 |
+
plt.xlabel('Predicted Label')
|
| 219 |
+
plt.savefig('model/ripeness_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 220 |
+
plt.close()
|
| 221 |
+
|
| 222 |
+
return report, cm
|
| 223 |
+
|
| 224 |
+
def plot_training_history(train_losses, val_accuracies, save_path):
|
| 225 |
+
"""Plot training history"""
|
| 226 |
+
|
| 227 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
| 228 |
+
|
| 229 |
+
# Plot training loss
|
| 230 |
+
ax1.plot(train_losses)
|
| 231 |
+
ax1.set_title('Training Loss')
|
| 232 |
+
ax1.set_xlabel('Epoch')
|
| 233 |
+
ax1.set_ylabel('Loss')
|
| 234 |
+
ax1.grid(True)
|
| 235 |
+
|
| 236 |
+
# Plot validation accuracy
|
| 237 |
+
ax2.plot(val_accuracies)
|
| 238 |
+
ax2.set_title('Validation Accuracy')
|
| 239 |
+
ax2.set_xlabel('Epoch')
|
| 240 |
+
ax2.set_ylabel('Accuracy (%)')
|
| 241 |
+
ax2.grid(True)
|
| 242 |
+
|
| 243 |
+
plt.tight_layout()
|
| 244 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 245 |
+
plt.close()
|
| 246 |
+
|
| 247 |
+
def main():
|
| 248 |
+
parser = argparse.ArgumentParser(description='Train strawberry ripeness classifier')
|
| 249 |
+
parser.add_argument('--data-dir', default='model/ripeness_manual_dataset',
|
| 250 |
+
help='Directory containing labeled images')
|
| 251 |
+
parser.add_argument('--img-size', type=int, default=224, help='Image size')
|
| 252 |
+
parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
|
| 253 |
+
parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
|
| 254 |
+
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
| 255 |
+
parser.add_argument('--backbone', default='resnet18',
|
| 256 |
+
choices=['resnet18', 'resnet50', 'efficientnet_b0'],
|
| 257 |
+
help='Backbone architecture')
|
| 258 |
+
parser.add_argument('--val-split', type=float, default=0.2, help='Validation split ratio')
|
| 259 |
+
parser.add_argument('--output-dir', default='model/ripeness_classifier',
|
| 260 |
+
help='Output directory for models and results')
|
| 261 |
+
|
| 262 |
+
args = parser.parse_args()
|
| 263 |
+
|
| 264 |
+
# Create output directory
|
| 265 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 266 |
+
|
| 267 |
+
# Load config
|
| 268 |
+
with open('config.yaml', 'r') as f:
|
| 269 |
+
config = yaml.safe_load(f)
|
| 270 |
+
|
| 271 |
+
# Set device
|
| 272 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 273 |
+
print(f"Using device: {device}")
|
| 274 |
+
|
| 275 |
+
# Get transforms
|
| 276 |
+
train_transform, val_transform = get_transforms(args.img_size)
|
| 277 |
+
|
| 278 |
+
# Create datasets
|
| 279 |
+
train_dataset = RipenessDataset(args.data_dir, transform=train_transform, split='train')
|
| 280 |
+
val_dataset = RipenessDataset(args.data_dir, transform=val_transform, split='val')
|
| 281 |
+
|
| 282 |
+
# Create data loaders
|
| 283 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
|
| 284 |
+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
|
| 285 |
+
|
| 286 |
+
# Create model
|
| 287 |
+
num_classes = len(train_dataset.classes)
|
| 288 |
+
model = create_model(num_classes=num_classes, backbone=args.backbone, pretrained=True)
|
| 289 |
+
model = model.to(device)
|
| 290 |
+
|
| 291 |
+
print(f"Model created with {num_classes} classes: {train_dataset.classes}")
|
| 292 |
+
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 293 |
+
|
| 294 |
+
# Train model
|
| 295 |
+
print("Starting training...")
|
| 296 |
+
train_losses, val_accuracies, best_val_acc = train_model(
|
| 297 |
+
model, train_loader, val_loader, device,
|
| 298 |
+
num_epochs=args.epochs, lr=args.lr
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Load best model for evaluation
|
| 302 |
+
model.load_state_dict(torch.load('model/ripeness_classifier_best.pth'))
|
| 303 |
+
|
| 304 |
+
# Evaluate model
|
| 305 |
+
print("Evaluating model...")
|
| 306 |
+
report, cm = evaluate_model(model, val_loader, device, train_dataset.classes)
|
| 307 |
+
|
| 308 |
+
# Plot training history
|
| 309 |
+
plot_training_history(train_losses, val_accuracies,
|
| 310 |
+
f'{args.output_dir}/training_history.png')
|
| 311 |
+
|
| 312 |
+
# Save results
|
| 313 |
+
results = {
|
| 314 |
+
'model_architecture': args.backbone,
|
| 315 |
+
'num_classes': num_classes,
|
| 316 |
+
'class_names': train_dataset.classes,
|
| 317 |
+
'best_val_accuracy': best_val_acc,
|
| 318 |
+
'training_config': {
|
| 319 |
+
'img_size': args.img_size,
|
| 320 |
+
'batch_size': args.batch_size,
|
| 321 |
+
'epochs': args.epochs,
|
| 322 |
+
'learning_rate': args.lr,
|
| 323 |
+
'val_split': args.val_split
|
| 324 |
+
},
|
| 325 |
+
'dataset_info': {
|
| 326 |
+
'total_samples': len(train_dataset),
|
| 327 |
+
'class_distribution': {cls: len(list(Path(args.data_dir, cls).glob('*.jpg')))
|
| 328 |
+
for cls in train_dataset.classes}
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
with open(f'{args.output_dir}/training_results.json', 'w') as f:
|
| 333 |
+
json.dump(results, f, indent=2)
|
| 334 |
+
|
| 335 |
+
# Save classification report
|
| 336 |
+
with open(f'{args.output_dir}/classification_report.txt', 'w') as f:
|
| 337 |
+
f.write(report)
|
| 338 |
+
|
| 339 |
+
print(f"\nTraining completed!")
|
| 340 |
+
print(f"Best validation accuracy: {best_val_acc:.2f}%")
|
| 341 |
+
print(f"Results saved to: {args.output_dir}")
|
| 342 |
+
print(f"Model saved to: model/ripeness_classifier_best.pth")
|
| 343 |
+
|
| 344 |
+
if __name__ == '__main__':
|
| 345 |
+
main()
|
scripts/train_yolov8.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
YOLOv8 Training Script for Strawberry Detection
|
| 4 |
+
Compatible with: Local Python, WSL, Google Colab (VS Code extension)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import torch
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
def check_environment():
|
| 15 |
+
"""Detect running environment and configure paths accordingly"""
|
| 16 |
+
env_info = {
|
| 17 |
+
'is_colab': 'COLAB_GPU' in os.environ or '/content' in os.getcwd(),
|
| 18 |
+
'is_wsl': 'WSL_DISTRO_NAME' in os.environ,
|
| 19 |
+
'has_gpu': torch.cuda.is_available(),
|
| 20 |
+
'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
|
| 21 |
+
}
|
| 22 |
+
return env_info
|
| 23 |
+
|
| 24 |
+
def setup_paths(dataset_path=None):
|
| 25 |
+
"""Configure dataset and output paths based on environment"""
|
| 26 |
+
env = check_environment()
|
| 27 |
+
|
| 28 |
+
if env['is_colab']:
|
| 29 |
+
# Google Colab paths
|
| 30 |
+
base_path = Path('/content/strawberry-picker')
|
| 31 |
+
dataset_path = dataset_path or '/content/dataset'
|
| 32 |
+
weights_dir = base_path / 'weights'
|
| 33 |
+
results_dir = base_path / 'results'
|
| 34 |
+
else:
|
| 35 |
+
# Local/WSL paths
|
| 36 |
+
base_path = Path(__file__).parent
|
| 37 |
+
dataset_path = dataset_path or base_path / 'model' / 'dataset' / 'straw-detect.v1-straw-detect.yolov8'
|
| 38 |
+
weights_dir = base_path / 'model' / 'weights'
|
| 39 |
+
results_dir = base_path / 'model' / 'results'
|
| 40 |
+
|
| 41 |
+
# Create directories
|
| 42 |
+
weights_dir.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
'dataset_path': Path(dataset_path),
|
| 47 |
+
'weights_dir': weights_dir,
|
| 48 |
+
'results_dir': results_dir,
|
| 49 |
+
'base_path': base_path
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def validate_dataset(dataset_path):
|
| 53 |
+
"""Validate YOLO dataset structure"""
|
| 54 |
+
dataset_path = Path(dataset_path)
|
| 55 |
+
data_yaml = dataset_path / 'data.yaml'
|
| 56 |
+
|
| 57 |
+
if not data_yaml.exists():
|
| 58 |
+
raise FileNotFoundError(f"data.yaml not found at {data_yaml}")
|
| 59 |
+
|
| 60 |
+
# Load and validate YAML
|
| 61 |
+
with open(data_yaml, 'r') as f:
|
| 62 |
+
data = yaml.safe_load(f)
|
| 63 |
+
|
| 64 |
+
required_keys = ['train', 'val', 'nc', 'names']
|
| 65 |
+
for key in required_keys:
|
| 66 |
+
if key not in data:
|
| 67 |
+
raise ValueError(f"Missing required key '{key}' in data.yaml")
|
| 68 |
+
|
| 69 |
+
# Check if paths are relative and resolve them
|
| 70 |
+
train_path = dataset_path / data['train']
|
| 71 |
+
val_path = dataset_path / data['val']
|
| 72 |
+
|
| 73 |
+
if not train_path.exists():
|
| 74 |
+
raise FileNotFoundError(f"Training images not found at {train_path}")
|
| 75 |
+
if not val_path.exists():
|
| 76 |
+
raise FileNotFoundError(f"Validation images not found at {val_path}")
|
| 77 |
+
|
| 78 |
+
print(f"✓ Dataset validated: {data['nc']} classes - {data['names']}")
|
| 79 |
+
print(f"✓ Training images: {train_path}")
|
| 80 |
+
print(f"✓ Validation images: {val_path}")
|
| 81 |
+
|
| 82 |
+
return data_yaml
|
| 83 |
+
|
| 84 |
+
def train_model(data_yaml, weights_dir, results_dir, epochs=100, img_size=640, batch_size=16, weights=None, resume=False):
|
| 85 |
+
"""Train YOLOv8 model (supports resuming from checkpoints)"""
|
| 86 |
+
try:
|
| 87 |
+
from ultralytics import YOLO
|
| 88 |
+
except ImportError:
|
| 89 |
+
print("ERROR: ultralytics not installed. Run: pip install ultralytics")
|
| 90 |
+
sys.exit(1)
|
| 91 |
+
|
| 92 |
+
env = check_environment()
|
| 93 |
+
print(f"\n{'='*60}")
|
| 94 |
+
print(f"Environment: {'Google Colab' if env['is_colab'] else 'Local/WSL'}")
|
| 95 |
+
print(f"GPU Available: {env['has_gpu']} ({env['gpu_name']})")
|
| 96 |
+
print(f"{'='*60}\n")
|
| 97 |
+
|
| 98 |
+
# Use GPU if available
|
| 99 |
+
device = '0' if env['has_gpu'] else 'cpu'
|
| 100 |
+
|
| 101 |
+
# Load model (custom weights or default YOLOv8n)
|
| 102 |
+
model_source = Path(weights) if weights else 'yolov8n.pt'
|
| 103 |
+
print(f"Loading model from {model_source}...")
|
| 104 |
+
model = YOLO(str(model_source))
|
| 105 |
+
|
| 106 |
+
# Training arguments
|
| 107 |
+
train_args = {
|
| 108 |
+
'data': str(data_yaml),
|
| 109 |
+
'epochs': epochs,
|
| 110 |
+
'imgsz': img_size,
|
| 111 |
+
'batch': batch_size,
|
| 112 |
+
'device': device,
|
| 113 |
+
'project': str(results_dir),
|
| 114 |
+
'name': 'strawberry_detection',
|
| 115 |
+
'exist_ok': True,
|
| 116 |
+
'patience': 20, # Early stopping patience
|
| 117 |
+
'save': True,
|
| 118 |
+
'save_period': 10, # Save checkpoint every 10 epochs
|
| 119 |
+
'cache': True, # Cache images for faster training
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if resume:
|
| 123 |
+
train_args['resume'] = True
|
| 124 |
+
|
| 125 |
+
# Adjust batch size for Colab's limited RAM
|
| 126 |
+
if env['is_colab'] and batch_size > 16:
|
| 127 |
+
train_args['batch'] = 16
|
| 128 |
+
print(f"Adjusted batch size to 16 for Colab environment")
|
| 129 |
+
|
| 130 |
+
print(f"\nStarting training with arguments:")
|
| 131 |
+
for key, value in train_args.items():
|
| 132 |
+
print(f" {key}: {value}")
|
| 133 |
+
|
| 134 |
+
# Train the model
|
| 135 |
+
print(f"\n{'='*60}")
|
| 136 |
+
print("TRAINING STARTED")
|
| 137 |
+
print(f"{'='*60}\n")
|
| 138 |
+
|
| 139 |
+
results = model.train(**train_args)
|
| 140 |
+
|
| 141 |
+
# Save final model
|
| 142 |
+
final_model_path = weights_dir / 'strawberry_yolov8n.pt'
|
| 143 |
+
model.save(str(final_model_path))
|
| 144 |
+
|
| 145 |
+
print(f"\n{'='*60}")
|
| 146 |
+
print(f"Training completed!")
|
| 147 |
+
print(f"Final model saved to: {final_model_path}")
|
| 148 |
+
print(f"Results saved to: {results_dir / 'strawberry_detection'}")
|
| 149 |
+
print(f"{'='*60}\n")
|
| 150 |
+
|
| 151 |
+
return results, final_model_path
|
| 152 |
+
|
| 153 |
+
def export_model(model_path, weights_dir):
|
| 154 |
+
"""Export model to ONNX format"""
|
| 155 |
+
try:
|
| 156 |
+
from ultralytics import YOLO
|
| 157 |
+
except ImportError:
|
| 158 |
+
print("ERROR: ultralytics not installed")
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
print(f"\nExporting model to ONNX...")
|
| 162 |
+
model = YOLO(str(model_path))
|
| 163 |
+
|
| 164 |
+
# Export to ONNX
|
| 165 |
+
onnx_path = weights_dir / 'strawberry_yolov8n.onnx'
|
| 166 |
+
model.export(format='onnx', imgsz=640, dynamic=True)
|
| 167 |
+
|
| 168 |
+
print(f"ONNX model exported to: {onnx_path}")
|
| 169 |
+
return onnx_path
|
| 170 |
+
|
| 171 |
+
def main():
|
| 172 |
+
parser = argparse.ArgumentParser(description='Train YOLOv8 for strawberry detection')
|
| 173 |
+
parser.add_argument('--dataset', type=str, help='Path to dataset directory')
|
| 174 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
| 175 |
+
parser.add_argument('--img-size', type=int, default=640, help='Image size for training')
|
| 176 |
+
parser.add_argument('--batch-size', type=int, default=16, help='Batch size for training')
|
| 177 |
+
parser.add_argument('--weights', type=str, help='Path to pretrained weights or checkpoint')
|
| 178 |
+
parser.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint')
|
| 179 |
+
parser.add_argument('--export-onnx', action='store_true', help='Export to ONNX after training')
|
| 180 |
+
parser.add_argument('--validate-only', action='store_true', help='Only validate dataset without training')
|
| 181 |
+
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
# Setup paths
|
| 186 |
+
paths = setup_paths(args.dataset)
|
| 187 |
+
print(f"Base path: {paths['base_path']}")
|
| 188 |
+
print(f"Dataset path: {paths['dataset_path']}")
|
| 189 |
+
print(f"Weights directory: {paths['weights_dir']}")
|
| 190 |
+
print(f"Results directory: {paths['results_dir']}")
|
| 191 |
+
|
| 192 |
+
# Validate dataset
|
| 193 |
+
print(f"\nValidating dataset...")
|
| 194 |
+
data_yaml = validate_dataset(paths['dataset_path'])
|
| 195 |
+
|
| 196 |
+
if args.validate_only:
|
| 197 |
+
print("Dataset validation completed. Exiting without training.")
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
# Train model
|
| 201 |
+
results, model_path = train_model(
|
| 202 |
+
data_yaml=data_yaml,
|
| 203 |
+
weights_dir=paths['weights_dir'],
|
| 204 |
+
results_dir=paths['results_dir'],
|
| 205 |
+
epochs=args.epochs,
|
| 206 |
+
img_size=args.img_size,
|
| 207 |
+
batch_size=args.batch_size,
|
| 208 |
+
weights=args.weights,
|
| 209 |
+
resume=args.resume
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Export to ONNX if requested
|
| 213 |
+
if args.export_onnx:
|
| 214 |
+
export_model(model_path, paths['weights_dir'])
|
| 215 |
+
|
| 216 |
+
print("\n✓ Training pipeline completed successfully!")
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"\n✗ Error: {str(e)}")
|
| 220 |
+
sys.exit(1)
|
| 221 |
+
|
| 222 |
+
if __name__ == '__main__':
|
| 223 |
+
main()
|
scripts/validate_model.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Model Validation Script for Strawberry Ripeness Classification
|
| 4 |
+
Tests the trained model on sample images to verify functionality
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import json
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
|
| 16 |
+
# Add current directory to path for imports
|
| 17 |
+
sys.path.append('.')
|
| 18 |
+
|
| 19 |
+
from train_ripeness_classifier import create_model, get_transforms
|
| 20 |
+
|
| 21 |
+
def load_model(model_path):
|
| 22 |
+
"""Load the trained classification model"""
|
| 23 |
+
print(f"Loading model from: {model_path}")
|
| 24 |
+
|
| 25 |
+
if not os.path.exists(model_path):
|
| 26 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 27 |
+
|
| 28 |
+
# Create model architecture
|
| 29 |
+
model = create_model(num_classes=3, backbone='resnet18', pretrained=False)
|
| 30 |
+
|
| 31 |
+
# Load trained weights
|
| 32 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 33 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 34 |
+
model = model.to(device)
|
| 35 |
+
model.eval()
|
| 36 |
+
|
| 37 |
+
print(f"Model loaded successfully on {device}")
|
| 38 |
+
return model, device
|
| 39 |
+
|
| 40 |
+
def get_test_images():
|
| 41 |
+
"""Get sample test images from the dataset"""
|
| 42 |
+
test_dirs = [
|
| 43 |
+
'model/ripeness_manual_dataset/unripe',
|
| 44 |
+
'model/ripeness_manual_dataset/ripe',
|
| 45 |
+
'model/ripeness_manual_dataset/overripe'
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
test_images = []
|
| 49 |
+
for test_dir in test_dirs:
|
| 50 |
+
if os.path.exists(test_dir):
|
| 51 |
+
images = list(Path(test_dir).glob('*.jpg'))[:3] # Get first 3 images from each class
|
| 52 |
+
for img_path in images:
|
| 53 |
+
test_images.append({
|
| 54 |
+
'path': str(img_path),
|
| 55 |
+
'true_label': os.path.basename(test_dir),
|
| 56 |
+
'class_name': os.path.basename(test_dir)
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
return test_images
|
| 60 |
+
|
| 61 |
+
def predict_image(model, device, image_path, transform):
|
| 62 |
+
"""Predict ripeness for a single image"""
|
| 63 |
+
try:
|
| 64 |
+
# Load and preprocess image
|
| 65 |
+
image = cv2.imread(image_path)
|
| 66 |
+
if image is None:
|
| 67 |
+
return None, "Failed to load image"
|
| 68 |
+
|
| 69 |
+
# Convert BGR to RGB
|
| 70 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 71 |
+
from PIL import Image
|
| 72 |
+
image_pil = Image.fromarray(image)
|
| 73 |
+
|
| 74 |
+
# Apply transforms
|
| 75 |
+
input_tensor = transform(image_pil).unsqueeze(0).to(device)
|
| 76 |
+
|
| 77 |
+
# Get prediction
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
outputs = model(input_tensor)
|
| 80 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 81 |
+
predicted_class_idx = torch.argmax(probabilities, dim=1).item()
|
| 82 |
+
confidence = probabilities[0][predicted_class_idx].item()
|
| 83 |
+
|
| 84 |
+
# Get class names
|
| 85 |
+
class_names = ['overripe', 'ripe', 'unripe']
|
| 86 |
+
predicted_class = class_names[predicted_class_idx]
|
| 87 |
+
|
| 88 |
+
# Get all probabilities
|
| 89 |
+
probs_dict = {
|
| 90 |
+
class_names[i]: float(probabilities[0][i].item())
|
| 91 |
+
for i in range(len(class_names))
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
'predicted_class': predicted_class,
|
| 96 |
+
'confidence': confidence,
|
| 97 |
+
'probabilities': probs_dict
|
| 98 |
+
}, None
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return None, str(e)
|
| 102 |
+
|
| 103 |
+
def validate_model():
|
| 104 |
+
"""Main validation function"""
|
| 105 |
+
print("=== Strawberry Ripeness Classification Model Validation ===")
|
| 106 |
+
print(f"Validation time: {datetime.now().isoformat()}")
|
| 107 |
+
print()
|
| 108 |
+
|
| 109 |
+
# Load model
|
| 110 |
+
model_path = 'model/ripeness_classifier_best.pth'
|
| 111 |
+
try:
|
| 112 |
+
model, device = load_model(model_path)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"❌ Failed to load model: {e}")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
# Get transforms
|
| 118 |
+
_, transform = get_transforms(img_size=224)
|
| 119 |
+
|
| 120 |
+
# Get test images
|
| 121 |
+
test_images = get_test_images()
|
| 122 |
+
if not test_images:
|
| 123 |
+
print("❌ No test images found")
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
print(f"Found {len(test_images)} test images")
|
| 127 |
+
print()
|
| 128 |
+
|
| 129 |
+
# Test predictions
|
| 130 |
+
results = []
|
| 131 |
+
correct_predictions = 0
|
| 132 |
+
total_predictions = 0
|
| 133 |
+
|
| 134 |
+
print("Testing predictions...")
|
| 135 |
+
print("-" * 80)
|
| 136 |
+
|
| 137 |
+
for i, test_img in enumerate(test_images):
|
| 138 |
+
image_path = test_img['path']
|
| 139 |
+
true_label = test_img['true_label']
|
| 140 |
+
|
| 141 |
+
# Make prediction
|
| 142 |
+
prediction, error = predict_image(model, device, image_path, transform)
|
| 143 |
+
|
| 144 |
+
if error:
|
| 145 |
+
print(f"❌ Image {i+1}: Error - {error}")
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
predicted_class = prediction['predicted_class']
|
| 149 |
+
confidence = prediction['confidence']
|
| 150 |
+
|
| 151 |
+
# Check if prediction is correct
|
| 152 |
+
is_correct = predicted_class == true_label
|
| 153 |
+
if is_correct:
|
| 154 |
+
correct_predictions += 1
|
| 155 |
+
total_predictions += 1
|
| 156 |
+
|
| 157 |
+
# Print result
|
| 158 |
+
status = "✅" if is_correct else "❌"
|
| 159 |
+
print(f"{status} Image {i+1}: {os.path.basename(image_path)}")
|
| 160 |
+
print(f" True: {true_label} | Predicted: {predicted_class} ({confidence:.3f})")
|
| 161 |
+
print(f" Probabilities: overripe={prediction['probabilities']['overripe']:.3f}, "
|
| 162 |
+
f"ripe={prediction['probabilities']['ripe']:.3f}, "
|
| 163 |
+
f"unripe={prediction['probabilities']['unripe']:.3f}")
|
| 164 |
+
print()
|
| 165 |
+
|
| 166 |
+
# Store result
|
| 167 |
+
results.append({
|
| 168 |
+
'image_path': image_path,
|
| 169 |
+
'true_label': true_label,
|
| 170 |
+
'predicted_class': predicted_class,
|
| 171 |
+
'confidence': confidence,
|
| 172 |
+
'probabilities': prediction['probabilities'],
|
| 173 |
+
'correct': is_correct
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
# Calculate accuracy
|
| 177 |
+
accuracy = (correct_predictions / total_predictions * 100) if total_predictions > 0 else 0
|
| 178 |
+
|
| 179 |
+
print("=" * 80)
|
| 180 |
+
print("VALIDATION RESULTS")
|
| 181 |
+
print("=" * 80)
|
| 182 |
+
print(f"Total images tested: {total_predictions}")
|
| 183 |
+
print(f"Correct predictions: {correct_predictions}")
|
| 184 |
+
print(f"Accuracy: {accuracy:.1f}%")
|
| 185 |
+
print()
|
| 186 |
+
|
| 187 |
+
# Class-wise analysis
|
| 188 |
+
class_stats = {}
|
| 189 |
+
for result in results:
|
| 190 |
+
true_class = result['true_label']
|
| 191 |
+
if true_class not in class_stats:
|
| 192 |
+
class_stats[true_class] = {'correct': 0, 'total': 0}
|
| 193 |
+
class_stats[true_class]['total'] += 1
|
| 194 |
+
if result['correct']:
|
| 195 |
+
class_stats[true_class]['correct'] += 1
|
| 196 |
+
|
| 197 |
+
print("Class-wise Performance:")
|
| 198 |
+
for class_name, stats in class_stats.items():
|
| 199 |
+
class_accuracy = (stats['correct'] / stats['total'] * 100) if stats['total'] > 0 else 0
|
| 200 |
+
print(f" {class_name}: {stats['correct']}/{stats['total']} ({class_accuracy:.1f}%)")
|
| 201 |
+
print()
|
| 202 |
+
|
| 203 |
+
# Save detailed results
|
| 204 |
+
validation_results = {
|
| 205 |
+
'validation_time': datetime.now().isoformat(),
|
| 206 |
+
'model_path': model_path,
|
| 207 |
+
'device': str(device),
|
| 208 |
+
'total_images': total_predictions,
|
| 209 |
+
'correct_predictions': correct_predictions,
|
| 210 |
+
'accuracy_percent': accuracy,
|
| 211 |
+
'class_stats': class_stats,
|
| 212 |
+
'detailed_results': results
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
results_path = 'model_validation_results.json'
|
| 216 |
+
with open(results_path, 'w') as f:
|
| 217 |
+
json.dump(validation_results, f, indent=2)
|
| 218 |
+
|
| 219 |
+
print(f"Detailed results saved to: {results_path}")
|
| 220 |
+
|
| 221 |
+
# Validation verdict
|
| 222 |
+
if accuracy >= 90:
|
| 223 |
+
print("🎉 VALIDATION PASSED: Model performs excellently!")
|
| 224 |
+
return True
|
| 225 |
+
elif accuracy >= 80:
|
| 226 |
+
print("⚠️ VALIDATION WARNING: Model performs moderately well")
|
| 227 |
+
return True
|
| 228 |
+
else:
|
| 229 |
+
print("❌ VALIDATION FAILED: Model performance is poor")
|
| 230 |
+
return False
|
| 231 |
+
|
| 232 |
+
if __name__ == '__main__':
|
| 233 |
+
success = validate_model()
|
| 234 |
+
sys.exit(0 if success else 1)
|
scripts/webcam_capture.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
cap = cv2.VideoCapture(0)
|
| 5 |
+
|
| 6 |
+
if not cap.isOpened():
|
| 7 |
+
print("Cannot open camera")
|
| 8 |
+
exit()
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
ret, frame = cap.read()
|
| 12 |
+
if ret:
|
| 13 |
+
cv2.imwrite('captured_frame.jpg', frame)
|
| 14 |
+
print("Frame captured and saved as captured_frame.jpg")
|
| 15 |
+
else:
|
| 16 |
+
print("Can't receive frame")
|
| 17 |
+
break
|
| 18 |
+
time.sleep(1)
|
| 19 |
+
|
| 20 |
+
cap.release()
|
src/arduino_bridge.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Arduino Bridge - Serial Communication for Robotic Arm Control
|
| 4 |
+
Handles communication between Python pipeline and Arduino microcontroller
|
| 5 |
+
|
| 6 |
+
Author: AI Assistant
|
| 7 |
+
Date: 2025-12-15
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import serial
|
| 11 |
+
import serial.tools.list_ports
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import threading
|
| 15 |
+
from typing import Optional, Tuple, List, Dict
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
import json
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ServoPosition:
|
| 24 |
+
"""Represents a servo position"""
|
| 25 |
+
servo_id: int
|
| 26 |
+
angle: float
|
| 27 |
+
timestamp: float
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class SensorData:
|
| 31 |
+
"""Represents sensor data from Arduino"""
|
| 32 |
+
limit_switches: Dict[int, bool]
|
| 33 |
+
force_sensor: float
|
| 34 |
+
temperature: float
|
| 35 |
+
timestamp: float
|
| 36 |
+
|
| 37 |
+
class ArduinoBridge:
|
| 38 |
+
"""
|
| 39 |
+
Bridge for communication with Arduino-based robotic arm controller
|
| 40 |
+
Handles servo control, sensor reading, and safety monitoring
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, port: str = '/dev/ttyUSB0', baudrate: int = 115200):
|
| 44 |
+
"""Initialize Arduino bridge"""
|
| 45 |
+
self.port = port
|
| 46 |
+
self.baudrate = baudrate
|
| 47 |
+
self.serial_connection: Optional[serial.Serial] = None
|
| 48 |
+
self.connected = False
|
| 49 |
+
self.running = False
|
| 50 |
+
|
| 51 |
+
# Command queue for async communication
|
| 52 |
+
self.command_queue = []
|
| 53 |
+
self.response_queue = []
|
| 54 |
+
self.queue_lock = threading.Lock()
|
| 55 |
+
|
| 56 |
+
# Current servo positions
|
| 57 |
+
self.current_positions = {}
|
| 58 |
+
self.target_positions = {}
|
| 59 |
+
|
| 60 |
+
# Sensor data
|
| 61 |
+
self.latest_sensor_data: Optional[SensorData] = None
|
| 62 |
+
self.sensor_history = []
|
| 63 |
+
|
| 64 |
+
# Safety parameters
|
| 65 |
+
self.max_servo_angle = 180.0
|
| 66 |
+
self.min_servo_angle = 0.0
|
| 67 |
+
self.movement_timeout = 10.0 # seconds
|
| 68 |
+
self.emergency_stop_flag = False
|
| 69 |
+
|
| 70 |
+
# Communication thread
|
| 71 |
+
self.comm_thread: Optional[threading.Thread] = None
|
| 72 |
+
|
| 73 |
+
logger.info(f"Arduino Bridge initialized for port {port} at {baudrate} baud")
|
| 74 |
+
|
| 75 |
+
def connect(self) -> bool:
|
| 76 |
+
"""Connect to Arduino via serial port"""
|
| 77 |
+
try:
|
| 78 |
+
# Auto-detect port if not specified
|
| 79 |
+
if self.port == '/dev/ttyUSB0':
|
| 80 |
+
self.port = self._auto_detect_port()
|
| 81 |
+
if not self.port:
|
| 82 |
+
logger.error("Could not auto-detect Arduino port")
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
# Establish serial connection
|
| 86 |
+
self.serial_connection = serial.Serial(
|
| 87 |
+
port=self.port,
|
| 88 |
+
baudrate=self.baudrate,
|
| 89 |
+
timeout=1.0,
|
| 90 |
+
write_timeout=1.0
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Wait for Arduino to initialize
|
| 94 |
+
time.sleep(2.0)
|
| 95 |
+
|
| 96 |
+
# Test connection
|
| 97 |
+
if self._test_connection():
|
| 98 |
+
self.connected = True
|
| 99 |
+
self.running = True
|
| 100 |
+
|
| 101 |
+
# Start communication thread
|
| 102 |
+
self.comm_thread = threading.Thread(target=self._communication_loop, daemon=True)
|
| 103 |
+
self.comm_thread.start()
|
| 104 |
+
|
| 105 |
+
logger.info(f"Successfully connected to Arduino on {self.port}")
|
| 106 |
+
return True
|
| 107 |
+
else:
|
| 108 |
+
logger.error("Arduino connection test failed")
|
| 109 |
+
self.disconnect()
|
| 110 |
+
return False
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"Failed to connect to Arduino: {e}")
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
def disconnect(self):
|
| 117 |
+
"""Disconnect from Arduino"""
|
| 118 |
+
logger.info("Disconnecting from Arduino...")
|
| 119 |
+
|
| 120 |
+
self.running = False
|
| 121 |
+
self.connected = False
|
| 122 |
+
|
| 123 |
+
# Stop communication thread
|
| 124 |
+
if self.comm_thread and self.comm_thread.is_alive():
|
| 125 |
+
self.comm_thread.join(timeout=2.0)
|
| 126 |
+
|
| 127 |
+
# Close serial connection
|
| 128 |
+
if self.serial_connection and self.serial_connection.is_open:
|
| 129 |
+
self.serial_connection.close()
|
| 130 |
+
|
| 131 |
+
logger.info("Arduino disconnected")
|
| 132 |
+
|
| 133 |
+
def _auto_detect_port(self) -> Optional[str]:
|
| 134 |
+
"""Auto-detect Arduino port"""
|
| 135 |
+
ports = serial.tools.list_ports.comports()
|
| 136 |
+
|
| 137 |
+
for port in ports:
|
| 138 |
+
# Look for common Arduino identifiers
|
| 139 |
+
if any(keyword in port.description.lower() for keyword in
|
| 140 |
+
['arduino', 'ch340', 'cp2102', 'ftdi']):
|
| 141 |
+
logger.info(f"Auto-detected Arduino on port: {port.device}")
|
| 142 |
+
return port.device
|
| 143 |
+
|
| 144 |
+
# If no Arduino found, return first available port
|
| 145 |
+
if ports:
|
| 146 |
+
logger.warning(f"No Arduino detected, using first available port: {ports[0].device}")
|
| 147 |
+
return ports[0].device
|
| 148 |
+
|
| 149 |
+
logger.error("No serial ports available")
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
def _test_connection(self) -> bool:
|
| 153 |
+
"""Test connection with Arduino"""
|
| 154 |
+
try:
|
| 155 |
+
# Send test command
|
| 156 |
+
self._send_command("PING")
|
| 157 |
+
|
| 158 |
+
# Wait for response
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
while time.time() - start_time < 3.0:
|
| 161 |
+
if self._check_response("PONG"):
|
| 162 |
+
return True
|
| 163 |
+
time.sleep(0.1)
|
| 164 |
+
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Connection test failed: {e}")
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
def _communication_loop(self):
|
| 172 |
+
"""Main communication loop for async processing"""
|
| 173 |
+
logger.info("Arduino communication loop started")
|
| 174 |
+
|
| 175 |
+
while self.running and self.connected:
|
| 176 |
+
try:
|
| 177 |
+
# Process outgoing commands
|
| 178 |
+
self._process_command_queue()
|
| 179 |
+
|
| 180 |
+
# Read incoming data
|
| 181 |
+
self._read_serial_data()
|
| 182 |
+
|
| 183 |
+
# Process sensor data
|
| 184 |
+
self._process_sensor_data()
|
| 185 |
+
|
| 186 |
+
time.sleep(0.01) # 10ms loop delay
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Communication loop error: {e}")
|
| 190 |
+
time.sleep(0.1)
|
| 191 |
+
|
| 192 |
+
logger.info("Arduino communication loop stopped")
|
| 193 |
+
|
| 194 |
+
def _process_command_queue(self):
|
| 195 |
+
"""Process commands in the queue"""
|
| 196 |
+
with self.queue_lock:
|
| 197 |
+
if not self.command_queue:
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
command = self.command_queue.pop(0)
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
self._send_raw_command(command)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Failed to send command {command}: {e}")
|
| 206 |
+
|
| 207 |
+
def _read_serial_data(self):
|
| 208 |
+
"""Read and process incoming serial data"""
|
| 209 |
+
if not self.serial_connection or not self.serial_connection.is_open:
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
if self.serial_connection.in_waiting > 0:
|
| 214 |
+
line = self.serial_connection.readline().decode('utf-8').strip()
|
| 215 |
+
self._process_serial_line(line)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Error reading serial data: {e}")
|
| 218 |
+
|
| 219 |
+
def _process_serial_line(self, line: str):
|
| 220 |
+
"""Process a single line of serial data"""
|
| 221 |
+
try:
|
| 222 |
+
# Parse different message types
|
| 223 |
+
if line.startswith("SENSOR:"):
|
| 224 |
+
self._parse_sensor_data(line)
|
| 225 |
+
elif line.startswith("STATUS:"):
|
| 226 |
+
self._parse_status_data(line)
|
| 227 |
+
elif line.startswith("ERROR:"):
|
| 228 |
+
logger.error(f"Arduino error: {line}")
|
| 229 |
+
elif line.startswith("DEBUG:"):
|
| 230 |
+
logger.debug(f"Arduino debug: {line}")
|
| 231 |
+
else:
|
| 232 |
+
# Add to response queue
|
| 233 |
+
with self.queue_lock:
|
| 234 |
+
self.response_queue.append(line)
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.error(f"Error processing serial line '{line}': {e}")
|
| 238 |
+
|
| 239 |
+
def _parse_sensor_data(self, line: str):
|
| 240 |
+
"""Parse sensor data from Arduino"""
|
| 241 |
+
try:
|
| 242 |
+
# Format: SENSOR:limit_sw1:0,limit_sw2:1,force:45.2,temp:23.5
|
| 243 |
+
data_part = line[7:] # Remove "SENSOR:" prefix
|
| 244 |
+
|
| 245 |
+
sensor_data = {}
|
| 246 |
+
for item in data_part.split(','):
|
| 247 |
+
key, value = item.split(':')
|
| 248 |
+
if key.startswith('limit_sw'):
|
| 249 |
+
sensor_data[key] = bool(int(value))
|
| 250 |
+
elif key == 'force':
|
| 251 |
+
sensor_data[key] = float(value)
|
| 252 |
+
elif key == 'temp':
|
| 253 |
+
sensor_data[key] = float(value)
|
| 254 |
+
|
| 255 |
+
self.latest_sensor_data = SensorData(
|
| 256 |
+
limit_switches={k: v for k, v in sensor_data.items() if k.startswith('limit_sw')},
|
| 257 |
+
force_sensor=sensor_data.get('force', 0.0),
|
| 258 |
+
temperature=sensor_data.get('temp', 0.0),
|
| 259 |
+
timestamp=time.time()
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Add to history
|
| 263 |
+
self.sensor_history.append(self.latest_sensor_data)
|
| 264 |
+
if len(self.sensor_history) > 100: # Keep last 100 readings
|
| 265 |
+
self.sensor_history.pop(0)
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"Error parsing sensor data: {e}")
|
| 269 |
+
|
| 270 |
+
def _parse_status_data(self, line: str):
|
| 271 |
+
"""Parse status data from Arduino"""
|
| 272 |
+
try:
|
| 273 |
+
# Format: STATUS:servo1:90.5,servo2:45.0
|
| 274 |
+
data_part = line[7:] # Remove "STATUS:" prefix
|
| 275 |
+
|
| 276 |
+
for item in data_part.split(','):
|
| 277 |
+
servo_id, angle = item.split(':')
|
| 278 |
+
servo_num = int(servo_id.replace('servo', ''))
|
| 279 |
+
self.current_positions[servo_num] = float(angle)
|
| 280 |
+
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.error(f"Error parsing status data: {e}")
|
| 283 |
+
|
| 284 |
+
def _send_command(self, command: str):
|
| 285 |
+
"""Send command and wait for response"""
|
| 286 |
+
with self.queue_lock:
|
| 287 |
+
self.command_queue.append(command)
|
| 288 |
+
|
| 289 |
+
# Wait for response
|
| 290 |
+
start_time = time.time()
|
| 291 |
+
while time.time() - start_time < 5.0:
|
| 292 |
+
if self._check_response(command):
|
| 293 |
+
return True
|
| 294 |
+
time.sleep(0.1)
|
| 295 |
+
|
| 296 |
+
logger.warning(f"No response received for command: {command}")
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
def _send_raw_command(self, command: str):
|
| 300 |
+
"""Send raw command without queuing"""
|
| 301 |
+
if not self.serial_connection or not self.serial_connection.is_open:
|
| 302 |
+
raise Exception("Serial connection not available")
|
| 303 |
+
|
| 304 |
+
self.serial_connection.write(f"{command}\n".encode('utf-8'))
|
| 305 |
+
self.serial_connection.flush()
|
| 306 |
+
|
| 307 |
+
def _check_response(self, expected_command: str) -> bool:
|
| 308 |
+
"""Check if expected response is in queue"""
|
| 309 |
+
with self.queue_lock:
|
| 310 |
+
for i, response in enumerate(self.response_queue):
|
| 311 |
+
if expected_command in response:
|
| 312 |
+
self.response_queue.pop(i)
|
| 313 |
+
return True
|
| 314 |
+
return False
|
| 315 |
+
|
| 316 |
+
def initialize_servos(self):
|
| 317 |
+
"""Initialize all servos to home position"""
|
| 318 |
+
logger.info("Initializing servos...")
|
| 319 |
+
|
| 320 |
+
# Home positions for each servo (adjust based on your robot design)
|
| 321 |
+
home_positions = {
|
| 322 |
+
1: 90.0, # Base rotation
|
| 323 |
+
2: 45.0, # Shoulder
|
| 324 |
+
3: 90.0, # Elbow
|
| 325 |
+
4: 90.0, # Wrist
|
| 326 |
+
5: 0.0, # Gripper
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
for servo_id, position in home_positions.items():
|
| 330 |
+
self.move_servo(servo_id, position)
|
| 331 |
+
time.sleep(0.5) # Small delay between servo movements
|
| 332 |
+
|
| 333 |
+
logger.info("Servos initialized to home positions")
|
| 334 |
+
|
| 335 |
+
def move_servo(self, servo_id: int, angle: float, speed: float = 1.0) -> bool:
|
| 336 |
+
"""Move a specific servo to target angle"""
|
| 337 |
+
if not self.connected:
|
| 338 |
+
logger.error("Not connected to Arduino")
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
# Validate angle
|
| 342 |
+
if not (self.min_servo_angle <= angle <= self.max_servo_angle):
|
| 343 |
+
logger.error(f"Invalid angle {angle} for servo {servo_id}")
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
+
# Check emergency stop
|
| 347 |
+
if self.emergency_stop_flag:
|
| 348 |
+
logger.warning("Emergency stop active, ignoring servo command")
|
| 349 |
+
return False
|
| 350 |
+
|
| 351 |
+
try:
|
| 352 |
+
command = f"MOVE:{servo_id}:{angle:.1f}:{speed:.2f}"
|
| 353 |
+
success = self._send_command(command)
|
| 354 |
+
|
| 355 |
+
if success:
|
| 356 |
+
self.target_positions[servo_id] = angle
|
| 357 |
+
logger.debug(f"Moved servo {servo_id} to {angle} degrees")
|
| 358 |
+
|
| 359 |
+
return success
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
logger.error(f"Failed to move servo {servo_id}: {e}")
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
def move_to_position(self, x: float, y: float, z: float) -> bool:
|
| 366 |
+
"""Move robotic arm to 3D position using inverse kinematics"""
|
| 367 |
+
if not self.connected:
|
| 368 |
+
logger.error("Not connected to Arduino")
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
# Convert 3D coordinates to servo angles
|
| 373 |
+
# This is a simplified version - implement proper inverse kinematics
|
| 374 |
+
servo_angles = self._inverse_kinematics(x, y, z)
|
| 375 |
+
|
| 376 |
+
# Move all servos simultaneously
|
| 377 |
+
success = True
|
| 378 |
+
for servo_id, angle in servo_angles.items():
|
| 379 |
+
if not self.move_servo(servo_id, angle):
|
| 380 |
+
success = False
|
| 381 |
+
|
| 382 |
+
if success:
|
| 383 |
+
logger.info(f"Moved to position ({x:.2f}, {y:.2f}, {z:.2f})")
|
| 384 |
+
|
| 385 |
+
return success
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.error(f"Failed to move to position: {e}")
|
| 389 |
+
return False
|
| 390 |
+
|
| 391 |
+
def _inverse_kinematics(self, x: float, y: float, z: float) -> Dict[int, float]:
|
| 392 |
+
"""Simple inverse kinematics calculation"""
|
| 393 |
+
# This is a placeholder implementation
|
| 394 |
+
# Replace with proper IK calculations for your robot design
|
| 395 |
+
|
| 396 |
+
# Simplified mapping (adjust based on your robot geometry)
|
| 397 |
+
base_angle = (np.arctan2(y, x) * 180 / np.pi) + 90
|
| 398 |
+
shoulder_angle = 45 + (z * 10) # Simplified mapping
|
| 399 |
+
elbow_angle = 90 - (z * 5)
|
| 400 |
+
wrist_angle = 90
|
| 401 |
+
|
| 402 |
+
return {
|
| 403 |
+
1: np.clip(base_angle, 0, 180),
|
| 404 |
+
2: np.clip(shoulder_angle, 0, 180),
|
| 405 |
+
3: np.clip(elbow_angle, 0, 180),
|
| 406 |
+
4: np.clip(wrist_angle, 0, 180),
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
def open_gripper(self) -> bool:
|
| 410 |
+
"""Open the gripper"""
|
| 411 |
+
return self.move_servo(5, 0.0)
|
| 412 |
+
|
| 413 |
+
def close_gripper(self) -> bool:
|
| 414 |
+
"""Close the gripper"""
|
| 415 |
+
return self.move_servo(5, 90.0)
|
| 416 |
+
|
| 417 |
+
def emergency_stop(self):
|
| 418 |
+
"""Activate emergency stop"""
|
| 419 |
+
logger.warning("EMERGENCY STOP ACTIVATED")
|
| 420 |
+
self.emergency_stop_flag = True
|
| 421 |
+
|
| 422 |
+
# Send emergency stop command
|
| 423 |
+
try:
|
| 424 |
+
self._send_command("ESTOP")
|
| 425 |
+
except:
|
| 426 |
+
pass
|
| 427 |
+
|
| 428 |
+
# Move all servos to safe positions
|
| 429 |
+
safe_positions = {1: 90, 2: 45, 3: 90, 4: 90, 5: 0}
|
| 430 |
+
for servo_id, position in safe_positions.items():
|
| 431 |
+
try:
|
| 432 |
+
self.move_servo(servo_id, position, speed=2.0)
|
| 433 |
+
except:
|
| 434 |
+
pass
|
| 435 |
+
|
| 436 |
+
def reset_emergency_stop(self):
|
| 437 |
+
"""Reset emergency stop"""
|
| 438 |
+
self.emergency_stop_flag = False
|
| 439 |
+
logger.info("Emergency stop reset")
|
| 440 |
+
|
| 441 |
+
def get_sensor_data(self) -> Optional[SensorData]:
|
| 442 |
+
"""Get latest sensor data"""
|
| 443 |
+
return self.latest_sensor_data
|
| 444 |
+
|
| 445 |
+
def get_servo_positions(self) -> Dict[int, float]:
|
| 446 |
+
"""Get current servo positions"""
|
| 447 |
+
return self.current_positions.copy()
|
| 448 |
+
|
| 449 |
+
def is_movement_complete(self, servo_id: int, tolerance: float = 2.0) -> bool:
|
| 450 |
+
"""Check if servo movement is complete"""
|
| 451 |
+
if servo_id not in self.target_positions:
|
| 452 |
+
return True
|
| 453 |
+
|
| 454 |
+
current = self.current_positions.get(servo_id, 0)
|
| 455 |
+
target = self.target_positions[servo_id]
|
| 456 |
+
|
| 457 |
+
return abs(current - target) <= tolerance
|
| 458 |
+
|
| 459 |
+
def wait_for_movement(self, servo_ids: List[int], timeout: float = 10.0) -> bool:
|
| 460 |
+
"""Wait for servo movements to complete"""
|
| 461 |
+
start_time = time.time()
|
| 462 |
+
|
| 463 |
+
while time.time() - start_time < timeout:
|
| 464 |
+
all_complete = True
|
| 465 |
+
for servo_id in servo_ids:
|
| 466 |
+
if not self.is_movement_complete(servo_id):
|
| 467 |
+
all_complete = False
|
| 468 |
+
break
|
| 469 |
+
|
| 470 |
+
if all_complete:
|
| 471 |
+
return True
|
| 472 |
+
|
| 473 |
+
time.sleep(0.1)
|
| 474 |
+
|
| 475 |
+
logger.warning(f"Movement timeout after {timeout} seconds")
|
| 476 |
+
return False
|
| 477 |
+
|
| 478 |
+
def get_status(self) -> Dict:
|
| 479 |
+
"""Get Arduino bridge status"""
|
| 480 |
+
return {
|
| 481 |
+
'connected': self.connected,
|
| 482 |
+
'port': self.port,
|
| 483 |
+
'baudrate': self.baudrate,
|
| 484 |
+
'emergency_stop': self.emergency_stop_flag,
|
| 485 |
+
'current_positions': self.get_servo_positions(),
|
| 486 |
+
'target_positions': self.target_positions.copy(),
|
| 487 |
+
'latest_sensor_data': self.latest_sensor_data.__dict__ if self.latest_sensor_data else None,
|
| 488 |
+
'queue_size': len(self.command_queue)
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
def main():
|
| 492 |
+
"""Test Arduino bridge functionality"""
|
| 493 |
+
import argparse
|
| 494 |
+
|
| 495 |
+
parser = argparse.ArgumentParser(description='Test Arduino Bridge')
|
| 496 |
+
parser.add_argument('--port', default='/dev/ttyUSB0', help='Arduino port')
|
| 497 |
+
parser.add_argument('--baudrate', type=int, default=115200, help='Baud rate')
|
| 498 |
+
|
| 499 |
+
args = parser.parse_args()
|
| 500 |
+
|
| 501 |
+
# Create bridge
|
| 502 |
+
bridge = ArduinoBridge(args.port, args.baudrate)
|
| 503 |
+
|
| 504 |
+
try:
|
| 505 |
+
# Connect
|
| 506 |
+
if bridge.connect():
|
| 507 |
+
print("Connected to Arduino successfully")
|
| 508 |
+
|
| 509 |
+
# Initialize servos
|
| 510 |
+
bridge.initialize_servos()
|
| 511 |
+
time.sleep(2)
|
| 512 |
+
|
| 513 |
+
# Test movements
|
| 514 |
+
print("Testing servo movements...")
|
| 515 |
+
bridge.move_servo(1, 45)
|
| 516 |
+
time.sleep(1)
|
| 517 |
+
bridge.move_servo(1, 135)
|
| 518 |
+
time.sleep(1)
|
| 519 |
+
bridge.move_servo(1, 90)
|
| 520 |
+
|
| 521 |
+
# Test gripper
|
| 522 |
+
print("Testing gripper...")
|
| 523 |
+
bridge.open_gripper()
|
| 524 |
+
time.sleep(1)
|
| 525 |
+
bridge.close_gripper()
|
| 526 |
+
time.sleep(1)
|
| 527 |
+
bridge.open_gripper()
|
| 528 |
+
|
| 529 |
+
# Print status
|
| 530 |
+
print("\nArduino Bridge Status:")
|
| 531 |
+
print(json.dumps(bridge.get_status(), indent=2, default=str))
|
| 532 |
+
|
| 533 |
+
else:
|
| 534 |
+
print("Failed to connect to Arduino")
|
| 535 |
+
|
| 536 |
+
except KeyboardInterrupt:
|
| 537 |
+
print("\nInterrupted by user")
|
| 538 |
+
|
| 539 |
+
finally:
|
| 540 |
+
bridge.disconnect()
|
| 541 |
+
print("Arduino bridge disconnected")
|
| 542 |
+
|
| 543 |
+
if __name__ == "__main__":
|
| 544 |
+
main()
|
src/coordinate_transformer.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Coordinate Transformer - Pixel to Robot Coordinate System
|
| 4 |
+
Handles camera calibration, stereo vision, and coordinate transformations
|
| 5 |
+
|
| 6 |
+
Author: AI Assistant
|
| 7 |
+
Date: 2025-12-15
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Tuple, Optional, Dict, List
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class CameraCalibration:
|
| 22 |
+
"""Camera calibration parameters"""
|
| 23 |
+
camera_matrix: np.ndarray
|
| 24 |
+
distortion_coeffs: np.ndarray
|
| 25 |
+
image_width: int
|
| 26 |
+
image_height: int
|
| 27 |
+
focal_length: float
|
| 28 |
+
principal_point: Tuple[float, float]
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class StereoCalibration:
|
| 32 |
+
"""Stereo camera calibration parameters"""
|
| 33 |
+
left_camera_matrix: np.ndarray
|
| 34 |
+
right_camera_matrix: np.ndarray
|
| 35 |
+
left_distortion: np.ndarray
|
| 36 |
+
right_distortion: np.ndarray
|
| 37 |
+
rotation_matrix: np.ndarray
|
| 38 |
+
translation_vector: np.ndarray
|
| 39 |
+
essential_matrix: np.ndarray
|
| 40 |
+
fundamental_matrix: np.ndarray
|
| 41 |
+
rectification_matrix_left: np.ndarray
|
| 42 |
+
rectification_matrix_right: np.ndarray
|
| 43 |
+
projection_matrix_left: np.ndarray
|
| 44 |
+
projection_matrix_right: np.ndarray
|
| 45 |
+
disparity_to_depth_map: np.ndarray
|
| 46 |
+
|
| 47 |
+
class CoordinateTransformer:
|
| 48 |
+
"""
|
| 49 |
+
Handles coordinate transformations between pixel coordinates and robot world coordinates
|
| 50 |
+
Supports both monocular and stereo camera systems
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self,
|
| 54 |
+
camera_matrix_path: Optional[str] = None,
|
| 55 |
+
distortion_coeffs_path: Optional[str] = None,
|
| 56 |
+
stereo_calibration_path: Optional[str] = None):
|
| 57 |
+
"""Initialize coordinate transformer"""
|
| 58 |
+
|
| 59 |
+
self.camera_calibration: Optional[CameraCalibration] = None
|
| 60 |
+
self.stereo_calibration: Optional[StereoCalibration] = None
|
| 61 |
+
self.stereo_matcher: Optional[cv2.StereoSGBM] = None
|
| 62 |
+
|
| 63 |
+
# Robot coordinate system parameters
|
| 64 |
+
self.robot_origin = (0.0, 0.0, 0.0) # Robot base position
|
| 65 |
+
self.camera_to_robot_transform = np.eye(4) # 4x4 transformation matrix
|
| 66 |
+
|
| 67 |
+
# Load calibrations if provided
|
| 68 |
+
if camera_matrix_path and distortion_coeffs_path:
|
| 69 |
+
self.load_camera_calibration(camera_matrix_path, distortion_coeffs_path)
|
| 70 |
+
|
| 71 |
+
if stereo_calibration_path:
|
| 72 |
+
self.load_stereo_calibration(stereo_calibration_path)
|
| 73 |
+
|
| 74 |
+
logger.info("Coordinate Transformer initialized")
|
| 75 |
+
|
| 76 |
+
def load_camera_calibration(self, camera_matrix_path: str, distortion_coeffs_path: str):
|
| 77 |
+
"""Load single camera calibration"""
|
| 78 |
+
try:
|
| 79 |
+
camera_matrix = np.load(camera_matrix_path)
|
| 80 |
+
distortion_coeffs = np.load(distortion_coeffs_path)
|
| 81 |
+
|
| 82 |
+
# Extract calibration parameters
|
| 83 |
+
fx, fy = camera_matrix[0, 0], camera_matrix[1, 1]
|
| 84 |
+
cx, cy = camera_matrix[0, 2], camera_matrix[1, 2]
|
| 85 |
+
|
| 86 |
+
self.camera_calibration = CameraCalibration(
|
| 87 |
+
camera_matrix=camera_matrix,
|
| 88 |
+
distortion_coeffs=distortion_coeffs,
|
| 89 |
+
image_width=int(camera_matrix[0, 2] * 2), # Approximate
|
| 90 |
+
image_height=int(camera_matrix[1, 2] * 2), # Approximate
|
| 91 |
+
focal_length=(fx + fy) / 2,
|
| 92 |
+
principal_point=(cx, cy)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
logger.info("Camera calibration loaded successfully")
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Failed to load camera calibration: {e}")
|
| 99 |
+
raise
|
| 100 |
+
|
| 101 |
+
def load_stereo_calibration(self, stereo_calibration_path: str):
|
| 102 |
+
"""Load stereo camera calibration"""
|
| 103 |
+
try:
|
| 104 |
+
calibration_data = np.load(stereo_calibration_path)
|
| 105 |
+
|
| 106 |
+
self.stereo_calibration = StereoCalibration(
|
| 107 |
+
left_camera_matrix=calibration_data['left_camera_matrix'],
|
| 108 |
+
right_camera_matrix=calibration_data['right_camera_matrix'],
|
| 109 |
+
left_distortion=calibration_data['left_distortion'],
|
| 110 |
+
right_distortion=calibration_data['right_distortion'],
|
| 111 |
+
rotation_matrix=calibration_data['rotation_matrix'],
|
| 112 |
+
translation_vector=calibration_data['translation_vector'],
|
| 113 |
+
essential_matrix=calibration_data['essential_matrix'],
|
| 114 |
+
fundamental_matrix=calibration_data['fundamental_matrix'],
|
| 115 |
+
rectification_matrix_left=calibration_data['rectification_matrix_left'],
|
| 116 |
+
rectification_matrix_right=calibration_data['rectification_matrix_right'],
|
| 117 |
+
projection_matrix_left=calibration_data['projection_matrix_left'],
|
| 118 |
+
projection_matrix_right=calibration_data['projection_matrix_right'],
|
| 119 |
+
disparity_to_depth_map=calibration_data['disparity_to_depth_map']
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Initialize stereo matcher for real-time depth calculation
|
| 123 |
+
self._initialize_stereo_matcher()
|
| 124 |
+
|
| 125 |
+
logger.info("Stereo calibration loaded successfully")
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Failed to load stereo calibration: {e}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
def _initialize_stereo_matcher(self):
|
| 132 |
+
"""Initialize stereo matching for depth calculation"""
|
| 133 |
+
if not self.stereo_calibration:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
# Create stereo block matcher
|
| 137 |
+
self.stereo_matcher = cv2.StereoSGBM_create(
|
| 138 |
+
minDisparity=0,
|
| 139 |
+
numDisparities=64,
|
| 140 |
+
blockSize=9,
|
| 141 |
+
P1=8 * 9 * 9,
|
| 142 |
+
P2=32 * 9 * 9,
|
| 143 |
+
disp12MaxDiff=1,
|
| 144 |
+
uniquenessRatio=10,
|
| 145 |
+
speckleWindowSize=100,
|
| 146 |
+
speckleRange=32
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
logger.info("Stereo matcher initialized")
|
| 150 |
+
|
| 151 |
+
def pixel_to_world(self,
|
| 152 |
+
pixel_x: int,
|
| 153 |
+
pixel_y: int,
|
| 154 |
+
image_shape: Tuple[int, int, int],
|
| 155 |
+
depth: Optional[float] = None) -> Tuple[float, float, float]:
|
| 156 |
+
"""
|
| 157 |
+
Convert pixel coordinates to world coordinates
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
pixel_x, pixel_y: Pixel coordinates
|
| 161 |
+
image_shape: Shape of the image (height, width, channels)
|
| 162 |
+
depth: Depth in meters (if None, uses stereo or default depth)
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Tuple of (x, y, z) world coordinates in meters
|
| 166 |
+
"""
|
| 167 |
+
if not self.camera_calibration:
|
| 168 |
+
raise ValueError("Camera calibration not loaded")
|
| 169 |
+
|
| 170 |
+
# Get depth if not provided
|
| 171 |
+
if depth is None:
|
| 172 |
+
depth = self._estimate_depth(pixel_x, pixel_y, image_shape)
|
| 173 |
+
|
| 174 |
+
# Convert pixel to normalized coordinates
|
| 175 |
+
fx, fy = self.camera_calibration.camera_matrix[0, 0], self.camera_calibration.camera_matrix[1, 1]
|
| 176 |
+
cx, cy = self.camera_calibration.principal_point
|
| 177 |
+
|
| 178 |
+
# Convert to camera coordinates
|
| 179 |
+
x_cam = (pixel_x - cx) * depth / fx
|
| 180 |
+
y_cam = (pixel_y - cy) * depth / fy
|
| 181 |
+
z_cam = depth
|
| 182 |
+
|
| 183 |
+
# Transform to robot world coordinates
|
| 184 |
+
camera_point = np.array([x_cam, y_cam, z_cam, 1.0])
|
| 185 |
+
world_point = self.camera_to_robot_transform @ camera_point
|
| 186 |
+
|
| 187 |
+
return (float(world_point[0]), float(world_point[1]), float(world_point[2]))
|
| 188 |
+
|
| 189 |
+
def world_to_pixel(self,
|
| 190 |
+
world_x: float,
|
| 191 |
+
world_y: float,
|
| 192 |
+
world_z: float) -> Tuple[int, int]:
|
| 193 |
+
"""
|
| 194 |
+
Convert world coordinates to pixel coordinates
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
Tuple of (pixel_x, pixel_y) coordinates
|
| 198 |
+
"""
|
| 199 |
+
if not self.camera_calibration:
|
| 200 |
+
raise ValueError("Camera calibration not loaded")
|
| 201 |
+
|
| 202 |
+
# Transform world point to camera coordinates
|
| 203 |
+
world_point = np.array([world_x, world_y, world_z, 1.0])
|
| 204 |
+
camera_point = np.linalg.inv(self.camera_to_robot_transform) @ world_point
|
| 205 |
+
|
| 206 |
+
x_cam, y_cam, z_cam = camera_point[:3]
|
| 207 |
+
|
| 208 |
+
# Convert to pixel coordinates
|
| 209 |
+
fx, fy = self.camera_calibration.camera_matrix[0, 0], self.camera_calibration.camera_matrix[1, 1]
|
| 210 |
+
cx, cy = self.camera_calibration.principal_point
|
| 211 |
+
|
| 212 |
+
pixel_x = int(x_cam * fx / z_cam + cx)
|
| 213 |
+
pixel_y = int(y_cam * fy / z_cam + cy)
|
| 214 |
+
|
| 215 |
+
return (pixel_x, pixel_y)
|
| 216 |
+
|
| 217 |
+
def _estimate_depth(self, pixel_x: int, pixel_y: int, image_shape: Tuple[int, int, int]) -> float:
|
| 218 |
+
"""Estimate depth using stereo vision or default depth"""
|
| 219 |
+
|
| 220 |
+
# If stereo calibration is available, try to get depth from disparity
|
| 221 |
+
if self.stereo_calibration and len(image_shape) == 3:
|
| 222 |
+
# This would require stereo images - simplified for now
|
| 223 |
+
pass
|
| 224 |
+
|
| 225 |
+
# Default depth estimation based on image position
|
| 226 |
+
# Assume strawberries are typically 20-50cm from camera
|
| 227 |
+
image_height = image_shape[0]
|
| 228 |
+
|
| 229 |
+
# Simple depth estimation based on vertical position
|
| 230 |
+
# Lower in image = closer to camera
|
| 231 |
+
normalized_y = pixel_y / image_height
|
| 232 |
+
estimated_depth = 0.2 + (0.3 * (1.0 - normalized_y)) # 20-50cm range
|
| 233 |
+
|
| 234 |
+
return estimated_depth
|
| 235 |
+
|
| 236 |
+
def calculate_depth_from_stereo(self,
|
| 237 |
+
left_image: np.ndarray,
|
| 238 |
+
right_image: np.ndarray,
|
| 239 |
+
pixel_x: int,
|
| 240 |
+
pixel_y: int) -> Optional[float]:
|
| 241 |
+
"""Calculate depth from stereo images"""
|
| 242 |
+
if not self.stereo_matcher or not self.stereo_calibration:
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Compute disparity map
|
| 247 |
+
disparity = self.stereo_matcher.compute(left_image, right_image)
|
| 248 |
+
|
| 249 |
+
# Get disparity at specific pixel
|
| 250 |
+
if 0 <= pixel_y < disparity.shape[0] and 0 <= pixel_x < disparity.shape[1]:
|
| 251 |
+
disp_value = disparity[pixel_y, pixel_x]
|
| 252 |
+
|
| 253 |
+
if disp_value > 0:
|
| 254 |
+
# Convert disparity to depth
|
| 255 |
+
depth = self.stereo_calibration.disparity_to_depth_map[disp_value]
|
| 256 |
+
return float(depth)
|
| 257 |
+
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Error calculating stereo depth: {e}")
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
+
def undistort_point(self, pixel_x: int, pixel_y: int) -> Tuple[int, int]:
|
| 265 |
+
"""Undistort pixel coordinates using camera calibration"""
|
| 266 |
+
if not self.camera_calibration:
|
| 267 |
+
return pixel_x, pixel_y
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
# Create point array
|
| 271 |
+
points = np.array([[pixel_x, pixel_y]], dtype=np.float32)
|
| 272 |
+
|
| 273 |
+
# Undistort points
|
| 274 |
+
undistorted_points = cv2.undistortPoints(
|
| 275 |
+
points,
|
| 276 |
+
self.camera_calibration.camera_matrix,
|
| 277 |
+
self.camera_calibration.distortion_coeffs
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
return int(undistorted_points[0][0][0]), int(undistorted_points[0][0][1])
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
logger.error(f"Error undistorting point: {e}")
|
| 284 |
+
return pixel_x, pixel_y
|
| 285 |
+
|
| 286 |
+
def set_camera_to_robot_transform(self, transform_matrix: np.ndarray):
|
| 287 |
+
"""Set the transformation matrix from camera to robot coordinates"""
|
| 288 |
+
if transform_matrix.shape != (4, 4):
|
| 289 |
+
raise ValueError("Transform matrix must be 4x4")
|
| 290 |
+
|
| 291 |
+
self.camera_to_robot_transform = transform_matrix
|
| 292 |
+
logger.info("Camera to robot transform updated")
|
| 293 |
+
|
| 294 |
+
def calibrate_camera_to_robot(self,
|
| 295 |
+
world_points: List[Tuple[float, float, float]],
|
| 296 |
+
pixel_points: List[Tuple[int, int]]) -> bool:
|
| 297 |
+
"""
|
| 298 |
+
Calibrate camera to robot transformation using known correspondences
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
world_points: List of (x, y, z) world coordinates
|
| 302 |
+
pixel_points: List of (pixel_x, pixel_y) pixel coordinates
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
True if calibration successful
|
| 306 |
+
"""
|
| 307 |
+
if len(world_points) != len(pixel_points) or len(world_points) < 4:
|
| 308 |
+
logger.error("Need at least 4 corresponding points for calibration")
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
# Prepare point correspondences
|
| 313 |
+
world_points_3d = np.array(world_points, dtype=np.float32)
|
| 314 |
+
pixel_points_2d = np.array(pixel_points, dtype=np.float32)
|
| 315 |
+
|
| 316 |
+
# Solve PnP problem
|
| 317 |
+
success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac(
|
| 318 |
+
world_points_3d,
|
| 319 |
+
pixel_points_2d,
|
| 320 |
+
self.camera_calibration.camera_matrix,
|
| 321 |
+
self.camera_calibration.distortion_coeffs
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if success:
|
| 325 |
+
# Convert rotation vector to rotation matrix
|
| 326 |
+
rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
|
| 327 |
+
|
| 328 |
+
# Create 4x4 transformation matrix
|
| 329 |
+
transform_matrix = np.eye(4)
|
| 330 |
+
transform_matrix[:3, :3] = rotation_matrix
|
| 331 |
+
transform_matrix[:3, 3] = translation_vector.flatten()
|
| 332 |
+
|
| 333 |
+
self.set_camera_to_robot_transform(transform_matrix)
|
| 334 |
+
|
| 335 |
+
logger.info("Camera to robot calibration successful")
|
| 336 |
+
return True
|
| 337 |
+
else:
|
| 338 |
+
logger.error("PnP calibration failed")
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error(f"Calibration error: {e}")
|
| 343 |
+
return False
|
| 344 |
+
|
| 345 |
+
def get_workspace_bounds(self) -> Dict[str, Tuple[float, float]]:
|
| 346 |
+
"""Get the bounds of the robot workspace in world coordinates"""
|
| 347 |
+
# This should be calibrated for your specific robot setup
|
| 348 |
+
return {
|
| 349 |
+
'x_min': -0.5, 'x_max': 0.5,
|
| 350 |
+
'y_min': -0.5, 'y_max': 0.5,
|
| 351 |
+
'z_min': 0.0, 'z_max': 0.5
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
def is_point_in_workspace(self, x: float, y: float, z: float) -> bool:
|
| 355 |
+
"""Check if a point is within the robot workspace"""
|
| 356 |
+
bounds = self.get_workspace_bounds()
|
| 357 |
+
|
| 358 |
+
return (bounds['x_min'] <= x <= bounds['x_max'] and
|
| 359 |
+
bounds['y_min'] <= y <= bounds['y_max'] and
|
| 360 |
+
bounds['z_min'] <= z <= bounds['z_max'])
|
| 361 |
+
|
| 362 |
+
def project_3d_to_2d(self,
|
| 363 |
+
world_points: List[Tuple[float, float, float]],
|
| 364 |
+
image_shape: Tuple[int, int]) -> List[Tuple[int, int]]:
|
| 365 |
+
"""Project 3D world points to 2D image coordinates"""
|
| 366 |
+
if not self.camera_calibration:
|
| 367 |
+
raise ValueError("Camera calibration not loaded")
|
| 368 |
+
|
| 369 |
+
projected_points = []
|
| 370 |
+
|
| 371 |
+
for world_point in world_points:
|
| 372 |
+
pixel_x, pixel_y = self.world_to_pixel(*world_point)
|
| 373 |
+
|
| 374 |
+
# Check if point is within image bounds
|
| 375 |
+
if (0 <= pixel_x < image_shape[1] and 0 <= pixel_y < image_shape[0]):
|
| 376 |
+
projected_points.append((pixel_x, pixel_y))
|
| 377 |
+
else:
|
| 378 |
+
projected_points.append((-1, -1)) # Outside image
|
| 379 |
+
|
| 380 |
+
return projected_points
|
| 381 |
+
|
| 382 |
+
def save_calibration(self, filepath: str):
|
| 383 |
+
"""Save current calibration to file"""
|
| 384 |
+
calibration_data = {
|
| 385 |
+
'camera_calibration': {
|
| 386 |
+
'camera_matrix': self.camera_calibration.camera_matrix.tolist() if self.camera_calibration else None,
|
| 387 |
+
'distortion_coeffs': self.camera_calibration.distortion_coeffs.tolist() if self.camera_calibration else None,
|
| 388 |
+
'image_width': self.camera_calibration.image_width if self.camera_calibration else None,
|
| 389 |
+
'image_height': self.camera_calibration.image_height if self.camera_calibration else None,
|
| 390 |
+
'focal_length': self.camera_calibration.focal_length if self.camera_calibration else None,
|
| 391 |
+
'principal_point': self.camera_calibration.principal_point if self.camera_calibration else None
|
| 392 |
+
},
|
| 393 |
+
'stereo_calibration': {
|
| 394 |
+
'left_camera_matrix': self.stereo_calibration.left_camera_matrix.tolist() if self.stereo_calibration else None,
|
| 395 |
+
'right_camera_matrix': self.stereo_calibration.right_camera_matrix.tolist() if self.stereo_calibration else None,
|
| 396 |
+
'rotation_matrix': self.stereo_calibration.rotation_matrix.tolist() if self.stereo_calibration else None,
|
| 397 |
+
'translation_vector': self.stereo_calibration.translation_vector.tolist() if self.stereo_calibration else None
|
| 398 |
+
},
|
| 399 |
+
'camera_to_robot_transform': self.camera_to_robot_transform.tolist(),
|
| 400 |
+
'robot_origin': self.robot_origin
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
with open(filepath, 'w') as f:
|
| 404 |
+
json.dump(calibration_data, f, indent=2)
|
| 405 |
+
|
| 406 |
+
logger.info(f"Calibration saved to {filepath}")
|
| 407 |
+
|
| 408 |
+
def load_calibration(self, filepath: str):
|
| 409 |
+
"""Load calibration from file"""
|
| 410 |
+
try:
|
| 411 |
+
with open(filepath, 'r') as f:
|
| 412 |
+
calibration_data = json.load(f)
|
| 413 |
+
|
| 414 |
+
# Load camera calibration
|
| 415 |
+
if calibration_data['camera_calibration']['camera_matrix']:
|
| 416 |
+
cam_data = calibration_data['camera_calibration']
|
| 417 |
+
self.camera_calibration = CameraCalibration(
|
| 418 |
+
camera_matrix=np.array(cam_data['camera_matrix']),
|
| 419 |
+
distortion_coeffs=np.array(cam_data['distortion_coeffs']),
|
| 420 |
+
image_width=cam_data['image_width'],
|
| 421 |
+
image_height=cam_data['image_height'],
|
| 422 |
+
focal_length=cam_data['focal_length'],
|
| 423 |
+
principal_point=tuple(cam_data['principal_point'])
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Load stereo calibration
|
| 427 |
+
if calibration_data['stereo_calibration']['left_camera_matrix']:
|
| 428 |
+
stereo_data = calibration_data['stereo_calibration']
|
| 429 |
+
# Note: This is simplified - you'd need to load all stereo parameters
|
| 430 |
+
logger.warning("Stereo calibration loading not fully implemented")
|
| 431 |
+
|
| 432 |
+
# Load transformation matrix
|
| 433 |
+
self.camera_to_robot_transform = np.array(calibration_data['camera_to_robot_transform'])
|
| 434 |
+
self.robot_origin = tuple(calibration_data['robot_origin'])
|
| 435 |
+
|
| 436 |
+
logger.info(f"Calibration loaded from {filepath}")
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.error(f"Failed to load calibration: {e}")
|
| 440 |
+
raise
|
| 441 |
+
|
| 442 |
+
def main():
|
| 443 |
+
"""Test coordinate transformer functionality"""
|
| 444 |
+
import argparse
|
| 445 |
+
|
| 446 |
+
parser = argparse.ArgumentParser(description='Test Coordinate Transformer')
|
| 447 |
+
parser.add_argument('--camera-matrix', help='Camera matrix file path')
|
| 448 |
+
parser.add_argument('--distortion', help='Distortion coefficients file path')
|
| 449 |
+
parser.add_argument('--stereo', help='Stereo calibration file path')
|
| 450 |
+
parser.add_argument('--test-pixel', nargs=2, type=int, metavar=('X', 'Y'),
|
| 451 |
+
help='Test pixel coordinates')
|
| 452 |
+
|
| 453 |
+
args = parser.parse_args()
|
| 454 |
+
|
| 455 |
+
try:
|
| 456 |
+
# Create transformer
|
| 457 |
+
transformer = CoordinateTransformer(
|
| 458 |
+
args.camera_matrix,
|
| 459 |
+
args.distortion,
|
| 460 |
+
args.stereo
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
print("Coordinate Transformer initialized")
|
| 464 |
+
|
| 465 |
+
if args.test_pixel:
|
| 466 |
+
pixel_x, pixel_y = args.test_pixel
|
| 467 |
+
world_coords = transformer.pixel_to_world(pixel_x, pixel_y, (480, 640, 3))
|
| 468 |
+
print(f"Pixel ({pixel_x}, {pixel_y}) -> World {world_coords}")
|
| 469 |
+
|
| 470 |
+
# Convert back
|
| 471 |
+
pixel_x_back, pixel_y_back = transformer.world_to_pixel(*world_coords)
|
| 472 |
+
print(f"World {world_coords} -> Pixel ({pixel_x_back}, {pixel_y_back})")
|
| 473 |
+
|
| 474 |
+
# Print workspace bounds
|
| 475 |
+
bounds = transformer.get_workspace_bounds()
|
| 476 |
+
print(f"Workspace bounds: {bounds}")
|
| 477 |
+
|
| 478 |
+
except Exception as e:
|
| 479 |
+
print(f"Error: {e}")
|
| 480 |
+
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
main()
|
src/integrated_detection_classification.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Integrated Strawberry Detection and Ripeness Classification Pipeline
|
| 4 |
+
Combines YOLOv8 detection with 3-class ripeness classification
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms as transforms
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import yaml
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
# YOLOv8
|
| 21 |
+
from ultralytics import YOLO
|
| 22 |
+
|
| 23 |
+
# Custom imports
|
| 24 |
+
from train_ripeness_classifier import create_model, get_transforms
|
| 25 |
+
|
| 26 |
+
class StrawberryDetectionClassifier:
|
| 27 |
+
"""Integrated detection and classification system"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, detection_model_path, classification_model_path, config_path='config.yaml'):
|
| 30 |
+
self.config = self.load_config(config_path)
|
| 31 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
|
| 33 |
+
# Initialize detection model
|
| 34 |
+
print(f"Loading detection model: {detection_model_path}")
|
| 35 |
+
self.detection_model = YOLO(detection_model_path)
|
| 36 |
+
|
| 37 |
+
# Initialize classification model
|
| 38 |
+
print(f"Loading classification model: {classification_model_path}")
|
| 39 |
+
self.classification_model = self.load_classification_model(classification_model_path)
|
| 40 |
+
|
| 41 |
+
# Get classification transforms
|
| 42 |
+
_, self.classify_transform = get_transforms(img_size=224)
|
| 43 |
+
|
| 44 |
+
# Class names for classification
|
| 45 |
+
self.class_names = ['overripe', 'ripe', 'unripe']
|
| 46 |
+
|
| 47 |
+
# Setup logging
|
| 48 |
+
self.setup_logging()
|
| 49 |
+
|
| 50 |
+
def load_config(self, config_path):
|
| 51 |
+
"""Load configuration from YAML file"""
|
| 52 |
+
with open(config_path, 'r') as f:
|
| 53 |
+
return yaml.safe_load(f)
|
| 54 |
+
|
| 55 |
+
def load_classification_model(self, model_path):
|
| 56 |
+
"""Load the trained classification model"""
|
| 57 |
+
model = create_model(num_classes=3, backbone='resnet18', pretrained=False)
|
| 58 |
+
model.load_state_dict(torch.load(model_path, map_location=self.device))
|
| 59 |
+
model = model.to(self.device)
|
| 60 |
+
model.eval()
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
def setup_logging(self):
|
| 64 |
+
"""Setup logging configuration"""
|
| 65 |
+
logging.basicConfig(
|
| 66 |
+
level=logging.INFO,
|
| 67 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 68 |
+
handlers=[
|
| 69 |
+
logging.FileHandler('strawberry_pipeline.log'),
|
| 70 |
+
logging.StreamHandler()
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
self.logger = logging.getLogger(__name__)
|
| 74 |
+
|
| 75 |
+
def detect_strawberries(self, image):
|
| 76 |
+
"""Detect strawberries in image using YOLOv8"""
|
| 77 |
+
results = self.detection_model(image)
|
| 78 |
+
|
| 79 |
+
detections = []
|
| 80 |
+
for result in results:
|
| 81 |
+
boxes = result.boxes
|
| 82 |
+
if boxes is not None:
|
| 83 |
+
for box in boxes:
|
| 84 |
+
# Get bounding box coordinates
|
| 85 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
| 86 |
+
confidence = box.conf[0].cpu().numpy()
|
| 87 |
+
|
| 88 |
+
# Only keep high-confidence detections
|
| 89 |
+
if confidence > 0.5:
|
| 90 |
+
detections.append({
|
| 91 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 92 |
+
'confidence': float(confidence),
|
| 93 |
+
'class': int(box.cls[0].cpu().numpy())
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
return detections
|
| 97 |
+
|
| 98 |
+
def classify_ripeness(self, image_crop):
|
| 99 |
+
"""Classify ripeness of strawberry crop"""
|
| 100 |
+
try:
|
| 101 |
+
# Apply transforms
|
| 102 |
+
if isinstance(image_crop, np.ndarray):
|
| 103 |
+
image_crop = cv2.cvtColor(image_crop, cv2.COLOR_BGR2RGB)
|
| 104 |
+
from PIL import Image
|
| 105 |
+
image_crop = Image.fromarray(image_crop)
|
| 106 |
+
|
| 107 |
+
input_tensor = self.classify_transform(image_crop).unsqueeze(0).to(self.device)
|
| 108 |
+
|
| 109 |
+
# Get prediction
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
outputs = self.classification_model(input_tensor)
|
| 112 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 113 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 114 |
+
confidence = probabilities[0][predicted_class].item()
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
'class': self.class_names[predicted_class],
|
| 118 |
+
'confidence': float(confidence),
|
| 119 |
+
'probabilities': {
|
| 120 |
+
self.class_names[i]: float(probabilities[0][i].item())
|
| 121 |
+
for i in range(len(self.class_names))
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
except Exception as e:
|
| 125 |
+
self.logger.error(f"Classification error: {e}")
|
| 126 |
+
return {
|
| 127 |
+
'class': 'unknown',
|
| 128 |
+
'confidence': 0.0,
|
| 129 |
+
'probabilities': {cls: 0.0 for cls in self.class_names}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
def process_image(self, image_path, save_annotated=True, output_dir='results'):
|
| 133 |
+
"""Process single image with detection and classification"""
|
| 134 |
+
# Load image
|
| 135 |
+
image = cv2.imread(str(image_path))
|
| 136 |
+
if image is None:
|
| 137 |
+
self.logger.error(f"Could not load image: {image_path}")
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
# Detect strawberries
|
| 141 |
+
detections = self.detect_strawberries(image)
|
| 142 |
+
|
| 143 |
+
results = {
|
| 144 |
+
'image_path': str(image_path),
|
| 145 |
+
'timestamp': datetime.now().isoformat(),
|
| 146 |
+
'detections': [],
|
| 147 |
+
'summary': {
|
| 148 |
+
'total_strawberries': len(detections),
|
| 149 |
+
'ripeness_counts': {'unripe': 0, 'ripe': 0, 'overripe': 0, 'unknown': 0}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Process each detection
|
| 154 |
+
for i, detection in enumerate(detections):
|
| 155 |
+
x1, y1, x2, y2 = detection['bbox']
|
| 156 |
+
|
| 157 |
+
# Crop strawberry
|
| 158 |
+
strawberry_crop = image[y1:y2, x1:x2]
|
| 159 |
+
|
| 160 |
+
# Classify ripeness
|
| 161 |
+
ripeness = self.classify_ripeness(strawberry_crop)
|
| 162 |
+
|
| 163 |
+
# Update summary
|
| 164 |
+
results['summary']['ripeness_counts'][ripeness['class']] += 1
|
| 165 |
+
|
| 166 |
+
# Store result
|
| 167 |
+
result = {
|
| 168 |
+
'detection_id': i,
|
| 169 |
+
'bbox': detection['bbox'],
|
| 170 |
+
'detection_confidence': detection['confidence'],
|
| 171 |
+
'ripeness': ripeness
|
| 172 |
+
}
|
| 173 |
+
results['detections'].append(result)
|
| 174 |
+
|
| 175 |
+
# Draw annotations if requested
|
| 176 |
+
if save_annotated:
|
| 177 |
+
color = self.get_ripeness_color(ripeness['class'])
|
| 178 |
+
label = f"{ripeness['class']} ({ripeness['confidence']:.2f})"
|
| 179 |
+
|
| 180 |
+
# Draw bounding box
|
| 181 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
| 182 |
+
|
| 183 |
+
# Draw label
|
| 184 |
+
cv2.putText(image, label, (x1, y1-10),
|
| 185 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
|
| 186 |
+
|
| 187 |
+
# Save annotated image
|
| 188 |
+
if save_annotated:
|
| 189 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 190 |
+
output_path = Path(output_dir) / f"annotated_{Path(image_path).name}"
|
| 191 |
+
cv2.imwrite(str(output_path), image)
|
| 192 |
+
results['annotated_image_path'] = str(output_path)
|
| 193 |
+
|
| 194 |
+
return results
|
| 195 |
+
|
| 196 |
+
def get_ripeness_color(self, ripeness_class):
|
| 197 |
+
"""Get color for ripeness class"""
|
| 198 |
+
colors = {
|
| 199 |
+
'unripe': (0, 255, 0), # Green
|
| 200 |
+
'ripe': (0, 255, 255), # Yellow
|
| 201 |
+
'overripe': (0, 0, 255), # Red
|
| 202 |
+
'unknown': (128, 128, 128) # Gray
|
| 203 |
+
}
|
| 204 |
+
return colors.get(ripeness_class, (128, 128, 128))
|
| 205 |
+
|
| 206 |
+
def main():
|
| 207 |
+
parser = argparse.ArgumentParser(description='Integrated strawberry detection and classification')
|
| 208 |
+
parser.add_argument('--detection-model', default='model/weights/best_yolov8n_strawberry.pt',
|
| 209 |
+
help='Path to YOLOv8 detection model')
|
| 210 |
+
parser.add_argument('--classification-model', default='model/ripeness_classifier_best.pth',
|
| 211 |
+
help='Path to ripeness classification model')
|
| 212 |
+
parser.add_argument('--mode', choices=['image', 'video', 'realtime'], required=True,
|
| 213 |
+
help='Processing mode')
|
| 214 |
+
parser.add_argument('--input', required=True, help='Input path (image/video/camera index)')
|
| 215 |
+
parser.add_argument('--output', help='Output path for results')
|
| 216 |
+
parser.add_argument('--save-annotated', action='store_true', help='Save annotated images')
|
| 217 |
+
parser.add_argument('--config', default='config.yaml', help='Configuration file path')
|
| 218 |
+
|
| 219 |
+
args = parser.parse_args()
|
| 220 |
+
|
| 221 |
+
# Initialize system
|
| 222 |
+
system = StrawberryDetectionClassifier(
|
| 223 |
+
args.detection_model,
|
| 224 |
+
args.classification_model,
|
| 225 |
+
args.config
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if args.mode == 'image':
|
| 229 |
+
# Process single image
|
| 230 |
+
results = system.process_image(
|
| 231 |
+
args.input,
|
| 232 |
+
save_annotated=args.save_annotated,
|
| 233 |
+
output_dir=args.output or 'results'
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if results:
|
| 237 |
+
# Save results
|
| 238 |
+
results_path = Path(args.output or 'results') / 'detection_results.json'
|
| 239 |
+
results_path.parent.mkdir(exist_ok=True)
|
| 240 |
+
with open(results_path, 'w') as f:
|
| 241 |
+
json.dump(results, f, indent=2)
|
| 242 |
+
|
| 243 |
+
print(f"Results saved to: {results_path}")
|
| 244 |
+
print(f"Found {results['summary']['total_strawberries']} strawberries")
|
| 245 |
+
print(f"Ripeness distribution: {results['summary']['ripeness_counts']}")
|
| 246 |
+
|
| 247 |
+
if __name__ == '__main__':
|
| 248 |
+
main()
|
src/strawberry_picker_pipeline.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Strawberry Picker Pipeline - End-to-End Real-time System
|
| 4 |
+
Combines detection, classification, and robotic control for automated strawberry picking
|
| 5 |
+
|
| 6 |
+
Author: AI Assistant
|
| 7 |
+
Date: 2025-12-15
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import time
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import List, Dict, Tuple, Optional
|
| 17 |
+
import argparse
|
| 18 |
+
import yaml
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 21 |
+
import threading
|
| 22 |
+
|
| 23 |
+
# Import our custom modules
|
| 24 |
+
from integrated_detection_classification import IntegratedDetectorClassifier
|
| 25 |
+
from arduino_bridge import ArduinoBridge
|
| 26 |
+
from coordinate_transformer import CoordinateTransformer
|
| 27 |
+
|
| 28 |
+
# Configure logging
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 32 |
+
handlers=[
|
| 33 |
+
logging.FileHandler('strawberry_picker.log'),
|
| 34 |
+
logging.StreamHandler()
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class PickingTarget:
|
| 41 |
+
"""Represents a strawberry target for picking"""
|
| 42 |
+
bbox: Tuple[int, int, int, int] # x, y, w, h
|
| 43 |
+
confidence: float
|
| 44 |
+
ripeness: str # 'unripe', 'ripe', 'overripe'
|
| 45 |
+
ripeness_confidence: float
|
| 46 |
+
pixel_coords: Tuple[int, int] # center pixel coordinates
|
| 47 |
+
world_coords: Tuple[float, float, float] # x, y, z in robot coordinates
|
| 48 |
+
priority: float # calculated priority score
|
| 49 |
+
|
| 50 |
+
class StrawberryPickerPipeline:
|
| 51 |
+
"""
|
| 52 |
+
Main pipeline for automated strawberry picking system
|
| 53 |
+
Integrates computer vision, classification, and robotic control
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config_path: str = "config.yaml"):
|
| 57 |
+
"""Initialize the strawberry picker pipeline"""
|
| 58 |
+
self.config = self._load_config(config_path)
|
| 59 |
+
self.running = False
|
| 60 |
+
self.picking_targets = []
|
| 61 |
+
self.processed_count = 0
|
| 62 |
+
self.successful_picks = 0
|
| 63 |
+
self.failed_picks = 0
|
| 64 |
+
|
| 65 |
+
# Initialize components
|
| 66 |
+
self.detector_classifier = IntegratedDetectorClassifier(
|
| 67 |
+
detection_model_path=self.config['models']['detection_model'],
|
| 68 |
+
classification_model_path=self.config['models']['classification_model'],
|
| 69 |
+
confidence_threshold=self.config['detection']['confidence_threshold']
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self.arduino = ArduinoBridge(
|
| 73 |
+
port=self.config['arduino']['port'],
|
| 74 |
+
baudrate=self.config['arduino']['baudrate']
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.coordinate_transformer = CoordinateTransformer(
|
| 78 |
+
camera_matrix_path=self.config['calibration']['camera_matrix'],
|
| 79 |
+
distortion_coeffs_path=self.config['calibration']['distortion_coeffs'],
|
| 80 |
+
stereo_calibration_path=self.config['calibration']['stereo_calibration']
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Threading for real-time processing
|
| 84 |
+
self.executor = ThreadPoolExecutor(max_workers=2)
|
| 85 |
+
self.processing_lock = threading.Lock()
|
| 86 |
+
|
| 87 |
+
logger.info("Strawberry Picker Pipeline initialized successfully")
|
| 88 |
+
|
| 89 |
+
def _load_config(self, config_path: str) -> Dict:
|
| 90 |
+
"""Load configuration from YAML file"""
|
| 91 |
+
try:
|
| 92 |
+
with open(config_path, 'r') as f:
|
| 93 |
+
return yaml.safe_load(f)
|
| 94 |
+
except FileNotFoundError:
|
| 95 |
+
logger.warning(f"Config file {config_path} not found, using defaults")
|
| 96 |
+
return self._get_default_config()
|
| 97 |
+
|
| 98 |
+
def _get_default_config(self) -> Dict:
|
| 99 |
+
"""Get default configuration"""
|
| 100 |
+
return {
|
| 101 |
+
'models': {
|
| 102 |
+
'detection_model': 'model/weights/best.pt',
|
| 103 |
+
'classification_model': 'model/ripeness_classifier.h5'
|
| 104 |
+
},
|
| 105 |
+
'detection': {
|
| 106 |
+
'confidence_threshold': 0.5,
|
| 107 |
+
'nms_threshold': 0.4
|
| 108 |
+
},
|
| 109 |
+
'arduino': {
|
| 110 |
+
'port': '/dev/ttyUSB0',
|
| 111 |
+
'baudrate': 115200
|
| 112 |
+
},
|
| 113 |
+
'calibration': {
|
| 114 |
+
'camera_matrix': 'calibration/camera_matrix.npy',
|
| 115 |
+
'distortion_coeffs': 'calibration/distortion_coeffs.npy',
|
| 116 |
+
'stereo_calibration': 'calibration/stereo_calibration.npz'
|
| 117 |
+
},
|
| 118 |
+
'picking': {
|
| 119 |
+
'max_targets_per_frame': 5,
|
| 120 |
+
'min_confidence': 0.7,
|
| 121 |
+
'pick_delay': 2.0, # seconds between picks
|
| 122 |
+
'safety_timeout': 30.0 # seconds
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def start(self):
|
| 127 |
+
"""Start the strawberry picker pipeline"""
|
| 128 |
+
logger.info("Starting Strawberry Picker Pipeline...")
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
# Initialize hardware
|
| 132 |
+
self.arduino.connect()
|
| 133 |
+
self.arduino.initialize_servos()
|
| 134 |
+
|
| 135 |
+
# Start processing
|
| 136 |
+
self.running = True
|
| 137 |
+
self._start_processing_loop()
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Failed to start pipeline: {e}")
|
| 141 |
+
self.stop()
|
| 142 |
+
raise
|
| 143 |
+
|
| 144 |
+
def stop(self):
|
| 145 |
+
"""Stop the strawberry picker pipeline"""
|
| 146 |
+
logger.info("Stopping Strawberry Picker Pipeline...")
|
| 147 |
+
self.running = False
|
| 148 |
+
|
| 149 |
+
# Close connections
|
| 150 |
+
if hasattr(self, 'arduino'):
|
| 151 |
+
self.arduino.disconnect()
|
| 152 |
+
|
| 153 |
+
if hasattr(self, 'executor'):
|
| 154 |
+
self.executor.shutdown(wait=True)
|
| 155 |
+
|
| 156 |
+
logger.info("Pipeline stopped successfully")
|
| 157 |
+
|
| 158 |
+
def _start_processing_loop(self):
|
| 159 |
+
"""Start the main processing loop"""
|
| 160 |
+
# Start camera capture in separate thread
|
| 161 |
+
capture_future = self.executor.submit(self._camera_capture_loop)
|
| 162 |
+
|
| 163 |
+
# Start picking loop
|
| 164 |
+
picking_future = self.executor.submit(self._picking_loop)
|
| 165 |
+
|
| 166 |
+
logger.info("Processing loops started")
|
| 167 |
+
|
| 168 |
+
def _camera_capture_loop(self):
|
| 169 |
+
"""Continuous camera capture and processing loop"""
|
| 170 |
+
cap = cv2.VideoCapture(self.config['camera']['index'])
|
| 171 |
+
|
| 172 |
+
if not cap.isOpened():
|
| 173 |
+
logger.error("Failed to open camera")
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.config['camera']['width'])
|
| 177 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.config['camera']['height'])
|
| 178 |
+
cap.set(cv2.CAP_PROP_FPS, self.config['camera']['fps'])
|
| 179 |
+
|
| 180 |
+
logger.info("Camera capture started")
|
| 181 |
+
|
| 182 |
+
while self.running:
|
| 183 |
+
ret, frame = cap.read()
|
| 184 |
+
if not ret:
|
| 185 |
+
logger.warning("Failed to capture frame")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Process frame for detection and classification
|
| 189 |
+
self._process_frame_async(frame)
|
| 190 |
+
|
| 191 |
+
# Display frame with annotations
|
| 192 |
+
self._display_frame(frame)
|
| 193 |
+
|
| 194 |
+
# Control frame rate
|
| 195 |
+
time.sleep(1.0 / self.config['camera']['fps'])
|
| 196 |
+
|
| 197 |
+
cap.release()
|
| 198 |
+
logger.info("Camera capture stopped")
|
| 199 |
+
|
| 200 |
+
def _process_frame_async(self, frame: np.ndarray):
|
| 201 |
+
"""Process frame asynchronously for detection and classification"""
|
| 202 |
+
def process():
|
| 203 |
+
try:
|
| 204 |
+
# Detect and classify strawberries
|
| 205 |
+
results = self.detector_classifier.process_frame(frame)
|
| 206 |
+
|
| 207 |
+
# Convert to picking targets
|
| 208 |
+
targets = self._create_picking_targets(results, frame)
|
| 209 |
+
|
| 210 |
+
# Update targets list
|
| 211 |
+
with self.processing_lock:
|
| 212 |
+
self.picking_targets = targets
|
| 213 |
+
self.processed_count += 1
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Frame processing error: {e}")
|
| 217 |
+
|
| 218 |
+
# Submit to thread pool
|
| 219 |
+
self.executor.submit(process)
|
| 220 |
+
|
| 221 |
+
def _create_picking_targets(self, detection_results: Dict, frame: np.ndarray) -> List[PickingTarget]:
|
| 222 |
+
"""Create picking targets from detection results"""
|
| 223 |
+
targets = []
|
| 224 |
+
|
| 225 |
+
for detection in detection_results.get('detections', []):
|
| 226 |
+
if detection['confidence'] < self.config['picking']['min_confidence']:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
# Get classification result
|
| 230 |
+
ripeness = detection.get('ripeness', 'unknown')
|
| 231 |
+
ripeness_confidence = detection.get('ripeness_confidence', 0.0)
|
| 232 |
+
|
| 233 |
+
# Only pick ripe strawberries
|
| 234 |
+
if ripeness != 'ripe':
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
# Calculate pixel coordinates
|
| 238 |
+
x, y, w, h = detection['bbox']
|
| 239 |
+
center_x = int(x + w / 2)
|
| 240 |
+
center_y = int(y + h / 2)
|
| 241 |
+
|
| 242 |
+
# Transform to world coordinates
|
| 243 |
+
try:
|
| 244 |
+
world_coords = self.coordinate_transformer.pixel_to_world(
|
| 245 |
+
center_x, center_y, frame.shape
|
| 246 |
+
)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.warning(f"Coordinate transformation failed: {e}")
|
| 249 |
+
world_coords = (0.0, 0.0, 0.0)
|
| 250 |
+
|
| 251 |
+
# Calculate priority (higher confidence = higher priority)
|
| 252 |
+
priority = detection['confidence'] * ripeness_confidence
|
| 253 |
+
|
| 254 |
+
target = PickingTarget(
|
| 255 |
+
bbox=detection['bbox'],
|
| 256 |
+
confidence=detection['confidence'],
|
| 257 |
+
ripeness=ripeness,
|
| 258 |
+
ripeness_confidence=ripeness_confidence,
|
| 259 |
+
pixel_coords=(center_x, center_y),
|
| 260 |
+
world_coords=world_coords,
|
| 261 |
+
priority=priority
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
targets.append(target)
|
| 265 |
+
|
| 266 |
+
# Sort by priority and limit number of targets
|
| 267 |
+
targets.sort(key=lambda t: t.priority, reverse=True)
|
| 268 |
+
return targets[:self.config['picking']['max_targets_per_frame']]
|
| 269 |
+
|
| 270 |
+
def _picking_loop(self):
|
| 271 |
+
"""Main picking loop"""
|
| 272 |
+
logger.info("Picking loop started")
|
| 273 |
+
|
| 274 |
+
last_pick_time = 0
|
| 275 |
+
safety_timeout = self.config['picking']['safety_timeout']
|
| 276 |
+
|
| 277 |
+
while self.running:
|
| 278 |
+
try:
|
| 279 |
+
current_time = time.time()
|
| 280 |
+
|
| 281 |
+
# Check if enough time has passed since last pick
|
| 282 |
+
if current_time - last_pick_time < self.config['picking']['pick_delay']:
|
| 283 |
+
time.sleep(0.1)
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
# Get current targets
|
| 287 |
+
with self.processing_lock:
|
| 288 |
+
targets = self.picking_targets.copy()
|
| 289 |
+
|
| 290 |
+
if not targets:
|
| 291 |
+
time.sleep(0.1)
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
# Select best target
|
| 295 |
+
target = targets[0]
|
| 296 |
+
|
| 297 |
+
# Execute pick
|
| 298 |
+
success = self._execute_pick(target)
|
| 299 |
+
|
| 300 |
+
if success:
|
| 301 |
+
self.successful_picks += 1
|
| 302 |
+
logger.info(f"Successful pick! Total: {self.successful_picks}")
|
| 303 |
+
else:
|
| 304 |
+
self.failed_picks += 1
|
| 305 |
+
logger.warning(f"Failed pick. Total failures: {self.failed_picks}")
|
| 306 |
+
|
| 307 |
+
last_pick_time = current_time
|
| 308 |
+
|
| 309 |
+
# Safety timeout check
|
| 310 |
+
if current_time - last_pick_time > safety_timeout:
|
| 311 |
+
logger.warning("Safety timeout reached, pausing picking")
|
| 312 |
+
time.sleep(5.0)
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.error(f"Picking loop error: {e}")
|
| 316 |
+
time.sleep(1.0)
|
| 317 |
+
|
| 318 |
+
logger.info("Picking loop stopped")
|
| 319 |
+
|
| 320 |
+
def _execute_pick(self, target: PickingTarget) -> bool:
|
| 321 |
+
"""Execute a picking action for the given target"""
|
| 322 |
+
try:
|
| 323 |
+
logger.info(f"Executing pick for target at {target.pixel_coords}")
|
| 324 |
+
|
| 325 |
+
# Move to target position
|
| 326 |
+
x, y, z = target.world_coords
|
| 327 |
+
self.arduino.move_to_position(x, y, z)
|
| 328 |
+
|
| 329 |
+
# Wait for movement to complete
|
| 330 |
+
time.sleep(2.0)
|
| 331 |
+
|
| 332 |
+
# Close gripper
|
| 333 |
+
self.arduino.close_gripper()
|
| 334 |
+
time.sleep(1.0)
|
| 335 |
+
|
| 336 |
+
# Lift strawberry
|
| 337 |
+
self.arduino.move_to_position(x, y, z + 0.1)
|
| 338 |
+
time.sleep(1.0)
|
| 339 |
+
|
| 340 |
+
# Move to collection area
|
| 341 |
+
collection_pos = self.config['picking']['collection_position']
|
| 342 |
+
self.arduino.move_to_position(*collection_pos)
|
| 343 |
+
time.sleep(2.0)
|
| 344 |
+
|
| 345 |
+
# Open gripper to release strawberry
|
| 346 |
+
self.arduino.open_gripper()
|
| 347 |
+
time.sleep(1.0)
|
| 348 |
+
|
| 349 |
+
# Return to home position
|
| 350 |
+
home_pos = self.config['picking']['home_position']
|
| 351 |
+
self.arduino.move_to_position(*home_pos)
|
| 352 |
+
time.sleep(2.0)
|
| 353 |
+
|
| 354 |
+
logger.info("Pick sequence completed successfully")
|
| 355 |
+
return True
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.error(f"Pick execution failed: {e}")
|
| 359 |
+
return False
|
| 360 |
+
|
| 361 |
+
def _display_frame(self, frame: np.ndarray):
|
| 362 |
+
"""Display frame with annotations"""
|
| 363 |
+
try:
|
| 364 |
+
# Add annotations for current targets
|
| 365 |
+
with self.processing_lock:
|
| 366 |
+
targets = self.picking_targets
|
| 367 |
+
|
| 368 |
+
for i, target in enumerate(targets):
|
| 369 |
+
x, y, w, h = target.bbox
|
| 370 |
+
|
| 371 |
+
# Draw bounding box
|
| 372 |
+
color = (0, 255, 0) if target.ripeness == 'ripe' else (0, 0, 255)
|
| 373 |
+
cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
|
| 374 |
+
|
| 375 |
+
# Add label
|
| 376 |
+
label = f"{target.ripeness} ({target.confidence:.2f})"
|
| 377 |
+
cv2.putText(frame, label, (x, y - 10),
|
| 378 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
| 379 |
+
|
| 380 |
+
# Add priority indicator
|
| 381 |
+
cv2.putText(frame, f"P{i+1}: {target.priority:.2f}",
|
| 382 |
+
(x, y + h + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
| 383 |
+
(255, 255, 0), 2)
|
| 384 |
+
|
| 385 |
+
# Add status information
|
| 386 |
+
status_text = f"Processed: {self.processed_count} | Success: {self.successful_picks} | Failed: {self.failed_picks}"
|
| 387 |
+
cv2.putText(frame, status_text, (10, 30),
|
| 388 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 389 |
+
|
| 390 |
+
# Display frame
|
| 391 |
+
cv2.imshow('Strawberry Picker Pipeline', frame)
|
| 392 |
+
cv2.waitKey(1)
|
| 393 |
+
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(f"Frame display error: {e}")
|
| 396 |
+
|
| 397 |
+
def get_statistics(self) -> Dict:
|
| 398 |
+
"""Get pipeline statistics"""
|
| 399 |
+
with self.processing_lock:
|
| 400 |
+
return {
|
| 401 |
+
'processed_frames': self.processed_count,
|
| 402 |
+
'successful_picks': self.successful_picks,
|
| 403 |
+
'failed_picks': self.failed_picks,
|
| 404 |
+
'success_rate': self.successful_picks / max(1, self.successful_picks + self.failed_picks),
|
| 405 |
+
'current_targets': len(self.picking_targets)
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
def save_statistics(self, filepath: str):
|
| 409 |
+
"""Save statistics to file"""
|
| 410 |
+
stats = self.get_statistics()
|
| 411 |
+
stats['timestamp'] = time.time()
|
| 412 |
+
|
| 413 |
+
with open(filepath, 'w') as f:
|
| 414 |
+
json.dump(stats, f, indent=2)
|
| 415 |
+
|
| 416 |
+
logger.info(f"Statistics saved to {filepath}")
|
| 417 |
+
|
| 418 |
+
def main():
|
| 419 |
+
"""Main function"""
|
| 420 |
+
parser = argparse.ArgumentParser(description='Strawberry Picker Pipeline')
|
| 421 |
+
parser.add_argument('--config', default='config.yaml', help='Configuration file path')
|
| 422 |
+
parser.add_argument('--test', action='store_true', help='Run in test mode without hardware')
|
| 423 |
+
parser.add_argument('--save-stats', help='Save statistics to file')
|
| 424 |
+
|
| 425 |
+
args = parser.parse_args()
|
| 426 |
+
|
| 427 |
+
try:
|
| 428 |
+
# Initialize pipeline
|
| 429 |
+
pipeline = StrawberryPickerPipeline(args.config)
|
| 430 |
+
|
| 431 |
+
if args.test:
|
| 432 |
+
logger.info("Running in test mode - no hardware control")
|
| 433 |
+
|
| 434 |
+
# Start pipeline
|
| 435 |
+
pipeline.start()
|
| 436 |
+
|
| 437 |
+
# Keep running until interrupted
|
| 438 |
+
try:
|
| 439 |
+
while True:
|
| 440 |
+
time.sleep(1)
|
| 441 |
+
|
| 442 |
+
# Print statistics periodically
|
| 443 |
+
stats = pipeline.get_statistics()
|
| 444 |
+
if stats['processed_frames'] % 100 == 0:
|
| 445 |
+
logger.info(f"Statistics: {stats}")
|
| 446 |
+
|
| 447 |
+
except KeyboardInterrupt:
|
| 448 |
+
logger.info("Received interrupt signal")
|
| 449 |
+
|
| 450 |
+
finally:
|
| 451 |
+
pipeline.stop()
|
| 452 |
+
|
| 453 |
+
if args.save_stats:
|
| 454 |
+
pipeline.save_statistics(args.save_stats)
|
| 455 |
+
|
| 456 |
+
logger.info("Pipeline execution completed")
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"Pipeline execution failed: {e}")
|
| 460 |
+
return 1
|
| 461 |
+
|
| 462 |
+
return 0
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
exit(main())
|
sync_to_huggingface.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Sync Strawberry Picker Models to HuggingFace Repository
|
| 4 |
+
|
| 5 |
+
This script automates the process of syncing trained models from strawberryPicker
|
| 6 |
+
to the HuggingFace strawberryPicker repository.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python sync_to_huggingface.py [--dry-run]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import shutil
|
| 16 |
+
import argparse
|
| 17 |
+
import subprocess
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
import hashlib
|
| 21 |
+
|
| 22 |
+
def calculate_file_hash(file_path):
|
| 23 |
+
"""Calculate SHA256 hash of a file."""
|
| 24 |
+
sha256_hash = hashlib.sha256()
|
| 25 |
+
with open(file_path, "rb") as f:
|
| 26 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
| 27 |
+
sha256_hash.update(byte_block)
|
| 28 |
+
return sha256_hash.hexdigest()
|
| 29 |
+
|
| 30 |
+
def find_updated_models(repo_path):
|
| 31 |
+
"""Find models that have been updated."""
|
| 32 |
+
updated_models = []
|
| 33 |
+
|
| 34 |
+
# Check detection model
|
| 35 |
+
detection_pt = Path(repo_path) / "detection" / "best.pt"
|
| 36 |
+
if detection_pt.exists():
|
| 37 |
+
detection_hash = calculate_file_hash(detection_pt)
|
| 38 |
+
updated_models.append({
|
| 39 |
+
'component': 'detection',
|
| 40 |
+
'path': detection_pt,
|
| 41 |
+
'hash': detection_hash
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
# Check classification model
|
| 45 |
+
classification_pth = Path(repo_path) / "classification" / "best_enhanced_classifier.pth"
|
| 46 |
+
if classification_pth.exists():
|
| 47 |
+
classification_hash = calculate_file_hash(classification_pth)
|
| 48 |
+
updated_models.append({
|
| 49 |
+
'component': 'classification',
|
| 50 |
+
'path': classification_pth,
|
| 51 |
+
'hash': classification_hash
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
return updated_models
|
| 55 |
+
|
| 56 |
+
def export_detection_to_onnx(model_path, output_dir):
|
| 57 |
+
"""Export detection model to ONNX format."""
|
| 58 |
+
try:
|
| 59 |
+
cmd = [
|
| 60 |
+
"yolo", "export",
|
| 61 |
+
f"model={model_path}",
|
| 62 |
+
f"dir={output_dir}",
|
| 63 |
+
"format=onnx",
|
| 64 |
+
"opset=12"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
print(f"Exporting detection model to ONNX...")
|
| 68 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 69 |
+
|
| 70 |
+
if result.returncode == 0:
|
| 71 |
+
print("Successfully exported detection model to ONNX")
|
| 72 |
+
return True
|
| 73 |
+
else:
|
| 74 |
+
print(f"Export failed: {result.stderr}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error during ONNX export: {e}")
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
def update_model_metadata(repo_path, models_info):
|
| 82 |
+
"""Update metadata files with sync information."""
|
| 83 |
+
metadata = {
|
| 84 |
+
"last_sync": datetime.now().isoformat(),
|
| 85 |
+
"models": {}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
for model in models_info:
|
| 89 |
+
metadata["models"][model['component']] = {
|
| 90 |
+
"hash": model['hash'],
|
| 91 |
+
"path": str(model['path']),
|
| 92 |
+
"last_updated": datetime.now().isoformat()
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
metadata_file = Path(repo_path) / "sync_metadata.json"
|
| 96 |
+
with open(metadata_file, 'w') as f:
|
| 97 |
+
json.dump(metadata, f, indent=2)
|
| 98 |
+
|
| 99 |
+
print(f"Updated metadata file: {metadata_file}")
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
def sync_repository(repo_path, dry_run=False):
|
| 103 |
+
"""Main sync function for strawberryPicker repository."""
|
| 104 |
+
print(f"Syncing strawberryPicker repository at {repo_path}")
|
| 105 |
+
|
| 106 |
+
if dry_run:
|
| 107 |
+
print("DRY RUN MODE - No changes will be made")
|
| 108 |
+
|
| 109 |
+
# Find updated models
|
| 110 |
+
updated_models = find_updated_models(repo_path)
|
| 111 |
+
|
| 112 |
+
if not updated_models:
|
| 113 |
+
print("No models found to sync")
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
print(f"Found {len(updated_models)} model components:")
|
| 117 |
+
for model in updated_models:
|
| 118 |
+
print(f" - {model['component']} (hash: {model['hash'][:8]}...)")
|
| 119 |
+
|
| 120 |
+
if dry_run:
|
| 121 |
+
return True
|
| 122 |
+
|
| 123 |
+
# Export detection model to ONNX if needed
|
| 124 |
+
detection_model = next((m for m in updated_models if m['component'] == 'detection'), None)
|
| 125 |
+
if detection_model:
|
| 126 |
+
detection_dir = Path(repo_path) / "detection"
|
| 127 |
+
export_detection_to_onnx(detection_model['path'], detection_dir)
|
| 128 |
+
|
| 129 |
+
# Update metadata
|
| 130 |
+
update_model_metadata(repo_path, updated_models)
|
| 131 |
+
|
| 132 |
+
print("\nSync completed successfully!")
|
| 133 |
+
print("Remember to:")
|
| 134 |
+
print("1. Review and commit changes to git")
|
| 135 |
+
print("2. Push to HuggingFace: git push origin main")
|
| 136 |
+
print("3. Update READMEs with any new performance metrics")
|
| 137 |
+
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
def main():
|
| 141 |
+
parser = argparse.ArgumentParser(description="Sync strawberryPicker models to HuggingFace")
|
| 142 |
+
parser.add_argument("--repo-path",
|
| 143 |
+
default="/home/user/machine-learning/HuggingfaceModels/strawberryPicker",
|
| 144 |
+
help="Path to strawberryPicker repository")
|
| 145 |
+
parser.add_argument("--dry-run", action="store_true",
|
| 146 |
+
help="Show what would be done without making changes")
|
| 147 |
+
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
# Validate path
|
| 151 |
+
if not os.path.exists(args.repo_path):
|
| 152 |
+
print(f"Error: Repository path {args.repo_path} does not exist")
|
| 153 |
+
sys.exit(1)
|
| 154 |
+
|
| 155 |
+
# Run sync
|
| 156 |
+
success = sync_repository(args.repo_path, args.dry_run)
|
| 157 |
+
|
| 158 |
+
if not success:
|
| 159 |
+
sys.exit(1)
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
webcam_inference.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real-time Strawberry Detection and Ripeness Classification using Webcam
|
| 4 |
+
Optimized for WSL (Windows Subsystem for Linux) environments
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import argparse
|
| 12 |
+
import time
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import sys
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
# Suppress warnings
|
| 18 |
+
warnings.filterwarnings('ignore')
|
| 19 |
+
|
| 20 |
+
class StrawberryPickerWebcam:
|
| 21 |
+
def __init__(self, detector_path, classifier_path, device='cpu'):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the strawberry picker system
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
detector_path: Path to YOLOv8 detection model
|
| 27 |
+
classifier_path: Path to EfficientNet classification model
|
| 28 |
+
device: Device to run inference on ('cpu' or 'cuda')
|
| 29 |
+
"""
|
| 30 |
+
print("🍓 Initializing Strawberry Picker AI System...")
|
| 31 |
+
|
| 32 |
+
self.device = device
|
| 33 |
+
self.ripeness_classes = ['unripe', 'partially-ripe', 'ripe', 'overripe']
|
| 34 |
+
|
| 35 |
+
# Color mapping for visualization
|
| 36 |
+
self.colors = {
|
| 37 |
+
'unripe': (0, 255, 0), # Green
|
| 38 |
+
'partially-ripe': (0, 255, 255), # Yellow
|
| 39 |
+
'ripe': (0, 0, 255), # Red
|
| 40 |
+
'overripe': (128, 0, 128) # Purple
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Load detection model
|
| 44 |
+
print("Loading detection model...")
|
| 45 |
+
try:
|
| 46 |
+
from ultralytics import YOLO
|
| 47 |
+
self.detector = YOLO(detector_path)
|
| 48 |
+
print("✅ Detection model loaded successfully")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"❌ Error loading detection model: {e}")
|
| 51 |
+
sys.exit(1)
|
| 52 |
+
|
| 53 |
+
# Load classification model
|
| 54 |
+
print("Loading classification model...")
|
| 55 |
+
try:
|
| 56 |
+
self.classifier = torch.load(classifier_path, map_location=device)
|
| 57 |
+
self.classifier.eval()
|
| 58 |
+
print("✅ Classification model loaded successfully")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"❌ Error loading classification model: {e}")
|
| 61 |
+
sys.exit(1)
|
| 62 |
+
|
| 63 |
+
# Setup preprocessing
|
| 64 |
+
self.transform = transforms.Compose([
|
| 65 |
+
transforms.Resize((128, 128)),
|
| 66 |
+
transforms.ToTensor(),
|
| 67 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 68 |
+
std=[0.229, 0.224, 0.225])
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
print("✅ System initialized and ready!")
|
| 72 |
+
|
| 73 |
+
def detect_and_classify(self, frame):
|
| 74 |
+
"""
|
| 75 |
+
Detect strawberries and classify their ripeness in a frame
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
frame: Input frame (BGR format)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
results: List of detection/classification results
|
| 82 |
+
visualized_frame: Frame with visualizations
|
| 83 |
+
"""
|
| 84 |
+
# Convert to RGB
|
| 85 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 86 |
+
|
| 87 |
+
# Detect strawberries
|
| 88 |
+
detection_results = self.detector(frame_rgb)
|
| 89 |
+
|
| 90 |
+
results = []
|
| 91 |
+
|
| 92 |
+
for result in detection_results:
|
| 93 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
| 94 |
+
confidences = result.boxes.conf.cpu().numpy()
|
| 95 |
+
|
| 96 |
+
for box, conf in zip(boxes, confidences):
|
| 97 |
+
if conf < 0.5: # Confidence threshold
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
x1, y1, x2, y2 = map(int, box)
|
| 101 |
+
|
| 102 |
+
# Ensure coordinates are within frame bounds
|
| 103 |
+
x1 = max(0, x1)
|
| 104 |
+
y1 = max(0, y1)
|
| 105 |
+
x2 = min(frame.shape[1], x2)
|
| 106 |
+
y2 = min(frame.shape[0], y2)
|
| 107 |
+
|
| 108 |
+
# Crop strawberry
|
| 109 |
+
crop = frame_rgb[y1:y2, x1:x2]
|
| 110 |
+
|
| 111 |
+
if crop.size == 0:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Classify ripeness
|
| 115 |
+
try:
|
| 116 |
+
crop_pil = Image.fromarray(crop)
|
| 117 |
+
input_tensor = self.transform(crop_pil).unsqueeze(0).to(self.device)
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
output = self.classifier(input_tensor)
|
| 121 |
+
probabilities = torch.softmax(output, dim=1)
|
| 122 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 123 |
+
confidence = probabilities[0][predicted_class].item()
|
| 124 |
+
|
| 125 |
+
ripeness = self.ripeness_classes[predicted_class]
|
| 126 |
+
|
| 127 |
+
results.append({
|
| 128 |
+
'bbox': (x1, y1, x2, y2),
|
| 129 |
+
'ripeness': ripeness,
|
| 130 |
+
'confidence': confidence,
|
| 131 |
+
'detection_confidence': float(conf)
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"Warning: Error classifying crop: {e}")
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
return results
|
| 139 |
+
|
| 140 |
+
def visualize(self, frame, results):
|
| 141 |
+
"""
|
| 142 |
+
Draw bounding boxes and labels on frame
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
frame: Input frame
|
| 146 |
+
results: Detection/classification results
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
visualized_frame: Frame with drawings
|
| 150 |
+
"""
|
| 151 |
+
vis_frame = frame.copy()
|
| 152 |
+
|
| 153 |
+
for result in results:
|
| 154 |
+
x1, y1, x2, y2 = result['bbox']
|
| 155 |
+
ripeness = result['ripeness']
|
| 156 |
+
conf = result['confidence']
|
| 157 |
+
|
| 158 |
+
# Draw bounding box
|
| 159 |
+
color = self.colors[ripeness]
|
| 160 |
+
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), color, 2)
|
| 161 |
+
|
| 162 |
+
# Draw label background
|
| 163 |
+
label = f"{ripeness} ({conf:.2f})"
|
| 164 |
+
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
|
| 165 |
+
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10),
|
| 166 |
+
(x1 + label_size[0], y1), color, -1)
|
| 167 |
+
|
| 168 |
+
# Draw label text
|
| 169 |
+
cv2.putText(vis_frame, label, (x1, y1 - 5),
|
| 170 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 171 |
+
|
| 172 |
+
# Add FPS counter
|
| 173 |
+
fps_text = f"FPS: {self.fps:.1f}"
|
| 174 |
+
cv2.putText(vis_frame, fps_text, (10, 30),
|
| 175 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
| 176 |
+
|
| 177 |
+
# Add title
|
| 178 |
+
title = "Strawberry Picker AI - Press 'q' to quit"
|
| 179 |
+
cv2.putText(vis_frame, title, (10, 60),
|
| 180 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
| 181 |
+
|
| 182 |
+
return vis_frame
|
| 183 |
+
|
| 184 |
+
def run_webcam(self, camera_index=0, width=640, height=480):
|
| 185 |
+
"""
|
| 186 |
+
Run real-time inference on webcam
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
camera_index: Camera index (0 for default webcam)
|
| 190 |
+
width: Frame width
|
| 191 |
+
height: Frame height
|
| 192 |
+
"""
|
| 193 |
+
print(f"\n📹 Starting webcam (camera {camera_index})...")
|
| 194 |
+
print("Press 'q' to quit, 's' to save screenshot")
|
| 195 |
+
print("Make sure strawberries are well-lit and clearly visible\n")
|
| 196 |
+
|
| 197 |
+
# Try to open webcam
|
| 198 |
+
cap = cv2.VideoCapture(camera_index)
|
| 199 |
+
|
| 200 |
+
if not cap.isOpened():
|
| 201 |
+
print(f"❌ Error: Could not open camera {camera_index}")
|
| 202 |
+
print("\nTroubleshooting tips for WSL:")
|
| 203 |
+
print("1. Install v4l2loopback: sudo apt-get install v4l2loopback-dkms")
|
| 204 |
+
print("2. Load module: sudo modprobe v4l2loopback")
|
| 205 |
+
print("3. Use IP webcam app on phone as alternative")
|
| 206 |
+
print("4. Or use pre-recorded video file")
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
# Set camera properties
|
| 210 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
| 211 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
| 212 |
+
|
| 213 |
+
# FPS tracking
|
| 214 |
+
self.fps = 0
|
| 215 |
+
frame_count = 0
|
| 216 |
+
start_time = time.time()
|
| 217 |
+
|
| 218 |
+
# Screenshot counter
|
| 219 |
+
screenshot_count = 0
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
while True:
|
| 223 |
+
# Read frame
|
| 224 |
+
ret, frame = cap.read()
|
| 225 |
+
|
| 226 |
+
if not ret:
|
| 227 |
+
print("❌ Error: Could not read frame from camera")
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
# Detect and classify
|
| 231 |
+
results = self.detect_and_classify(frame)
|
| 232 |
+
|
| 233 |
+
# Visualize results
|
| 234 |
+
vis_frame = self.visualize(frame, results)
|
| 235 |
+
|
| 236 |
+
# Calculate FPS
|
| 237 |
+
frame_count += 1
|
| 238 |
+
if frame_count % 10 == 0:
|
| 239 |
+
elapsed = time.time() - start_time
|
| 240 |
+
self.fps = frame_count / elapsed
|
| 241 |
+
|
| 242 |
+
# Display frame
|
| 243 |
+
cv2.imshow('Strawberry Picker AI', vis_frame)
|
| 244 |
+
|
| 245 |
+
# Handle keyboard input
|
| 246 |
+
key = cv2.waitKey(1) & 0xFF
|
| 247 |
+
|
| 248 |
+
if key == ord('q'):
|
| 249 |
+
print("\n👋 Quitting...")
|
| 250 |
+
break
|
| 251 |
+
elif key == ord('s'):
|
| 252 |
+
# Save screenshot
|
| 253 |
+
screenshot_path = f"screenshot_{screenshot_count}.jpg"
|
| 254 |
+
cv2.imwrite(screenshot_path, vis_frame)
|
| 255 |
+
print(f"📸 Screenshot saved: {screenshot_path}")
|
| 256 |
+
screenshot_count += 1
|
| 257 |
+
|
| 258 |
+
except KeyboardInterrupt:
|
| 259 |
+
print("\n👋 Interrupted by user")
|
| 260 |
+
|
| 261 |
+
finally:
|
| 262 |
+
# Cleanup
|
| 263 |
+
cap.release()
|
| 264 |
+
cv2.destroyAllWindows()
|
| 265 |
+
print("✅ Webcam session ended")
|
| 266 |
+
|
| 267 |
+
def run_video_file(self, video_path):
|
| 268 |
+
"""
|
| 269 |
+
Run inference on a video file
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
video_path: Path to video file
|
| 273 |
+
"""
|
| 274 |
+
print(f"\n🎬 Processing video: {video_path}")
|
| 275 |
+
|
| 276 |
+
cap = cv2.VideoCapture(video_path)
|
| 277 |
+
|
| 278 |
+
if not cap.isOpened():
|
| 279 |
+
print(f"❌ Error: Could not open video file: {video_path}")
|
| 280 |
+
return
|
| 281 |
+
|
| 282 |
+
# Get video properties
|
| 283 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 284 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 285 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 286 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 287 |
+
|
| 288 |
+
print(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames")
|
| 289 |
+
|
| 290 |
+
# Setup output video
|
| 291 |
+
output_path = f"output_{Path(video_path).name}"
|
| 292 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 293 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 294 |
+
|
| 295 |
+
frame_count = 0
|
| 296 |
+
start_time = time.time()
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
while True:
|
| 300 |
+
ret, frame = cap.read()
|
| 301 |
+
|
| 302 |
+
if not ret:
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
+
# Process frame
|
| 306 |
+
results = self.detect_and_classify(frame)
|
| 307 |
+
vis_frame = self.visualize(frame, results)
|
| 308 |
+
|
| 309 |
+
# Write to output
|
| 310 |
+
out.write(vis_frame)
|
| 311 |
+
|
| 312 |
+
# Display progress
|
| 313 |
+
frame_count += 1
|
| 314 |
+
if frame_count % 30 == 0:
|
| 315 |
+
progress = (frame_count / total_frames) * 100
|
| 316 |
+
elapsed = time.time() - start_time
|
| 317 |
+
print(f"Progress: {progress:.1f}% ({frame_count}/{total_frames}) - "
|
| 318 |
+
f"Time: {elapsed:.1f}s")
|
| 319 |
+
|
| 320 |
+
except KeyboardInterrupt:
|
| 321 |
+
print("\n👋 Interrupted by user")
|
| 322 |
+
|
| 323 |
+
finally:
|
| 324 |
+
cap.release()
|
| 325 |
+
out.release()
|
| 326 |
+
cv2.destroyAllWindows()
|
| 327 |
+
print(f"✅ Video processing complete. Output saved to: {output_path}")
|
| 328 |
+
|
| 329 |
+
def main():
|
| 330 |
+
"""Main function with argument parsing"""
|
| 331 |
+
parser = argparse.ArgumentParser(
|
| 332 |
+
description='Real-time Strawberry Detection and Ripeness Classification'
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
parser.add_argument(
|
| 336 |
+
'--detector',
|
| 337 |
+
type=str,
|
| 338 |
+
default='detection_model/best.pt',
|
| 339 |
+
help='Path to YOLOv8 detection model'
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
'--classifier',
|
| 344 |
+
type=str,
|
| 345 |
+
default='classification_model/best_enhanced_classifier.pth',
|
| 346 |
+
help='Path to EfficientNet classification model'
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
parser.add_argument(
|
| 350 |
+
'--mode',
|
| 351 |
+
type=str,
|
| 352 |
+
choices=['webcam', 'video'],
|
| 353 |
+
default='webcam',
|
| 354 |
+
help='Mode: webcam or video file'
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
'--input',
|
| 359 |
+
type=str,
|
| 360 |
+
help='Path to video file (if mode=video)'
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
'--camera',
|
| 365 |
+
type=int,
|
| 366 |
+
default=0,
|
| 367 |
+
help='Camera index (default: 0)'
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
'--width',
|
| 372 |
+
type=int,
|
| 373 |
+
default=640,
|
| 374 |
+
help='Camera frame width'
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
'--height',
|
| 379 |
+
type=int,
|
| 380 |
+
default=480,
|
| 381 |
+
help='Camera frame height'
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
'--device',
|
| 386 |
+
type=str,
|
| 387 |
+
default='auto',
|
| 388 |
+
choices=['auto', 'cpu', 'cuda'],
|
| 389 |
+
help='Device to use for inference'
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
args = parser.parse_args()
|
| 393 |
+
|
| 394 |
+
# Determine device
|
| 395 |
+
if args.device == 'auto':
|
| 396 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 397 |
+
else:
|
| 398 |
+
device = args.device
|
| 399 |
+
|
| 400 |
+
print(f"Using device: {device}")
|
| 401 |
+
|
| 402 |
+
if device == 'cpu':
|
| 403 |
+
print("⚠️ Running on CPU - this will be slower. Consider using GPU if available.")
|
| 404 |
+
|
| 405 |
+
# Initialize system
|
| 406 |
+
try:
|
| 407 |
+
picker = StrawberryPickerWebcam(
|
| 408 |
+
detector_path=args.detector,
|
| 409 |
+
classifier_path=args.classifier,
|
| 410 |
+
device=device
|
| 411 |
+
)
|
| 412 |
+
except Exception as e:
|
| 413 |
+
print(f"❌ Failed to initialize system: {e}")
|
| 414 |
+
sys.exit(1)
|
| 415 |
+
|
| 416 |
+
# Run inference
|
| 417 |
+
if args.mode == 'webcam':
|
| 418 |
+
picker.run_webcam(
|
| 419 |
+
camera_index=args.camera,
|
| 420 |
+
width=args.width,
|
| 421 |
+
height=args.height
|
| 422 |
+
)
|
| 423 |
+
elif args.mode == 'video':
|
| 424 |
+
if not args.input:
|
| 425 |
+
print("❌ Error: --input required for video mode")
|
| 426 |
+
sys.exit(1)
|
| 427 |
+
picker.run_video_file(args.input)
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
# Check for required libraries
|
| 431 |
+
try:
|
| 432 |
+
import torch
|
| 433 |
+
import cv2
|
| 434 |
+
from PIL import Image
|
| 435 |
+
from torchvision import transforms
|
| 436 |
+
except ImportError as e:
|
| 437 |
+
print(f"❌ Missing required library: {e}")
|
| 438 |
+
print("Install with: pip install torch torchvision opencv-python pillow")
|
| 439 |
+
sys.exit(1)
|
| 440 |
+
|
| 441 |
+
main()
|
yolov11n/README.md
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- object-detection
|
| 4 |
+
- yolo
|
| 5 |
+
- yolov11
|
| 6 |
+
- strawberry
|
| 7 |
+
- agriculture
|
| 8 |
+
- robotics
|
| 9 |
+
- computer-vision
|
| 10 |
+
- pytorch
|
| 11 |
+
- onnx
|
| 12 |
+
license: mit
|
| 13 |
+
datasets:
|
| 14 |
+
- theonegareth/strawberry-detect
|
| 15 |
+
language:
|
| 16 |
+
- python
|
| 17 |
+
pretty_name: YOLOv11n Strawberry Detection
|
| 18 |
+
description: YOLOv11 Nano model for strawberry detection using latest architecture
|
| 19 |
+
pipeline_tag: object-detection
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# YOLOv11n Strawberry Detection Model
|
| 23 |
+
|
| 24 |
+
This directory contains the YOLOv11 Nano model for strawberry detection, utilizing the latest YOLO architecture improvements.
|
| 25 |
+
|
| 26 |
+
## 📊 Model Performance
|
| 27 |
+
|
| 28 |
+
| Metric | Value |
|
| 29 |
+
|--------|-------|
|
| 30 |
+
| **mAP@50** | TBD |
|
| 31 |
+
| **mAP@50-95** | TBD |
|
| 32 |
+
| **Inference Speed** | TBD |
|
| 33 |
+
| **Model Size** | TBD |
|
| 34 |
+
| **Parameters** | TBD |
|
| 35 |
+
|
| 36 |
+
*Performance metrics will be updated after validation testing*
|
| 37 |
+
|
| 38 |
+
## 🚀 Quick Start
|
| 39 |
+
|
| 40 |
+
### Installation
|
| 41 |
+
```bash
|
| 42 |
+
pip install ultralytics opencv-python
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### Python Inference
|
| 46 |
+
```python
|
| 47 |
+
from ultralytics import YOLO
|
| 48 |
+
|
| 49 |
+
# Load model
|
| 50 |
+
model = YOLO('strawberry_yolov11n.pt')
|
| 51 |
+
|
| 52 |
+
# Run inference
|
| 53 |
+
results = model('image.jpg', conf=0.25)
|
| 54 |
+
|
| 55 |
+
# Process results
|
| 56 |
+
for result in results:
|
| 57 |
+
boxes = result.boxes
|
| 58 |
+
for box in boxes:
|
| 59 |
+
cls = int(box.cls)
|
| 60 |
+
conf = float(box.conf)
|
| 61 |
+
xyxy = box.xyxy
|
| 62 |
+
print(f"Strawberry detected: {conf:.2f} confidence at {xyxy}")
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Command Line
|
| 66 |
+
```bash
|
| 67 |
+
# Single image
|
| 68 |
+
yolo predict model=strawberry_yolov11n.pt source='image.jpg'
|
| 69 |
+
|
| 70 |
+
# Webcam
|
| 71 |
+
yolo predict model=strawberry_yolov11n.pt source=0
|
| 72 |
+
|
| 73 |
+
# Video
|
| 74 |
+
yolo predict model=strawberry_yolov11n.pt source='video.mp4'
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## 📁 Files
|
| 78 |
+
|
| 79 |
+
- `strawberry_yolov11n.pt` - PyTorch model weights
|
| 80 |
+
- `strawberry_yolov11n.onnx` - ONNX model for deployment
|
| 81 |
+
|
| 82 |
+
## 🎯 Use Cases
|
| 83 |
+
|
| 84 |
+
- **Latest Architecture Testing**: Evaluation of YOLOv11 improvements
|
| 85 |
+
- **Edge Deployment**: Optimized for modern edge devices
|
| 86 |
+
- **Research Applications**: Academic and industrial research
|
| 87 |
+
- **Future Deployment**: Next-generation robotic systems
|
| 88 |
+
|
| 89 |
+
## 🔧 Technical Details
|
| 90 |
+
|
| 91 |
+
- **Architecture**: YOLOv11n (Nano)
|
| 92 |
+
- **Input Size**: 640x640
|
| 93 |
+
- **Training Dataset**: Enhanced Strawberry Dataset
|
| 94 |
+
- **Training Epochs**: TBD
|
| 95 |
+
- **Batch Size**: TBD
|
| 96 |
+
- **Optimizer**: TBD
|
| 97 |
+
- **Learning Rate**: TBD
|
| 98 |
+
|
| 99 |
+
## 📈 Training Configuration
|
| 100 |
+
|
| 101 |
+
*Training configuration will be updated after model validation*
|
| 102 |
+
|
| 103 |
+
```yaml
|
| 104 |
+
model: yolov11n.pt
|
| 105 |
+
epochs: TBD
|
| 106 |
+
batch: TBD
|
| 107 |
+
imgsz: 640
|
| 108 |
+
optimizer: TBD
|
| 109 |
+
lr0: TBD
|
| 110 |
+
# Additional hyperparameters will be added after validation
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## 🔗 Related Models
|
| 114 |
+
|
| 115 |
+
- [YOLOv8n](../yolov8n/) - Proven YOLOv8 nano model
|
| 116 |
+
- [YOLOv8s](../yolov8s/) - Higher accuracy YOLOv8 small model
|
| 117 |
+
|
| 118 |
+
## 📚 Documentation
|
| 119 |
+
|
| 120 |
+
- [Training Pipeline](https://github.com/theonegareth/strawberryPicker)
|
| 121 |
+
- [Dataset](https://universe.roboflow.com/theonegareth/strawberry-detect)
|
| 122 |
+
- [ROS2 Integration](https://github.com/theonegareth/strawberryPicker/blob/main/ROS2_INTEGRATION_PLAN.md)
|
| 123 |
+
|
| 124 |
+
## ⚠️ Note
|
| 125 |
+
|
| 126 |
+
This model is currently in testing phase. Performance metrics and training details will be updated after comprehensive validation. For production deployment, consider using the validated [YOLOv8n](../yolov8n/) or [YOLOv8s](../yolov8s/) models.
|
| 127 |
+
|
| 128 |
+
## 📄 License
|
| 129 |
+
|
| 130 |
+
MIT License - See main repository for details.
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
**Model Version**: 0.1.0 (Testing)
|
| 135 |
+
**Training Date**: December 2025
|
| 136 |
+
**Status**: Under validation - Not recommended for production use
|