{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hidden Markov Model for NER\n" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from app.travel_resolver.libs.nlp.data_processing import from_bio_file_to_examples\n", "\n", "BIO_FILE = \"data/bio/fr.bio/1k_samples.bio\"\n", "\n", "sentences, labels, _, unique_labels = from_bio_file_to_examples(BIO_FILE)" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [], "source": [ "from app.travel_resolver.libs.nlp.data_processing import process_sentence\n", "\n", "processed_sentences = []\n", "processed_labels = []\n", "\n", "for sentence, label in zip(sentences, labels):\n", " p_sentence, p_label = process_sentence(\n", " sentence, stemming=True, labels_to_adapt=label, rm_stopwords=True\n", " )\n", " processed_sentences.append(p_sentence)\n", " processed_labels.append(p_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Splitting the data between **training** and **test** set. We will do an `80/20` split.\n" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_sentences, test_sentences, y_train, y_test = train_test_split(\n", " processed_sentences, processed_labels, test_size=0.2\n", ")" ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "def t2_given_t1(\n", " t2: str | int,\n", " t1: str | int,\n", " train_bag=y_train,\n", " unique_labels_mapping: dict = unique_labels,\n", "):\n", " \"\"\"\n", " Get the probability of getting t2 given t1 in the given labels\n", "\n", " Args:\n", " t2: str | int, the second tag\n", " t1: str | int, the first tag\n", " train_bag: list, the list of labels\n", "\n", " Returns:\n", " float, the probability of getting t2 given t1\n", " \"\"\"\n", " t1 = t1 if isinstance(t1, int) else unique_labels_mapping[t1]\n", " t2 = t2 if isinstance(t2, int) else unique_labels_mapping[t2]\n", " count_t1 = 0\n", " count_t2_t1 = 0\n", " for row in train_bag:\n", " for index in range(len(row) - 1):\n", " if row[index] == t1:\n", " count_t1 += 1\n", " if row[index] == t1 and row[index + 1] == t2:\n", " count_t2_t1 += 1\n", " return count_t2_t1 / count_t1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the next part, we will be getting the **transition matrix** which represents the _probability_ of transitioning from a state to another $P(S_2 | S_1)$. In our case it would be for example $P(O | \\text{ARR-LOC})$.\n" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.74208238, 0.53819444, 0.91217799],\n", " [0.16145726, 0.06712963, 0.03278689],\n", " [0.09646036, 0.39467593, 0.05503513]])" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "tags = list(unique_labels.keys())\n", "n_tags = len(tags)\n", "\n", "trans_matrix = np.zeros((n_tags, n_tags))\n", "\n", "for t1 in range(n_tags):\n", " for t2 in range(n_tags):\n", " trans_matrix[t1][t2] = t2_given_t1(tags[t1], tags[t2])\n", "\n", "trans_matrix" ] }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
OLOC-DEPLOC-ARR
O0.7420820.5381940.912178
LOC-DEP0.1614570.0671300.032787
LOC-ARR0.0964600.3946760.055035
\n", "
" ], "text/plain": [ " O LOC-DEP LOC-ARR\n", "O 0.742082 0.538194 0.912178\n", "LOC-DEP 0.161457 0.067130 0.032787\n", "LOC-ARR 0.096460 0.394676 0.055035" ] }, "execution_count": 143, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "trans_matrix_df = pd.DataFrame(trans_matrix, columns=tags, index=tags)\n", "\n", "trans_matrix_df" ] }, { "cell_type": "code", "execution_count": 144, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "plt.figure(figsize=(8, 5))\n", "\n", "sns.set(font_scale=1.2)\n", "\n", "sns.heatmap(trans_matrix_df, annot=True, square=True, annot_kws={\"fontsize\": 12}, cmap=\"Blues\")\n", "\n", "plt.title(\"Transition Matrix\")\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 145, "metadata": {}, "outputs": [], "source": [ "import math\n", "from fractions import Fraction\n", "\n", "\n", "def round_down(x, a):\n", " return math.floor(x / a) * a\n", "\n", "\n", "def get_emission_prob_dict(\n", " sentences: list[str] = train_sentences,\n", " labels: list[int] = y_train,\n", " tags: dict = unique_labels,\n", "):\n", " \"\"\"\n", " Given a bunch of sentences and their labels, get the emission probability of each word given each tag\n", "\n", " Parameters:\n", " sentences (list): The list of sentences\n", " labels (list): The list of labels\n", " tags (dict): The dictionary of tags\n", "\n", " Returns:\n", " dict: The dictionary of emission probabilities\n", " \"\"\"\n", " tags = list(tags.keys())\n", " word_tag_count = {}\n", " # Add tag\n", " word_tag_count[\"\"] = {tag: 0 for tag in tags}\n", "\n", " tag_count = {tag: 0 for tag in tags}\n", "\n", " # Count n times a word has tag\n", " for i in range(len(sentences)):\n", " for word, label in zip(sentences[i].split(\" \"), labels[i]):\n", " if word not in word_tag_count:\n", " word_tag_count[word] = {tag: 0 for tag in tags}\n", " word_tag_count[word][tags[label]] += 1\n", " tag_count[tags[label]] += 1\n", "\n", " words_to_remove = []\n", "\n", " # combining least represented words under \n", " for word in word_tag_count:\n", " # if word not frequent enough add it to \n", " if sum(word_tag_count[word].values()) < 3 and not word == \"\":\n", " for tag in word_tag_count[word]:\n", " word_tag_count[\"\"][tag] += word_tag_count[word][tag]\n", " words_to_remove.append(word)\n", " continue\n", "\n", " word_tag_count = {\n", " key: word_tag_count[key] for key in word_tag_count if key not in words_to_remove\n", " }\n", "\n", " # Calculate the prob of a word given tag\n", " for word in word_tag_count:\n", " if word == \"\":\n", " print(word_tag_count[word])\n", " for tag in word_tag_count[word]:\n", " if word_tag_count[word][tag] == 0:\n", " word_tag_count[word][tag] = 0\n", " continue\n", " word_tag_count[word][tag] = word_tag_count[word][tag] / tag_count[tag]\n", "\n", " return word_tag_count" ] }, { "cell_type": "code", "execution_count": 146, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['je', 'veux', 'part', 'de', 'montpelli', 'à', 'paris', '.', '']" ] }, "execution_count": 146, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sentence_test = \"Je veux partir de Montpellier à Paris.\"\n", "\n", "s_t = process_sentence(\n", " sentence_test, stemming=True, rm_stopwords=False, return_tokens=True\n", ") + [\"\"]\n", "\n", "s_t" ] }, { "cell_type": "code", "execution_count": 147, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
OLOC-DEPLOC-ARR
je0.0508180.0000000.000000
veux0.0019260.0000000.000000
part0.0054560.0000000.000000
de0.1086980.0011570.001160
à0.0625870.0000000.000000
paris0.0000000.0115740.011601
.0.0680430.0000000.000000
<UNK>0.0010700.7789350.778422
\n", "
" ], "text/plain": [ " O LOC-DEP LOC-ARR\n", "je 0.050818 0.000000 0.000000\n", "veux 0.001926 0.000000 0.000000\n", "part 0.005456 0.000000 0.000000\n", "de 0.108698 0.001157 0.001160\n", "à 0.062587 0.000000 0.000000\n", "paris 0.000000 0.011574 0.011601\n", ". 0.068043 0.000000 0.000000\n", " 0.001070 0.778935 0.778422" ] }, "execution_count": 147, "metadata": {}, "output_type": "execute_result" } ], "source": [ "em_prob_df.filter(items=s_t, axis=0)" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'O': 2, 'LOC-DEP': 674, 'LOC-ARR': 664}\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
OLOC-DEPLOC-ARR
<UNK>0.0003560.7782910.777518
je0.0833780.0000000.000000
voudr0.0426670.0000000.000000
trouv0.0426670.0000000.000000
vol0.0375110.0000000.000000
ver0.0428440.0000000.000000
compagn0.0288000.0000000.000000
aérien0.0288000.0000000.000000
d'hélicopter0.0008890.0000000.000000
.0.1136000.0000000.000000
\n", "
" ], "text/plain": [ " O LOC-DEP LOC-ARR\n", " 0.000356 0.778291 0.777518\n", "je 0.083378 0.000000 0.000000\n", "voudr 0.042667 0.000000 0.000000\n", "trouv 0.042667 0.000000 0.000000\n", "vol 0.037511 0.000000 0.000000\n", "ver 0.042844 0.000000 0.000000\n", "compagn 0.028800 0.000000 0.000000\n", "aérien 0.028800 0.000000 0.000000\n", "d'hélicopter 0.000889 0.000000 0.000000\n", ". 0.113600 0.000000 0.000000" ] }, "execution_count": 148, "metadata": {}, "output_type": "execute_result" } ], "source": [ "em_prob_dict = get_emission_prob_dict()\n", "\n", "em_prob_df = pd.DataFrame(em_prob_dict).T\n", "\n", "em_prob_df.head(10)" ] }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [], "source": [ "em_prob = em_prob_df.to_numpy()" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(5, 5))\n", "ax = fig.add_subplot(projection=\"3d\")\n", "\n", "xx = em_prob[:, 0]\n", "yy = em_prob[:, 1]\n", "zz = em_prob[:, 2]\n", "\n", "xlabel = list(em_prob_df.columns)[0]\n", "ylabel = list(em_prob_df.columns)[1]\n", "zlabel = list(em_prob_df.columns)[2]\n", "\n", "# Words with the highest probability on each axis\n", "x_highest_i = np.argmax(em_prob[:, 0])\n", "y_highest_i = np.argmax(em_prob[:, 1])\n", "z_highest_i = np.argmax(em_prob[:, 2])\n", "\n", "for i in enumerate([x_highest_i, y_highest_i, z_highest_i]):\n", " ax.text(\n", " xx[i[1]],\n", " yy[i[1]],\n", " zz[i[1]],\n", " f\"'{em_prob_df.index[i[1]]}'\\n({xx[i[1]]:.2f}, {yy[i[1]]:.2f}, {zz[i[1]]:.2f})\",\n", " fontsize=12,\n", " ha=\"center\",\n", " va=\"center\",\n", " )\n", "\n", "ax.scatter(xx, yy, zz)\n", "\n", "ax.get_xaxis().set_ticklabels([])\n", "ax.get_yaxis().set_ticklabels([])\n", "ax.get_zaxis().set_ticklabels([])\n", "\n", "ax.set_xlabel(xlabel)\n", "ax.set_ylabel(ylabel)\n", "ax.set_zlabel(zlabel)\n", "\n", "# Move zlabel closer to plot\n", "ax.zaxis.labelpad = -10\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [], "source": [ "def get_start_prob(labels: list[int] = y_train, tags: dict = unique_labels):\n", " \"\"\"\n", " Get the start probability of each tag\n", "\n", " Parameters:\n", " labels (list): The list of labels\n", " tags (dict): The dictionary of tags\n", "\n", " Returns:\n", " dict: The dictionary of start probabilities\n", " \"\"\"\n", " tags = list(tags.keys())\n", " start_prob = {tag: 0 for tag in tags}\n", " for label in labels:\n", " start_prob[tags[label[0]]] += 1\n", " total_count = sum(start_prob.values())\n", " for tag in start_prob:\n", " start_prob[tag] /= total_count\n", " return start_prob" ] }, { "cell_type": "code", "execution_count": 152, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 0., 0.])" ] }, "execution_count": 152, "metadata": {}, "output_type": "execute_result" } ], "source": [ "startprob = get_start_prob()\n", "\n", "startprob = np.array(list(startprob.values()))\n", "\n", "startprob" ] }, { "cell_type": "code", "execution_count": 153, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Start Probability
O1.0
LOC-DEP0.0
LOC-ARR0.0
\n", "
" ], "text/plain": [ " Start Probability\n", "O 1.0\n", "LOC-DEP 0.0\n", "LOC-ARR 0.0" ] }, "execution_count": 153, "metadata": {}, "output_type": "execute_result" } ], "source": [ "startprob_df = pd.DataFrame(\n", " startprob, index=unique_labels.keys(), columns=[\"Start Probability\"]\n", ")\n", "\n", "startprob_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we already have the model parameters, we don't need to use the automatic model estimation with `fit()`. Therefore, we will pass the probabilities directly to the model and move on to the predictions.\n" ] }, { "cell_type": "code", "execution_count": 154, "metadata": {}, "outputs": [], "source": [ "from hmmlearn.hmm import CategoricalHMM\n", "\n", "vocab = list(em_prob_dict.keys())\n", "\n", "hmm = CategoricalHMM(n_components=n_tags, n_iter=100)\n", "\n", "hmm.n_features = len(vocab)\n", "hmm.startprob_ = startprob\n", "hmm.transmat_ = trans_matrix.T\n", "hmm.emissionprob_ = em_prob.T" ] }, { "cell_type": "code", "execution_count": 155, "metadata": {}, "outputs": [], "source": [ "def encode_sentence(sentence: str, vocab: list[str]):\n", " \"\"\"\n", " Encode a sentence into a list of integers\n", "\n", " Parameters:\n", " sentence (str): The sentence to encode\n", " vocab (list): The vocabulary\n", "\n", " Returns:\n", " list: The list of integers\n", " \"\"\"\n", " return [\n", " vocab.index(word) if word in vocab else vocab.index(\"\")\n", " for word in sentence.split(\" \")\n", " ]" ] }, { "cell_type": "code", "execution_count": 159, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quel: O\n", "trajet: O\n", "fair: O\n", "aller: O\n", "perpignan: LOC-DEP\n", "montpelli: LOC-ARR\n" ] } ], "source": [ "from app.travel_resolver.libs.nlp.data_processing import process_sentence\n", "\n", "vocab = list(em_prob_dict.keys())\n", "\n", "test_sentence = \"Quel trajet faire pour aller de Perpignan à Montpellier\"\n", "\n", "test_sentence = process_sentence(test_sentence, stemming=True, rm_stopwords=True)\n", "\n", "test_sentence_encoded = encode_sentence(test_sentence, vocab)\n", "\n", "test_sentence_encoded = np.array(test_sentence_encoded).reshape(-1, 1)\n", "\n", "predicted_labels = hmm.predict(test_sentence_encoded)\n", "\n", "for word, label in zip(test_sentence.split(\" \"), predicted_labels):\n", " print(f\"{word}: {list(unique_labels.keys())[label]}\")" ] }, { "cell_type": "code", "execution_count": 160, "metadata": {}, "outputs": [], "source": [ "test_sentences_encoded = [\n", " encode_sentence(sentence, vocab) for sentence in test_sentences\n", "]\n", "test_sentences_lengths = [len(sentence) for sentence in test_sentences_encoded]\n", "\n", "\n", "test_sentences_encoded_flat = [\n", " item for sublist in test_sentences_encoded for item in sublist\n", "]\n", "test_sentences_encoded = np.array(test_sentences_encoded_flat).reshape(-1, 1)\n", "\n", "predicted_labels_test = hmm.predict(test_sentences_encoded, test_sentences_lengths)" ] }, { "cell_type": "code", "execution_count": 161, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F1 Score: 0.6749156355455568\n" ] } ], "source": [ "from sklearn.metrics import accuracy_score, f1_score\n", "\n", "y_test_flat = [item for sublist in y_test for item in sublist]\n", "\n", "f1 = f1_score(y_test_flat, predicted_labels_test, average=\"micro\", labels=[1, 2])\n", "\n", "print(f\"F1 Score: {f1}\")" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "conf matrix [[1432 0 2]\n", " [ 0 103 117]\n", " [ 1 26 197]]\n" ] }, { "data": { "text/plain": [ "Text(0.5, 1.0, 'Confusion Matrix on test data')" ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "conf_matrix = confusion_matrix(y_test_flat, predicted_labels_test)\n", "\n", "print(\"conf matrix\", conf_matrix)\n", "\n", "conf_matrix = conf_matrix / conf_matrix.sum(axis=1)\n", "\n", "conf_matrix_df = pd.DataFrame(conf_matrix, columns=tags, index=tags)\n", "\n", "\n", "plt.figure(figsize=(8, 5))\n", "\n", "sns.set(font_scale=1)\n", "\n", "sns.heatmap(\n", " conf_matrix_df, annot=True, square=True, annot_kws={\"fontsize\": 12}, cmap=\"Blues\"\n", ")\n", "\n", "plt.xlabel(\"Predicted\")\n", "plt.ylabel(\"True\")\n", "plt.title(\"Confusion Matrix on test data\")" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import RocCurveDisplay\n", "from sklearn.preprocessing import LabelBinarizer\n", "\n", "class_of_interest = \"LOC-DEP\"\n", "class_id = list(unique_labels.keys()).index(class_of_interest)\n", "\n", "label_binarizer = LabelBinarizer().fit(y_test_flat)\n", "y_test_onehot = label_binarizer.transform(y_test_flat)\n", "y_pred_onehot = label_binarizer.transform(predicted_labels_test)\n", "\n", "display = RocCurveDisplay.from_predictions(\n", " y_test_onehot[:, class_id],\n", " y_pred_onehot[:, class_id],\n", " name=f\"{class_of_interest} vs Rest\",\n", " plot_chance_level=True,\n", ")\n", "\n", "_ = display.ax_.set(\n", " xlabel=\"False Positive Rate\",\n", " ylabel=\"True Positive Rate\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluating new sentences\n", "\n", "In this section, we will be evaluating the model capacity on sentences that hasn't contributed to the model's probabilities.\n", "\n", "The corpus is formed from an initial dataset of 20 unique blueprint sentences that were used to generate a 1000 sentences with different \"Departures\" and \"Arrivals\".\n" ] }, { "cell_type": "code", "execution_count": 165, "metadata": {}, "outputs": [], "source": [ "new_sentences, new_labels, new_vocab, new_unique_labels = from_bio_file_to_examples(\n", " \"./data/bio/fr.bio/800_eval_small_samples.bio\"\n", ")" ] }, { "cell_type": "code", "execution_count": 166, "metadata": {}, "outputs": [], "source": [ "sls_new_sentences = []\n", "sls_new_labels = []\n", "\n", "for sentence, label in zip(new_sentences, new_labels):\n", " p_sentence, p_label = process_sentence(\n", " sentence, stemming=True, labels_to_adapt=label, rm_stopwords=True\n", " )\n", " sls_new_sentences.append(p_sentence)\n", " sls_new_labels.append(p_label)" ] }, { "cell_type": "code", "execution_count": 167, "metadata": {}, "outputs": [], "source": [ "sls_new_sentences_encoded = [\n", " encode_sentence(sentence, vocab) for sentence in sls_new_sentences\n", "]\n", "\n", "sls_new_sentences_length = [len(sentence) for sentence in sls_new_sentences_encoded]\n", "\n", "\n", "sls_new_sentences_flat = [\n", " item for sublist in sls_new_sentences_encoded for item in sublist\n", "]\n", "\n", "sls_new_sentences_encoded = np.array(sls_new_sentences_flat).reshape(-1, 1)\n", "\n", "new_sentences_pred = hmm.predict(sls_new_sentences_encoded, sls_new_sentences_length)" ] }, { "cell_type": "code", "execution_count": 168, "metadata": {}, "outputs": [], "source": [ "y_real = [item for sublist in sls_new_labels for item in sublist]" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F1: 0.3812801101169993\n" ] } ], "source": [ "f1_score_new = f1_score(y_real, new_sentences_pred, average=\"micro\", labels=[1, 2])\n", "\n", "print(f\"F1: {f1_score_new}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the sharp decrease from the other corpus, and that's the limitation of HMM, they can go as far as their they know - by know I the known observations and emission probs - .\n" ] }, { "cell_type": "code", "execution_count": 170, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[3314, 186, 820],\n", " [ 0, 241, 608],\n", " [ 1, 237, 590]])" ] }, "execution_count": 170, "metadata": {}, "output_type": "execute_result" } ], "source": [ "consufion_matrix_new = confusion_matrix(y_real, new_sentences_pred)\n", "\n", "consufion_matrix_new" ] }, { "cell_type": "code", "execution_count": 171, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Counter({0: 4320, 1: 849, 2: 828})" ] }, "execution_count": 171, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from collections import Counter\n", "\n", "Counter(y_real)" ] }, { "cell_type": "code", "execution_count": 172, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Confusion Matrix HMM on new data')" ] }, "execution_count": 172, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "conf_matrix = confusion_matrix(y_real, new_sentences_pred)\n", "\n", "conf_matrix = np.round(conf_matrix / conf_matrix.sum(axis=0), 2)\n", "\n", "conf_matrix_df = pd.DataFrame(conf_matrix, columns=tags, index=tags)\n", "\n", "plt.figure(figsize=(8, 5))\n", "\n", "sns.set(font_scale=1)\n", "\n", "sns.heatmap(\n", " conf_matrix_df, annot=True, square=True, annot_kws={\"fontsize\": 12}, cmap=\"Blues\"\n", ")\n", "\n", "plt.xlabel(\"Predicted Labels\")\n", "plt.ylabel(\"True Labels\")\n", "plt.title(\"Confusion Matrix HMM on new data\")" ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n", "Even though the 'startprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 's'\n", "Even though the 'transmat_' attribute is set, it will be overwritten during initialization because 'init_params' contains 't'\n", "Even though the 'emissionprob_' attribute is set, it will be overwritten during initialization because 'init_params' contains 'e'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[0.11692084241103849, 0.15783704786262331, 0.14619289340101524, 0.1215783860646996, 0.14670308232539991, 0.14045906132237068, 0.15143974404550303, 0.12706995606623858, 0.12127894156560089, 0.11530249110320284, 0.11458333333333333, 0.16469691193290126, 0.1560379918588874, 0.15489422732162064, 0.15068015347052668, 0.18445518780215694, 0.15480304077401522, 0.13828689370485037, 0.1728101466508125, 0.16194625998547568, 0.16495629038388446, 0.1774387796065837, 0.15412058508740634, 0.15793357933579336, 0.14457831325301204, 0.15610113594723343, 0.11690140845070422, 0.1413665432514305, 0.17109375, 0.1608598962194218]\n" ] } ], "source": [ "from sklearn.utils import resample\n", "\n", "f1_scores = []\n", "\n", "for _ in range(30):\n", " # Resample the training data with replacement\n", " train_sentences_resampled, y_train_resampled = resample(\n", " sls_new_sentences, sls_new_labels, replace=True\n", " )\n", "\n", " train_sentences_resampled_encoded = [\n", " encode_sentence(sentence, vocab) for sentence in train_sentences_resampled\n", " ]\n", "\n", " train_sentences_resampled_flat = [\n", " item for sublist in train_sentences_resampled_encoded for item in sublist\n", " ]\n", "\n", " # Train the HMM on the resampled data\n", " hmm.fit(\n", " np.array(train_sentences_resampled_flat).reshape(-1, 1),\n", " lengths=[len(sentence) for sentence in y_train_resampled],\n", " )\n", "\n", " # Predict on the test data\n", " predicted_labels_test_resampled = hmm.predict(\n", " test_sentences_encoded, test_sentences_lengths\n", " )\n", "\n", " # Compute the F1 score\n", " f1_resampled = f1_score(\n", " y_test_flat, predicted_labels_test_resampled, average=\"micro\", labels=[1, 2]\n", " )\n", " f1_scores.append(f1_resampled)\n", "\n", "print(f1_scores)" ] }, { "cell_type": "code", "execution_count": 137, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.11692084241103849,\n", " 0.15783704786262331,\n", " 0.14619289340101524,\n", " 0.1215783860646996,\n", " 0.14670308232539991,\n", " 0.14045906132237068,\n", " 0.15143974404550303,\n", " 0.12706995606623858,\n", " 0.12127894156560089,\n", " 0.11530249110320284,\n", " 0.11458333333333333,\n", " 0.16469691193290126,\n", " 0.1560379918588874,\n", " 0.15489422732162064,\n", " 0.15068015347052668,\n", " 0.18445518780215694,\n", " 0.15480304077401522,\n", " 0.13828689370485037,\n", " 0.1728101466508125,\n", " 0.16194625998547568,\n", " 0.16495629038388446,\n", " 0.1774387796065837,\n", " 0.15412058508740634,\n", " 0.15793357933579336,\n", " 0.14457831325301204,\n", " 0.15610113594723343,\n", " 0.11690140845070422,\n", " 0.1413665432514305,\n", " 0.17109375,\n", " 0.1608598962194218]" ] }, "execution_count": 137, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f1_scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "with open(\"hmm_f1_results\", \"wb\") as f:\n", " pickle.dump(f1_scores, f)" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }