Commit
Β·
238f86d
1
Parent(s):
9470ff7
chore: Upload missing project files
Browse files- Dockerfile.train +39 -0
- data/images/.gitkeep +0 -0
- embeddings/.gitkeep +0 -0
- pytest.ini +5 -0
- src/classifiers_mlp.py +11 -2
- src/vision_embeddings_tf.py +2 -2
Dockerfile.train
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the official Python 3.9.6 image from DockerHub
|
| 2 |
+
FROM python:3.9.6-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy the requirements file into the container
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
# Install necessary system packages for h5py and TensorFlow
|
| 11 |
+
RUN apt-get update && apt-get install -y \
|
| 12 |
+
build-essential \
|
| 13 |
+
pkg-config \
|
| 14 |
+
libhdf5-dev \
|
| 15 |
+
zlib1g-dev \
|
| 16 |
+
libjpeg-dev \
|
| 17 |
+
liblapack-dev \
|
| 18 |
+
libblas-dev \
|
| 19 |
+
gfortran
|
| 20 |
+
|
| 21 |
+
# Install pip 21.2.3
|
| 22 |
+
RUN pip install --upgrade pip==21.2.3
|
| 23 |
+
|
| 24 |
+
RUN pip install -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Install Jupyter Notebook
|
| 27 |
+
RUN pip install jupyter
|
| 28 |
+
|
| 29 |
+
# Copy the entire project into the container
|
| 30 |
+
COPY . .
|
| 31 |
+
|
| 32 |
+
# Expose port 8888 for Jupyter Notebook
|
| 33 |
+
EXPOSE 8888
|
| 34 |
+
|
| 35 |
+
# Set environment variable to prevent Python from buffering output
|
| 36 |
+
ENV PYTHONUNBUFFERED=1
|
| 37 |
+
|
| 38 |
+
# Set the default command to start Jupyter Notebook
|
| 39 |
+
CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root"]
|
data/images/.gitkeep
ADDED
|
File without changes
|
embeddings/.gitkeep
ADDED
|
File without changes
|
pytest.ini
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
filterwarnings =
|
| 3 |
+
ignore::DeprecationWarning
|
| 4 |
+
ignore::UserWarning
|
| 5 |
+
ignore::FutureWarning
|
src/classifiers_mlp.py
CHANGED
|
@@ -459,6 +459,7 @@ def train_mlp(
|
|
| 459 |
# Train the model using the training data and validation data
|
| 460 |
history = None
|
| 461 |
if train_model:
|
|
|
|
| 462 |
history = model.fit(
|
| 463 |
train_loader,
|
| 464 |
validation_data=test_loader,
|
|
@@ -469,7 +470,7 @@ def train_mlp(
|
|
| 469 |
)
|
| 470 |
|
| 471 |
if test_mlp_model:
|
| 472 |
-
# Test the model on the test set
|
| 473 |
y_true, y_pred, y_prob = [], [], []
|
| 474 |
for batch in test_loader:
|
| 475 |
features, labels = batch
|
|
@@ -501,7 +502,7 @@ def train_mlp(
|
|
| 501 |
if report:
|
| 502 |
test_model(y_true, y_pred, y_prob, encoder=train_loader.encoder)
|
| 503 |
|
| 504 |
-
# Store results in a dataframe and save in the results folder
|
| 505 |
if text_input_size is not None and image_input_size is not None:
|
| 506 |
model_type = "multimodal"
|
| 507 |
elif text_input_size is not None:
|
|
@@ -516,6 +517,14 @@ def train_mlp(
|
|
| 516 |
# create results folder if it does not exist
|
| 517 |
os.makedirs("results", exist_ok=True)
|
| 518 |
results.to_csv(f"results/{model_type}_results.csv", index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
else:
|
| 520 |
test_accuracy, f1, macro_auc = None, None, None
|
| 521 |
|
|
|
|
| 459 |
# Train the model using the training data and validation data
|
| 460 |
history = None
|
| 461 |
if train_model:
|
| 462 |
+
# π Train the model
|
| 463 |
history = model.fit(
|
| 464 |
train_loader,
|
| 465 |
validation_data=test_loader,
|
|
|
|
| 470 |
)
|
| 471 |
|
| 472 |
if test_mlp_model:
|
| 473 |
+
# π Test the model on the test set
|
| 474 |
y_true, y_pred, y_prob = [], [], []
|
| 475 |
for batch in test_loader:
|
| 476 |
features, labels = batch
|
|
|
|
| 502 |
if report:
|
| 503 |
test_model(y_true, y_pred, y_prob, encoder=train_loader.encoder)
|
| 504 |
|
| 505 |
+
# π Store results in a dataframe and save in the results folder
|
| 506 |
if text_input_size is not None and image_input_size is not None:
|
| 507 |
model_type = "multimodal"
|
| 508 |
elif text_input_size is not None:
|
|
|
|
| 517 |
# create results folder if it does not exist
|
| 518 |
os.makedirs("results", exist_ok=True)
|
| 519 |
results.to_csv(f"results/{model_type}_results.csv", index=False)
|
| 520 |
+
|
| 521 |
+
# π Save the model
|
| 522 |
+
models_dir = "trained_models"
|
| 523 |
+
os.makedirs(models_dir, exist_ok=True)
|
| 524 |
+
|
| 525 |
+
model_filename = os.path.join(models_dir, f"{model_type}_model")
|
| 526 |
+
model.save(model_filename)
|
| 527 |
+
print(f"β
{model_type} model saved successfully")
|
| 528 |
else:
|
| 529 |
test_accuracy, f1, macro_auc = None, None, None
|
| 530 |
|
src/vision_embeddings_tf.py
CHANGED
|
@@ -372,7 +372,7 @@ def get_embeddings_df(
|
|
| 372 |
path="data/images",
|
| 373 |
dataset_name="",
|
| 374 |
backbone="resnet50",
|
| 375 |
-
directory="
|
| 376 |
image_files=None,
|
| 377 |
):
|
| 378 |
"""
|
|
@@ -394,7 +394,7 @@ def get_embeddings_df(
|
|
| 394 |
The name of the backbone model to use for generating embeddings. The default is 'resnet50'.
|
| 395 |
Other possible options include models like 'convnext_tiny', 'vit_base', etc.
|
| 396 |
directory : str, optional
|
| 397 |
-
The root directory where the embeddings CSV file will be saved. Default is '
|
| 398 |
image_files : list, optional
|
| 399 |
A pre-defined list of image file names to process. If not provided, the function will automatically detect
|
| 400 |
image files in the `path` directory.
|
|
|
|
| 372 |
path="data/images",
|
| 373 |
dataset_name="",
|
| 374 |
backbone="resnet50",
|
| 375 |
+
directory="embeddings",
|
| 376 |
image_files=None,
|
| 377 |
):
|
| 378 |
"""
|
|
|
|
| 394 |
The name of the backbone model to use for generating embeddings. The default is 'resnet50'.
|
| 395 |
Other possible options include models like 'convnext_tiny', 'vit_base', etc.
|
| 396 |
directory : str, optional
|
| 397 |
+
The root directory where the embeddings CSV file will be saved. Default is 'embeddings'.
|
| 398 |
image_files : list, optional
|
| 399 |
A pre-defined list of image file names to process. If not provided, the function will automatically detect
|
| 400 |
image files in the `path` directory.
|