Az-r-ow commited on
Commit
ee9d38f
·
1 Parent(s): d914a99

chore(deepl): padding sentences and labels

Browse files
Files changed (1) hide show
  1. deepl_ner.ipynb +230 -14
deepl_ner.ipynb CHANGED
@@ -6,26 +6,38 @@
6
  "source": [
7
  "# Deep learning NER\n",
8
  "\n",
9
- "In this notebook, we will discover two deep learning techniques for Named Entity Recognition (or NER). \n",
10
  "\n",
11
  "- LSTM (Long Short Term Memory)\n",
12
- "- Transformers"
13
  ]
14
  },
15
  {
16
  "cell_type": "code",
17
- "execution_count": 5,
18
  "metadata": {},
19
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
20
  "source": [
21
  "from app.travel_resolver.libs.nlp import data_processing as dp\n",
22
  "\n",
23
- "sentences, labels, vocab, unique_labels = dp.from_bio_file_to_examples('./data/bio/fr.bio/fr.sentences.bio')"
 
 
24
  ]
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": 6,
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
@@ -33,18 +45,62 @@
33
  "processed_labels = []\n",
34
  "\n",
35
  "for sentence, label in zip(sentences, labels):\n",
36
- " sentence, label = dp.process_sentence(sentence, stemming=True, return_tokens=True, labels_to_adapt=label)\n",
 
 
37
  " processed_sentences.append(sentence)\n",
38
- " processed_labels.append(label)\n"
39
  ]
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  "metadata": {},
45
  "outputs": [],
46
  "source": [
47
- "def encode_sentence(sentence: str, vocab: list[str]):\n",
 
 
 
48
  " \"\"\"\n",
49
  " Encode a sentence into a list of integers\n",
50
  "\n",
@@ -55,10 +111,170 @@
55
  " Returns:\n",
56
  " list: The list of integers\n",
57
  " \"\"\"\n",
58
- " return [\n",
59
  " vocab.index(word) if word in vocab else vocab.index(\"<UNK>\")\n",
60
- " for word in sentence.split(\" \")\n",
61
- " ]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  ]
63
  }
64
  ],
@@ -78,7 +294,7 @@
78
  "name": "python",
79
  "nbconvert_exporter": "python",
80
  "pygments_lexer": "ipython3",
81
- "version": "3.10.1"
82
  }
83
  },
84
  "nbformat": 4,
 
6
  "source": [
7
  "# Deep learning NER\n",
8
  "\n",
9
+ "In this notebook, we will discover two deep learning techniques for Named Entity Recognition (or NER).\n",
10
  "\n",
11
  "- LSTM (Long Short Term Memory)\n",
12
+ "- Transformers\n"
13
  ]
14
  },
15
  {
16
  "cell_type": "code",
17
+ "execution_count": null,
18
  "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stderr",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "[nltk_data] Downloading package punkt_tab to /Users/az-r-\n",
25
+ "[nltk_data] ow/nltk_data...\n",
26
+ "[nltk_data] Package punkt_tab is already up-to-date!\n"
27
+ ]
28
+ }
29
+ ],
30
  "source": [
31
  "from app.travel_resolver.libs.nlp import data_processing as dp\n",
32
  "\n",
33
+ "sentences, labels, vocab, unique_labels = dp.from_bio_file_to_examples(\n",
34
+ " \"./data/bio/fr.bio/10k_samples.bio\"\n",
35
+ ")"
36
  ]
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": null,
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
 
45
  "processed_labels = []\n",
46
  "\n",
47
  "for sentence, label in zip(sentences, labels):\n",
48
+ " sentence, label = dp.process_sentence(\n",
49
+ " sentence, stemming=True, return_tokens=True, labels_to_adapt=label\n",
50
+ " )\n",
51
  " processed_sentences.append(sentence)\n",
52
+ " processed_labels.append(label)"
53
  ]
54
  },
55
  {
56
  "cell_type": "code",
57
+ "execution_count": 31,
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "data": {
62
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGzCAYAAAAxPS2EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1l0lEQVR4nO3de1xU5d7//zcHGVBhEBVGEpHUUjyHJ1JRy0RjW6ZllnlK824H9UWNzF0qHorUPOzcprd3O+2g7WrfaaV5wHMlamJsS92kpWka2jYFj4iyfn/0Y+5GUMTA4cLX8/FYj5xrfWatzzUM9GYdGA/LsiwBAAAYxNPdDQAAAJQUAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBjedunXravDgwe5uo8KbNm2abr31Vnl5ealFixbubqdC27Bhgzw8PLRhwwZ3t1IsDw8PJSQkuLsNVAAEGBht4cKF8vDw0Pbt24tc37lzZzVp0uQP7+ezzz5TcnLyH97OzWL16tV67rnn1L59ey1YsEAvv/yyu1tyceTIESUnJysjI8PdrVRImzdvVnJysk6ePOnuVlCBebu7AeBGy8zMlKdnybL7Z599pjlz5hBirtG6devk6empv//97/Lx8XF3O4UcOXJEEyZMUN26dTk6VAY2b96sCRMmaPDgwQoMDHR3O6igOAKDm47NZlOlSpXc3UaJnDlzxt0tlMixY8fk5+dXLsMLgIqBAIObzuXXwOTl5WnChAlq0KCBfH19Vb16dXXo0EGpqamSpMGDB2vOnDmSfjt/X7AUOHPmjEaNGqWwsDDZbDbdfvvtevXVV3X5B72fO3dOzzzzjGrUqCF/f3/dd999Onz4sDw8PFyO7CQnJ8vDw0O7d+/Wo48+qmrVqqlDhw6SpJ07d2rw4MG69dZb5evrK4fDoccff1zHjx932VfBNr777js99thjstvtqlmzpsaOHSvLsnTo0CHdf//9CggIkMPh0PTp06/ptbt48aImTZqkevXqyWazqW7duvrLX/6i3NxcZ42Hh4cWLFigM2fOOF+rhQsXXnGbe/fuVZ8+feRwOOTr66vatWurX79+ys7Odql79913FRUVJT8/PwUFBalfv346dOiQS03BKcPdu3erS5cuqly5sm655RZNnTrVWbNhwwa1bt1akjRkyJAie9y6dau6d+8uu92uypUrq1OnTvryyy+LfI337dvnPNJgt9s1ZMgQnT17ttA83333XbVp00aVK1dWtWrVFBMTo9WrV7vUrFixQh07dlSVKlXk7++vuLg47dq164qvXXFKex7X8h5OTk5WUlKSJCkiIsL5+h44cMBlW0uXLlWTJk1ks9nUuHFjrVy50mX9qVOnlJiYqLp168pmsyk4OFj33HOPduzYcd2vByoWTiGhQsjOztZ//vOfQuN5eXnFPjc5OVkpKSkaNmyY2rRpo5ycHG3fvl07duzQPffco//6r//SkSNHlJqaqnfeecfluZZl6b777tP69es1dOhQtWjRQqtWrVJSUpIOHz6smTNnOmsHDx6sDz74QAMGDFC7du20ceNGxcXFXbGvhx56SA0aNNDLL7/sDEOpqan64YcfNGTIEDkcDu3atUvz58/Xrl27tGXLFpdgJUkPP/ywGjVqpFdeeUXLly/X5MmTFRQUpP/+7//WXXfdpSlTpmjRokV69tln1bp1a8XExFz1tRo2bJjeeustPfjggxo1apS2bt2qlJQU7dmzR0uWLJEkvfPOO5o/f762bdumN954Q5J05513Frm9CxcuKDY2Vrm5uXr66aflcDh0+PBhLVu2TCdPnpTdbpckvfTSSxo7dqz69u2rYcOG6ZdfftHs2bMVExOjr7/+2uU0xYkTJ9S9e3f17t1bffv21T//+U+NHj1aTZs2VY8ePdSoUSNNnDhR48aN0/Dhw9WxY0eXHtetW6cePXooKipK48ePl6enpxYsWKC77rpLn3/+udq0aeMyh759+yoiIkIpKSnasWOH3njjDQUHB2vKlCnOmgkTJig5OVl33nmnJk6cKB8fH23dulXr1q1Tt27dnK/boEGDFBsbqylTpujs2bOaO3euOnTooK+//lp169a96tfmcmUxj2t5D/fu3Vvfffed3nvvPc2cOVM1atSQJNWsWdNZ88UXX+ijjz7SU089JX9/f7322mvq06ePDh48qOrVq0uSnnzySf3zn/9UQkKCIiMjdfz4cX3xxRfas2eP7rjjjhK9FqigLMBgCxYssCRddWncuLHLc8LDw61BgwY5Hzdv3tyKi4u76n7i4+Otor5dli5dakmyJk+e7DL+4IMPWh4eHta+ffssy7Ks9PR0S5KVmJjoUjd48GBLkjV+/Hjn2Pjx4y1J1iOPPFJof2fPni009t5771mSrE2bNhXaxvDhw51jFy9etGrXrm15eHhYr7zyinP8xIkTlp+fn8trUpSMjAxLkjVs2DCX8WeffdaSZK1bt845NmjQIKtKlSpX3Z5lWdbXX39tSbI+/PDDK9YcOHDA8vLysl566SWX8W+++cby9vZ2Ge/UqZMlyXr77bedY7m5uZbD4bD69OnjHPvqq68sSdaCBQtctpmfn281aNDAio2NtfLz853jZ8+etSIiIqx77rnHOVbwGj/++OMu23jggQes6tWrOx/v3bvX8vT0tB544AHr0qVLhfZnWZZ16tQpKzAw0HriiSdc1mdlZVl2u73Q+OXWr19vSbLWr19fZvMoyXt42rRpliRr//79hXqVZPn4+Di/NyzLsv71r39ZkqzZs2c7x+x2uxUfH3/VeePmxikkVAhz5sxRampqoaVZs2bFPjcwMFC7du3S3r17S7zfzz77TF5eXnrmmWdcxkeNGiXLsrRixQpJch4ef+qpp1zqnn766Stu+8knnyw05ufn5/z3+fPn9Z///Eft2rWTpCIPrQ8bNsz5by8vL7Vq1UqWZWno0KHO8cDAQN1+++364YcfrtiL9NtcJWnkyJEu46NGjZIkLV++/KrPL0rBEZZVq1YVedpFkj766CPl5+erb9+++s9//uNcHA6HGjRooPXr17vUV61aVY899pjzsY+Pj9q0aVPs/CQpIyNDe/fu1aOPPqrjx48793XmzBndfffd2rRpk/Lz812ec/nXqWPHjjp+/LhycnIk/XaqJD8/X+PGjSt08XjBEbPU1FSdPHlSjzzyiMscvby81LZt20JzdMc8ruc9fCVdu3ZVvXr1nI+bNWumgIAAl69RYGCgtm7dqiNHjpR4+7g5cAoJFUKbNm3UqlWrQuPVqlUr8tTS702cOFH333+/brvtNjVp0kTdu3fXgAEDrin8/PjjjwoNDZW/v7/LeKNGjZzrC/7r6empiIgIl7r69etfcduX10rSr7/+qgkTJugf//iHjh075rLu8mtGJKlOnTouj+12u3x9fZ2H9X8/fvl1NJcrmMPlPTscDgUGBjrnWhIREREaOXKkZsyYoUWLFqljx4667777nNftSL9dI2NZlho0aFDkNi6/ILt27dqFTqVVq1ZNO3fuLLafghA7aNCgK9ZkZ2erWrVqzseXv8YF606cOKGAgAB9//338vT0VGRkZLH7veuuu4pcHxAQUGzvRW2vNOdxPe/hK7l8XwX7O3HihPPx1KlTNWjQIIWFhSkqKkr33nuvBg4cqFtvvbXE+0PFRIDBTS8mJkbff/+9Pv74Y61evVpvvPGGZs6cqXnz5rkcwbjRfn+0pUDfvn21efNmJSUlqUWLFqpatary8/PVvXv3Qr9RS78ddbmWMUmFLjq+ksvDwR81ffp0DR482Pn6P/PMM0pJSdGWLVtUu3Zt5efny8PDQytWrCiy96pVq7o8/iPzK3gNp02bdsXbq0tzf5fv95133pHD4Si03tu7ZD+q3TWPa3Ut++rbt686duyoJUuWaPXq1Zo2bZqmTJmijz76SD169Cj1nmAeAgwgKSgoSEOGDNGQIUN0+vRpxcTEKDk52RlgrvQ/7fDwcK1Zs0anTp1yOQrz73//27m+4L/5+fnav3+/y5GEffv2XXOPJ06c0Nq1azVhwgSNGzfOOX49p76uR8Ec9u7d6zzCJElHjx7VyZMnnXO9Hk2bNlXTpk314osvavPmzWrfvr3mzZunyZMnq169erIsSxEREbrttttKYypX/HoWnNYICAhQ165dS2Vf9erVU35+vnbv3n3FMFGw3+Dg4FLZb1nMoyTv4dIKubVq1dJTTz2lp556SseOHdMdd9yhl156iQADSdxGDRQ6dVK1alXVr1/f5dbgKlWqSFKhvyx677336tKlS/rb3/7mMj5z5kx5eHg4f9DGxsZKkl5//XWXutmzZ19znwW/tV7+G/GsWbOueRt/xL333lvk/mbMmCFJV72j6kpycnJ08eJFl7GmTZvK09PT+fr37t1bXl5emjBhQqG5W5ZV7Kmvolzp6xkVFaV69erp1Vdf1enTpws975dffinxvnr16iVPT09NnDix0FGygvnExsYqICBAL7/8cpF3zpV0v2Uxj5K8h6/0+l6rS5cuFTolGhwcrNDQUJfvS9zcOAKDm15kZKQ6d+6sqKgoBQUFafv27c7bNwtERUVJkp555hnFxsbKy8tL/fr1U8+ePdWlSxe98MILOnDggJo3b67Vq1fr448/VmJiovM34aioKPXp00ezZs3S8ePHnbegfvfdd5Ku7TfWgIAAxcTEaOrUqcrLy9Mtt9yi1atXa//+/WXwqhTWvHlzDRo0SPPnz9fJkyfVqVMnbdu2TW+99ZZ69eqlLl26lHib69atU0JCgh566CHddtttunjxot555x15eXmpT58+kn47mjB58mSNGTNGBw4cUK9eveTv76/9+/dryZIlGj58uJ599tkS7bdevXoKDAzUvHnz5O/vrypVqqht27aKiIjQG2+8oR49eqhx48YaMmSIbrnlFh0+fFjr169XQECAPv300xLtq379+nrhhRc0adIkdezYUb1795bNZtNXX32l0NBQpaSkKCAgQHPnztWAAQN0xx13qF+/fqpZs6YOHjyo5cuXq3379oVC8tV4enqW+jxK8h4u+H554YUX1K9fP1WqVEk9e/Z0BpvinDp1SrVr19aDDz6o5s2bq2rVqlqzZo2++uqra/6bRbgJuOXeJ6CUFNxG/dVXXxW5vlOnTsXeRj158mSrTZs2VmBgoOXn52c1bNjQeumll6wLFy44ay5evGg9/fTTVs2aNS0PDw+XW6pPnTpljRgxwgoNDbUqVapkNWjQwJo2bZrL7auWZVlnzpyx4uPjraCgIKtq1apWr169rMzMTEuSy23NBbe1/vLLL4Xm89NPP1kPPPCAFRgYaNntduuhhx6yjhw5csVbsS/fxpVuby7qdSpKXl6eNWHCBCsiIsKqVKmSFRYWZo0ZM8Y6f/78Ne3ncj/88IP1+OOPW/Xq1bN8fX2toKAgq0uXLtaaNWsK1f7v//6v1aFDB6tKlSpWlSpVrIYNG1rx8fFWZmZmsfMYNGiQFR4e7jL28ccfW5GRkZa3t3ehW6q//vprq3fv3lb16tUtm81mhYeHW3379rXWrl3rrLnSa1zwnrz8FuI333zTatmypWWz2axq1apZnTp1slJTU11q1q9fb8XGxlp2u93y9fW16tWrZw0ePNjavn37VV/Hy2+jLqt5XOt72LIsa9KkSdYtt9xieXp6umxHUpG3R//++zI3N9dKSkqymjdvbvn7+1tVqlSxmjdvbr3++utXfR1wc/GwrDK4QgvANcnIyFDLli317rvvqn///u5uBygx3sNwF66BAW6Qc+fOFRqbNWuWPD09i/0LuEB5wHsY5QnXwAA3yNSpU5Wenq4uXbrI29tbK1as0IoVKzR8+HCFhYW5uz2gWLyHUZ5wCgm4QVJTUzVhwgTt3r1bp0+fVp06dTRgwAC98MILJf47H4A78B5GeUKAAQAAxuEaGAAAYBwCDAAAME6FPWmZn5+vI0eOyN/fv9Q/uwUAAJQNy7J06tQphYaGFvoE99+rsAHmyJEjXBUPAIChDh06pNq1a19xfYUNMAUfrHfo0KESfxQ9AABwj5ycHIWFhbl8QG5RKmyAKThtFBAQQIABAMAwxV3+wUW8AADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMbxdncDwNXUfX55sTUHXom7AZ0AAMoTjsAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcUoUYFJSUtS6dWv5+/srODhYvXr1UmZmpktN586d5eHh4bI8+eSTLjUHDx5UXFycKleurODgYCUlJenixYsuNRs2bNAdd9whm82m+vXra+HChdc3QwAAUOGUKMBs3LhR8fHx2rJli1JTU5WXl6du3brpzJkzLnVPPPGEfv75Z+cydepU57pLly4pLi5OFy5c0ObNm/XWW29p4cKFGjdunLNm//79iouLU5cuXZSRkaHExEQNGzZMq1at+oPTBQAAFYF3SYpXrlzp8njhwoUKDg5Wenq6YmJinOOVK1eWw+EochurV6/W7t27tWbNGoWEhKhFixaaNGmSRo8ereTkZPn4+GjevHmKiIjQ9OnTJUmNGjXSF198oZkzZyo2NrakcwQAABXMH7oGJjs7W5IUFBTkMr5o0SLVqFFDTZo00ZgxY3T27FnnurS0NDVt2lQhISHOsdjYWOXk5GjXrl3Omq5du7psMzY2VmlpaVfsJTc3Vzk5OS4LAAComEp0BOb38vPzlZiYqPbt26tJkybO8UcffVTh4eEKDQ3Vzp07NXr0aGVmZuqjjz6SJGVlZbmEF0nOx1lZWVetycnJ0blz5+Tn51eon5SUFE2YMOF6pwMAAAxy3QEmPj5e3377rb744guX8eHDhzv/3bRpU9WqVUt33323vv/+e9WrV+/6Oy3GmDFjNHLkSOfjnJwchYWFldn+AACA+1zXKaSEhAQtW7ZM69evV+3ata9a27ZtW0nSvn37JEkOh0NHjx51qSl4XHDdzJVqAgICijz6Ikk2m00BAQEuCwAAqJhKFGAsy1JCQoKWLFmidevWKSIiotjnZGRkSJJq1aolSYqOjtY333yjY8eOOWtSU1MVEBCgyMhIZ83atWtdtpOamqro6OiStAsAACqoEgWY+Ph4vfvuu1q8eLH8/f2VlZWlrKwsnTt3TpL0/fffa9KkSUpPT9eBAwf0ySefaODAgYqJiVGzZs0kSd26dVNkZKQGDBigf/3rX1q1apVefPFFxcfHy2azSZKefPJJ/fDDD3ruuef073//W6+//ro++OADjRgxopSnDwAATFSiADN37lxlZ2erc+fOqlWrlnN5//33JUk+Pj5as2aNunXrpoYNG2rUqFHq06ePPv30U+c2vLy8tGzZMnl5eSk6OlqPPfaYBg4cqIkTJzprIiIitHz5cqWmpqp58+aaPn263njjDW6hBgAAkiQPy7IsdzdRFnJycmS325Wdnc31MAar+/zyYmsOvBJ3AzoBANwI1/r/bz4LCQAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnBIFmJSUFLVu3Vr+/v4KDg5Wr169lJmZ6VJz/vx5xcfHq3r16qpatar69Omjo0ePutQcPHhQcXFxqly5soKDg5WUlKSLFy+61GzYsEF33HGHbDab6tevr4ULF17fDAEAQIVTogCzceNGxcfHa8uWLUpNTVVeXp66deumM2fOOGtGjBihTz/9VB9++KE2btyoI0eOqHfv3s71ly5dUlxcnC5cuKDNmzfrrbfe0sKFCzVu3Dhnzf79+xUXF6cuXbooIyNDiYmJGjZsmFatWlUKUwYAAKbzsCzLut4n//LLLwoODtbGjRsVExOj7Oxs1axZU4sXL9aDDz4oSfr3v/+tRo0aKS0tTe3atdOKFSv0pz/9SUeOHFFISIgkad68eRo9erR++eUX+fj4aPTo0Vq+fLm+/fZb57769eunkydPauXKldfUW05Ojux2u7KzsxUQEHC9U4Sb1X1+ebE1B16JuwGdAABuhGv9//cfugYmOztbkhQUFCRJSk9PV15enrp27eqsadiwoerUqaO0tDRJUlpampo2beoML5IUGxurnJwc7dq1y1nz+20U1BRsoyi5ubnKyclxWQAAQMV03QEmPz9fiYmJat++vZo0aSJJysrKko+PjwIDA11qQ0JClJWV5az5fXgpWF+w7mo1OTk5OnfuXJH9pKSkyG63O5ewsLDrnRoAACjnrjvAxMfH69tvv9U//vGP0uznuo0ZM0bZ2dnO5dChQ+5uCQAAlBHv63lSQkKCli1bpk2bNql27drOcYfDoQsXLujkyZMuR2GOHj0qh8PhrNm2bZvL9gruUvp9zeV3Lh09elQBAQHy8/MrsiebzSabzXY90wEAAIYp0REYy7KUkJCgJUuWaN26dYqIiHBZHxUVpUqVKmnt2rXOsczMTB08eFDR0dGSpOjoaH3zzTc6duyYsyY1NVUBAQGKjIx01vx+GwU1BdsAAAA3txIdgYmPj9fixYv18ccfy9/f33nNit1ul5+fn+x2u4YOHaqRI0cqKChIAQEBevrppxUdHa127dpJkrp166bIyEgNGDBAU6dOVVZWll588UXFx8c7j6A8+eST+tvf/qbnnntOjz/+uNatW6cPPvhAy5cXf0cKAACo+Ep0BGbu3LnKzs5W586dVatWLefy/vvvO2tmzpypP/3pT+rTp49iYmLkcDj00UcfOdd7eXlp2bJl8vLyUnR0tB577DENHDhQEydOdNZERERo+fLlSk1NVfPmzTV9+nS98cYbio2NLYUpAwAA0/2hvwNTnvF3YCoG/g4MANxcbsjfgQEAAHAHAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHFKHGA2bdqknj17KjQ0VB4eHlq6dKnL+sGDB8vDw8Nl6d69u0vNr7/+qv79+ysgIECBgYEaOnSoTp8+7VKzc+dOdezYUb6+vgoLC9PUqVNLPjsA16Tu88uLXQCgPClxgDlz5oyaN2+uOXPmXLGme/fu+vnnn53Le++957K+f//+2rVrl1JTU7Vs2TJt2rRJw4cPd67PyclRt27dFB4ervT0dE2bNk3JycmaP39+SdsFAAAVkHdJn9CjRw/16NHjqjU2m00Oh6PIdXv27NHKlSv11VdfqVWrVpKk2bNn695779Wrr76q0NBQLVq0SBcuXNCbb74pHx8fNW7cWBkZGZoxY4ZL0AEAADenEgeYa7FhwwYFBwerWrVquuuuuzR58mRVr15dkpSWlqbAwEBneJGkrl27ytPTU1u3btUDDzygtLQ0xcTEyMfHx1kTGxurKVOm6MSJE6pWrVqhfebm5io3N9f5OCcnpyymVuFdy6mCA6/E3YBOAAC4slK/iLd79+56++23tXbtWk2ZMkUbN25Ujx49dOnSJUlSVlaWgoODXZ7j7e2toKAgZWVlOWtCQkJcagoeF9RcLiUlRXa73bmEhYWV9tQAAEA5UepHYPr16+f8d9OmTdWsWTPVq1dPGzZs0N13313au3MaM2aMRo4c6Xyck5NDiAEAoIIq89uob731VtWoUUP79u2TJDkcDh07dsyl5uLFi/r111+d1804HA4dPXrUpabg8ZWurbHZbAoICHBZAABAxVTmAeann37S8ePHVatWLUlSdHS0Tp48qfT0dGfNunXrlJ+fr7Zt2zprNm3apLy8PGdNamqqbr/99iKvfwEAADeXEgeY06dPKyMjQxkZGZKk/fv3KyMjQwcPHtTp06eVlJSkLVu26MCBA1q7dq3uv/9+1a9fX7GxsZKkRo0aqXv37nriiSe0bds2ffnll0pISFC/fv0UGhoqSXr00Ufl4+OjoUOHateuXXr//ff117/+1eUUEQAAuHmVOMBs375dLVu2VMuWLSVJI0eOVMuWLTVu3Dh5eXlp586duu+++3Tbbbdp6NChioqK0ueffy6bzebcxqJFi9SwYUPdfffduvfee9WhQweXv/Fit9u1evVq7d+/X1FRURo1apTGjRvHLdQAAEDSdVzE27lzZ1mWdcX1q1atKnYbQUFBWrx48VVrmjVrps8//7yk7QEAgJsAn4UEAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIzj7e4GgPKi7vPLi6058ErcDegEAFAcjsAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwTokDzKZNm9SzZ0+FhobKw8NDS5cudVlvWZbGjRunWrVqyc/PT127dtXevXtdan799Vf1799fAQEBCgwM1NChQ3X69GmXmp07d6pjx47y9fVVWFiYpk6dWvLZAQCACqnEAebMmTNq3ry55syZU+T6qVOn6rXXXtO8efO0detWValSRbGxsTp//ryzpn///tq1a5dSU1O1bNkybdq0ScOHD3euz8nJUbdu3RQeHq709HRNmzZNycnJmj9//nVMEQAAVDTeJX1Cjx491KNHjyLXWZalWbNm6cUXX9T9998vSXr77bcVEhKipUuXql+/ftqzZ49Wrlypr776Sq1atZIkzZ49W/fee69effVVhYaGatGiRbpw4YLefPNN+fj4qHHjxsrIyNCMGTNcgs7v5ebmKjc31/k4JyenpFMDAACGKNVrYPbv36+srCx17drVOWa329W2bVulpaVJktLS0hQYGOgML5LUtWtXeXp6auvWrc6amJgY+fj4OGtiY2OVmZmpEydOFLnvlJQU2e125xIWFlaaUwMAAOVIqQaYrKwsSVJISIjLeEhIiHNdVlaWgoODXdZ7e3srKCjIpaaobfx+H5cbM2aMsrOzncuhQ4f++IQAAEC5VOJTSOWVzWaTzWZzdxsAAOAGKNUjMA6HQ5J09OhRl/GjR4861zkcDh07dsxl/cWLF/Xrr7+61BS1jd/vAwAA3LxKNcBERETI4XBo7dq1zrGcnBxt3bpV0dHRkqTo6GidPHlS6enpzpp169YpPz9fbdu2ddZs2rRJeXl5zprU1FTdfvvtqlatWmm2DAAADFTiAHP69GllZGQoIyND0m8X7mZkZOjgwYPy8PBQYmKiJk+erE8++UTffPONBg4cqNDQUPXq1UuS1KhRI3Xv3l1PPPGEtm3bpi+//FIJCQnq16+fQkNDJUmPPvqofHx8NHToUO3atUvvv/++/vrXv2rkyJGlNnEAAGCuEl8Ds337dnXp0sX5uCBUDBo0SAsXLtRzzz2nM2fOaPjw4Tp58qQ6dOiglStXytfX1/mcRYsWKSEhQXfffbc8PT3Vp08fvfbaa871drtdq1evVnx8vKKiolSjRg2NGzfuirdQAwCAm0uJA0znzp1lWdYV13t4eGjixImaOHHiFWuCgoK0ePHiq+6nWbNm+vzzz0vaHgAAuAnwWUgAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGKfEH+YI4I+r+/zyYmsOvBJ3AzoBADNxBAYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADG8XZ3A7i6us8vL7bmwCtxN6ATAADKD47AAAAA4xBgAACAcQgwAADAOAQYAABgHC7iBVBhcRE8UHFxBAYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIe/xAvghuKv4wIoDRyBAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYp9QDTHJysjw8PFyWhg0bOtefP39e8fHxql69uqpWrao+ffro6NGjLts4ePCg4uLiVLlyZQUHByspKUkXL14s7VYBAIChyuQP2TVu3Fhr1qz5v514/99uRowYoeXLl+vDDz+U3W5XQkKCevfurS+//FKSdOnSJcXFxcnhcGjz5s36+eefNXDgQFWqVEkvv/xyWbQLAAAMUyYBxtvbWw6Ho9B4dna2/v73v2vx4sW66667JEkLFixQo0aNtGXLFrVr106rV6/W7t27tWbNGoWEhKhFixaaNGmSRo8ereTkZPn4+BS5z9zcXOXm5jof5+TklMXUAABAOVAm18Ds3btXoaGhuvXWW9W/f38dPHhQkpSenq68vDx17drVWduwYUPVqVNHaWlpkqS0tDQ1bdpUISEhzprY2Fjl5ORo165dV9xnSkqK7Ha7cwkLCyuLqQEAgHKg1ANM27ZttXDhQq1cuVJz587V/v371bFjR506dUpZWVny8fFRYGCgy3NCQkKUlZUlScrKynIJLwXrC9ZdyZgxY5Sdne1cDh06VLoTAwAA5Uapn0Lq0aOH89/NmjVT27ZtFR4erg8++EB+fn6lvTsnm80mm81WZtsHAADlR5nfRh0YGKjbbrtN+/btk8Ph0IULF3Ty5EmXmqNHjzqvmXE4HIXuSip4XNR1NQAA4OZT5gHm9OnT+v7771WrVi1FRUWpUqVKWrt2rXN9ZmamDh48qOjoaElSdHS0vvnmGx07dsxZk5qaqoCAAEVGRpZ1uwAAwAClfgrp2WefVc+ePRUeHq4jR45o/Pjx8vLy0iOPPCK73a6hQ4dq5MiRCgoKUkBAgJ5++mlFR0erXbt2kqRu3bopMjJSAwYM0NSpU5WVlaUXX3xR8fHxnCICAACSyiDA/PTTT3rkkUd0/Phx1axZUx06dNCWLVtUs2ZNSdLMmTPl6empPn36KDc3V7GxsXr99dedz/fy8tKyZcv05z//WdHR0apSpYoGDRqkiRMnlnarAADAUKUeYP7xj39cdb2vr6/mzJmjOXPmXLEmPDxcn332WWm3BgAAKgg+CwkAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxvN3dAADcLOo+v7zYmgOvxN2ATgDzcQQGAAAYhwADAACMQ4ABAADGIcAAAADjcBHvdeBCPAAA3IsjMAAAwDgEGAAAYBwCDAAAME65DjBz5sxR3bp15evrq7Zt22rbtm3ubgkAAJQD5fYi3vfff18jR47UvHnz1LZtW82aNUuxsbHKzMxUcHCwu9sDgHKNmw1Q0ZXbIzAzZszQE088oSFDhigyMlLz5s1T5cqV9eabb7q7NQAA4Gbl8gjMhQsXlJ6erjFjxjjHPD091bVrV6WlpRX5nNzcXOXm5jofZ2dnS5JycnJKvb/83LPF1pTWfm/kvtyxv+JU1Nea17ni7etaVNS5Nxm/qtiabyfElsq+rsWN7Ke8zb0iKHhfWpZ19UKrHDp8+LAlydq8ebPLeFJSktWmTZsinzN+/HhLEgsLCwsLC0sFWA4dOnTVrFAuj8BcjzFjxmjkyJHOx/n5+fr1119VvXp1eXh4uLGz0pGTk6OwsDAdOnRIAQEB7m7nhrjZ5sx8KzbmW7Ex39JjWZZOnTql0NDQq9aVywBTo0YNeXl56ejRoy7jR48elcPhKPI5NptNNpvNZSwwMLCsWnSbgICAm+Kb4/dutjkz34qN+VZszLd02O32YmvK5UW8Pj4+ioqK0tq1a51j+fn5Wrt2raKjo93YGQAAKA/K5REYSRo5cqQGDRqkVq1aqU2bNpo1a5bOnDmjIUOGuLs1AADgZuU2wDz88MP65ZdfNG7cOGVlZalFixZauXKlQkJC3N2aW9hsNo0fP77QabKK7GabM/Ot2JhvxcZ8bzwPyyruPiUAAIDypVxeAwMAAHA1BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgDHA4cOH9dhjj6l69ery8/NT06ZNtX37dne3VSYuXbqksWPHKiIiQn5+fqpXr54mTZpU/Id6GWLTpk3q2bOnQkND5eHhoaVLl7qstyxL48aNU61ateTn56euXbtq79697mm2lFxtznl5eRo9erSaNm2qKlWqKDQ0VAMHDtSRI0fc1/AfVNzX+PeefPJJeXh4aNasWTesv9J2LfPds2eP7rvvPtntdlWpUkWtW7fWwYMHb3yzpaC4+Z4+fVoJCQmqXbu2/Pz8FBkZqXnz5rmn2VKQkpKi1q1by9/fX8HBwerVq5cyMzNdas6fP6/4+HhVr15dVatWVZ8+fQr9Jf2yQIAp506cOKH27durUqVKWrFihXbv3q3p06erWrVq7m6tTEyZMkVz587V3/72N+3Zs0dTpkzR1KlTNXv2bHe3VirOnDmj5s2ba86cOUWunzp1ql577TXNmzdPW7duVZUqVRQbG6vz58/f4E5Lz9XmfPbsWe3YsUNjx47Vjh079NFHHykzM1P33XefGzotHcV9jQssWbJEW7ZsKfbzXsq74ub7/fffq0OHDmrYsKE2bNignTt3auzYsfL19b3BnZaO4uY7cuRIrVy5Uu+++6727NmjxMREJSQk6JNPPrnBnZaOjRs3Kj4+Xlu2bFFqaqry8vLUrVs3nTlzxlkzYsQIffrpp/rwww+1ceNGHTlyRL179y775krj06NRdkaPHm116NDB3W3cMHFxcdbjjz/uMta7d2+rf//+buqo7EiylixZ4nycn59vORwOa9q0ac6xkydPWjabzXrvvffc0GHpu3zORdm2bZslyfrxxx9vTFNl6Erz/emnn6xbbrnF+vbbb63w8HBr5syZN7y3slDUfB9++GHrsccec09DZayo+TZu3NiaOHGiy9gdd9xhvfDCCzews7Jz7NgxS5K1ceNGy7J++xlVqVIl68MPP3TW7Nmzx5JkpaWllWkvHIEp5z755BO1atVKDz30kIKDg9WyZUv9z//8j7vbKjN33nmn1q5dq++++06S9K9//UtffPGFevTo4ebOyt7+/fuVlZWlrl27Osfsdrvatm2rtLQ0N3Z2Y2VnZ8vDw6NCfhir9Nvnug0YMEBJSUlq3Lixu9spU/n5+Vq+fLluu+02xcbGKjg4WG3btr3qaTXT3Xnnnfrkk090+PBhWZal9evX67vvvlO3bt3c3VqpyM7OliQFBQVJktLT05WXl+fyc6thw4aqU6dOmf/cIsCUcz/88IPmzp2rBg0aaNWqVfrzn/+sZ555Rm+99Za7WysTzz//vPr166eGDRuqUqVKatmypRITE9W/f393t1bmsrKyJKnQx2WEhIQ411V058+f1+jRo/XII49U2E/0nTJliry9vfXMM8+4u5Uyd+zYMZ0+fVqvvPKKunfvrtWrV+uBBx5Q7969tXHjRne3VyZmz56tyMhI1a5dWz4+PurevbvmzJmjmJgYd7f2h+Xn5ysxMVHt27dXkyZNJP32c8vHx6fQLxw34udWuf0sJPwmPz9frVq10ssvvyxJatmypb799lvNmzdPgwYNcnN3pe+DDz7QokWLtHjxYjVu3FgZGRlKTExUaGhohZwv/k9eXp769u0ry7I0d+5cd7dTJtLT0/XXv/5VO3bskIeHh7vbKXP5+fmSpPvvv18jRoyQJLVo0UKbN2/WvHnz1KlTJ3e2VyZmz56tLVu26JNPPlF4eLg2bdqk+Ph4hYaGuhylMFF8fLy+/fZbffHFF+5uRRJHYMq9WrVqKTIy0mWsUaNGxl7BX5ykpCTnUZimTZtqwIABGjFihFJSUtzdWplzOBySVOjq/aNHjzrXVVQF4eXHH39UampqhT368vnnn+vYsWOqU6eOvL295e3trR9//FGjRo1S3bp13d1eqatRo4a8vb1vmp9h586d01/+8hfNmDFDPXv2VLNmzZSQkKCHH35Yr776qrvb+0MSEhK0bNkyrV+/XrVr13aOOxwOXbhwQSdPnnSpvxE/twgw5Vz79u0L3bL23XffKTw83E0dla2zZ8/K09P1benl5eX8Ta4ii4iIkMPh0Nq1a51jOTk52rp1q6Kjo93YWdkqCC979+7VmjVrVL16dXe3VGYGDBignTt3KiMjw7mEhoYqKSlJq1atcnd7pc7Hx0etW7e+aX6G5eXlKS8vr0L9DLMsSwkJCVqyZInWrVuniIgIl/VRUVGqVKmSy8+tzMxMHTx4sMx/bnEKqZwbMWKE7rzzTr388svq27evtm3bpvnz52v+/Pnubq1M9OzZUy+99JLq1Kmjxo0b6+uvv9aMGTP0+OOPu7u1UnH69Gnt27fP+Xj//v3KyMhQUFCQ6tSpo8TERE2ePFkNGjRQRESExo4dq9DQUPXq1ct9Tf9BV5tzrVq19OCDD2rHjh1atmyZLl265DxvHhQUJB8fH3e1fd2K+xpfHtAqVaokh8Oh22+//Ua3WiqKm29SUpIefvhhxcTEqEuXLlq5cqU+/fRTbdiwwX1N/wHFzbdTp05KSkqSn5+fwsPDtXHjRr399tuaMWOGG7u+fvHx8Vq8eLE+/vhj+fv7O78/7Xa7/Pz8ZLfbNXToUI0cOVJBQUEKCAjQ008/rejoaLVr165smyvTe5xQKj799FOrSZMmls1msxo2bGjNnz/f3S2VmZycHOv//b//Z9WpU8fy9fW1br31VuuFF16wcnNz3d1aqVi/fr0lqdAyaNAgy7J+u5V67NixVkhIiGWz2ay7777byszMdG/Tf9DV5rx///4i10my1q9f7+7Wr0txX+PLmX4b9bXM9+9//7tVv359y9fX12revLm1dOlS9zX8BxU3359//tkaPHiwFRoaavn6+lq33367NX36dCs/P9+9jV+nK31/LliwwFlz7tw566mnnrKqVatmVa5c2XrggQesn3/+ucx78/j/GwQAADAG18AAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDj/Hwci2gbf6S+hAAAAAElFTkSuQmCC",
63
+ "text/plain": [
64
+ "<Figure size 640x480 with 1 Axes>"
65
+ ]
66
+ },
67
+ "metadata": {},
68
+ "output_type": "display_data"
69
+ }
70
+ ],
71
+ "source": [
72
+ "import matplotlib.pyplot as plt\n",
73
+ "\n",
74
+ "plt.hist([len(sentence) for sentence in processed_sentences], bins=50)\n",
75
+ "plt.title(\"Histogram of sentence lengths\")\n",
76
+ "\n",
77
+ "plt.show()"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 48,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "\"\"\"\n",
87
+ " This variable will control the maximum length of the sentence \n",
88
+ " as well as the embedding size\n",
89
+ "\"\"\"\n",
90
+ "\n",
91
+ "MAX_LEN = 30"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 49,
97
  "metadata": {},
98
  "outputs": [],
99
  "source": [
100
+ "import tensorflow as tf\n",
101
+ "\n",
102
+ "\n",
103
+ "def encode_and_pad_sentence(sentence: str, vocab: list[str], max_length: int = MAX_LEN):\n",
104
  " \"\"\"\n",
105
  " Encode a sentence into a list of integers\n",
106
  "\n",
 
111
  " Returns:\n",
112
  " list: The list of integers\n",
113
  " \"\"\"\n",
114
+ " encoded_sentence = [\n",
115
  " vocab.index(word) if word in vocab else vocab.index(\"<UNK>\")\n",
116
+ " for word in sentence\n",
117
+ " ]\n",
118
+ "\n",
119
+ " return tf.keras.utils.pad_sequences(\n",
120
+ " [encoded_sentence], maxlen=max_length, padding=\"post\", value=0\n",
121
+ " )[0]"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 25,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "get_vocab_from_corpus = lambda corpus: list(\n",
131
+ " set([word for sentence in corpus for word in sentence])\n",
132
+ ") + [\"<UNK>\"]"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 26,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "vocab = get_vocab_from_corpus(processed_sentences)"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 29,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "encoded_sentences = [\n",
151
+ " encode_and_pad_sentence(sentence, vocab) for sentence in processed_sentences\n",
152
+ "]"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 50,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "padded_labels = tf.keras.preprocessing.sequence.pad_sequences(\n",
162
+ " processed_labels, maxlen=MAX_LEN, padding=\"post\", value=-1\n",
163
+ ")"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 47,
169
+ "metadata": {},
170
+ "outputs": [
171
+ {
172
+ "name": "stderr",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "2024-11-09 16:56:24.038756: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "dataset = tf.data.Dataset.from_tensor_slices((encoded_sentences, padded_labels))\n",
181
+ "\n",
182
+ "# Split the dataset into a training and testing dataset\n",
183
+ "train_dataset, test_dataset = tf.keras.utils.split_dataset(dataset, left_size=0.8)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 51,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "lstm = tf.keras.models.Sequential(\n",
193
+ " layers=[\n",
194
+ " tf.keras.layers.Embedding(len(vocab) + 1, MAX_LEN, mask_zero=True),\n",
195
+ " tf.keras.layers.LSTM(MAX_LEN, return_sequences=True),\n",
196
+ " tf.keras.layers.Dense(len(unique_labels), activation=tf.nn.log_softmax),\n",
197
+ " ]\n",
198
+ ")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "metadata": {},
204
+ "source": [
205
+ "## Masked loss and metrics\n",
206
+ "\n",
207
+ "Before training the model, we need to create your own function to compute the accuracy. Tensorflow has built-in accuracy metrics but we cannot pass values to be ignored. This will impact the calculations, since we must remove the padded values.\n",
208
+ "\n",
209
+ "Usually, the metric that inputs true labels and predicted labels and outputs how many times the predicted and true labels match is called accuracy. In some cases, however, there is one more step before getting the predicted labels. This may happen if, instead of passing the predicted labels, a vector of probabilities is passed. In such case, there is a need to perform an `argmax` for each prediction to find the appropriate predicted label. Such situations happen very often, therefore Tensorflow has a set of functions, with prefix `Sparse`, that performs this operation in the backend. Unfortunately, it does not provide values to ignore in the accuracy case. This is what you will work on now.\n",
210
+ "\n",
211
+ "Note that the model's prediction has 3 axes:\n",
212
+ "\n",
213
+ "- the number of examples (batch size)\n",
214
+ "- the number of words in each example (padded to be as long as the longest sentence in the batch)\n",
215
+ "- the number of possible targets (the 17 named entity tags).\n",
216
+ "\n",
217
+ "Another important function is the loss function. In this case, we will use the Cross Entropy loss, but we need a multiclass implementation of it, also we may look for its Sparse version. Tensorflow has a SparseCategoricalCrossentropy loss function, which it is already imported by the name SparseCategoricalCrossEntropy.\n",
218
+ "\n",
219
+ "SparseCategoricalCrossentropy: The Sparse Categorical Crossentropy Loss Function.\n",
220
+ "\n",
221
+ "The arguments you will need:\n",
222
+ "\n",
223
+ "1. `from_logits`: This indicates if the values are raw values or normalized values (probabilities). Since the last layer of the model finishes with a LogSoftMax call, the results are not normalized - they do not lie between 0 and 1.\n",
224
+ "2. `ignore_class`: This indicates which class should be ignored when computing the crossentropy. Remember that the class related to padding value is set to be 0.\n"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "class CustomSparseCategoricalCrossentropy(tf.keras.losses.Loss):\n",
234
+ " def __init__(self, from_logits=False, ignore_class=-1):\n",
235
+ " super().__init__()\n",
236
+ " self.from_logits = from_logits\n",
237
+ " self.ignore_class = ignore_class\n",
238
+ "\n",
239
+ " def call(self, y_true, y_pred):\n",
240
+ " # Ensure inputs are tensors\n",
241
+ " y_true = tf.convert_to_tensor(y_true)\n",
242
+ " y_pred = tf.convert_to_tensor(y_pred)\n",
243
+ "\n",
244
+ " # Generate a mask that is False where y_true equals ignore_class and True elsewhere\n",
245
+ " mask = tf.not_equal(y_true, self.ignore_class)\n",
246
+ "\n",
247
+ " # Use this mask to filter out ignored values from y_true and y_pred\n",
248
+ " y_true_filtered = tf.boolean_mask(y_true, mask)\n",
249
+ " y_pred_filtered = tf.boolean_mask(y_pred, mask)\n",
250
+ "\n",
251
+ " # Compute the sparse categorical crossentropy on filtered targets and predictions\n",
252
+ " loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
253
+ " y_true_filtered, y_pred_filtered, from_logits=self.from_logits\n",
254
+ " )\n",
255
+ "\n",
256
+ " # Return the mean loss value\n",
257
+ " return tf.reduce_mean(loss)\n",
258
+ "\n",
259
+ "\n",
260
+ "def masked_loss(y_true, y_pred):\n",
261
+ " \"\"\"\n",
262
+ " Calculate the masked sparse categorical cross-entropy loss.\n",
263
+ "\n",
264
+ " Parameters:\n",
265
+ " y_true (tensor): True labels.\n",
266
+ " y_pred (tensor): Predicted logits.\n",
267
+ "\n",
268
+ " Returns:\n",
269
+ " loss (tensor): Calculated loss.\n",
270
+ " \"\"\"\n",
271
+ "\n",
272
+ " # Calculate the loss for each item in the batch. Remember to pass the right arguments, as discussed above!\n",
273
+ " loss_fn = CustomSparseCategoricalCrossentropy(from_logits=True, ignore_class=-1)\n",
274
+ " # Use the previous defined function to compute the loss\n",
275
+ " loss = loss_fn(y_true, y_pred)\n",
276
+ "\n",
277
+ " return loss"
278
  ]
279
  }
280
  ],
 
294
  "name": "python",
295
  "nbconvert_exporter": "python",
296
  "pygments_lexer": "ipython3",
297
+ "version": "3.12.4"
298
  }
299
  },
300
  "nbformat": 4,