|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pytest |
|
|
|
|
|
from src.utils import train_test_split_and_feature_extraction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def big_fake_data(): |
|
|
|
|
|
num_rows = 100 |
|
|
num_image_columns = 10 |
|
|
num_text_columns = 11 |
|
|
|
|
|
data = { |
|
|
"id": np.arange(1, num_rows + 1), |
|
|
"image": [f"path/{i}.jpg" for i in range(1, num_rows + 1)], |
|
|
} |
|
|
|
|
|
|
|
|
for i in range(num_image_columns): |
|
|
data[f"image_{i}"] = np.random.rand(num_rows) |
|
|
|
|
|
|
|
|
for i in range(num_text_columns): |
|
|
data[f"text_{i}"] = np.random.rand(num_rows) |
|
|
|
|
|
|
|
|
data["class_id"] = np.random.choice(["label1", "label2", "label3"], size=num_rows) |
|
|
|
|
|
return pd.DataFrame(data) |
|
|
|
|
|
|
|
|
def test_train_test_split_and_feature_extraction(big_fake_data): |
|
|
|
|
|
train_df, test_df, text_columns, image_columns, label_columns = ( |
|
|
train_test_split_and_feature_extraction( |
|
|
big_fake_data, test_size=0.3, random_state=42 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
assert text_columns == [f"text_{i}" for i in range(11)], ( |
|
|
"The text embedding columns extraction is incorrect" |
|
|
) |
|
|
assert image_columns == [f"image_{i}" for i in range(10)], ( |
|
|
"The image embedding columns extraction is incorrect" |
|
|
) |
|
|
assert label_columns == ["class_id"], ( |
|
|
"The label column extraction is incorrect, should be 'class_id'" |
|
|
) |
|
|
|
|
|
|
|
|
assert "image" not in image_columns, ( |
|
|
"'image' column is not part of the embedding columns" |
|
|
) |
|
|
|
|
|
|
|
|
assert len(train_df) == 70, f"Train size should be 70%, but got {len(train_df)}%" |
|
|
assert len(test_df) == 30, f"Test size should be 30%, but got {len(test_df)}%" |
|
|
|
|
|
|
|
|
expected_train_indices = train_df.index.tolist() |
|
|
expected_test_indices = test_df.index.tolist() |
|
|
|
|
|
|
|
|
train_df_recheck, test_df_recheck, _, _, _ = ( |
|
|
train_test_split_and_feature_extraction( |
|
|
big_fake_data, test_size=0.3, random_state=42 |
|
|
) |
|
|
) |
|
|
|
|
|
assert expected_train_indices == train_df_recheck.index.tolist(), ( |
|
|
"Train set indices are not consistent with the random state" |
|
|
) |
|
|
assert expected_test_indices == test_df_recheck.index.tolist(), ( |
|
|
"Test set indices are not consistent with the random state" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main() |
|
|
|