Az-r-ow commited on
Commit
dcd93f5
Β·
1 Parent(s): 74d5c55

feat(ihm): semi-functionnal interface, minor things to add

Browse files
.gitignore CHANGED
@@ -167,15 +167,5 @@ cython_debug/
167
  # Remove test ouptuts
168
  output.*
169
 
170
- # Remove vscode settings
171
- .vscode
172
-
173
- # Remove macos ds store
174
  # Macos generated files
175
- .DS_Store
176
-
177
- # Remove vscode settings
178
- .vscode
179
-
180
- # Remove macos ds store
181
  .DS_Store
 
167
  # Remove test ouptuts
168
  output.*
169
 
 
 
 
 
170
  # Macos generated files
 
 
 
 
 
 
171
  .DS_Store
app/app.py CHANGED
@@ -3,17 +3,107 @@ from transformers import pipeline
3
  import numpy as np
4
  import pandas as pd
5
  from travel_resolver.libs.nlp.ner.models import BiLSTM_NER, LSTM_NER, CamemBERT_NER
6
- import torch
 
7
  from travel_resolver.libs.nlp.ner.data_processing import process_sentence
8
  from travel_resolver.libs.pathfinder.CSVTravelGraph import CSVTravelGraph
9
  from travel_resolver.libs.pathfinder.graph import Graph
10
- import os
11
 
12
  transcriber = pipeline(
13
  "automatic-speech-recognition", model="openai/whisper-base", device="cpu"
14
  )
15
 
16
- models = {"LSTM": None, "BiLSTM": None, "CamemBERT": CamemBERT_NER()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def transcribe(audio):
@@ -81,162 +171,151 @@ def getStationsByCityName(city: str):
81
  return stations
82
 
83
 
84
- def getDepartureAndArrivalFromText(text: str, model: str):
85
- entities = models[model].get_entities(text)
86
- tokenized_sentence = process_sentence(text, return_tokens=True)
87
-
88
- dep_idx = entities.index(1)
89
- arr_idx = entities.index(2)
90
 
91
- return tokenized_sentence[dep_idx].upper(), tokenized_sentence[arr_idx].upper()
92
 
93
 
94
- def handle_audio(audio):
95
-
96
- promptAudio = transcribe(audio)
97
-
98
- # todo : replace with the model selected by the user
99
- dep, arr = getDepartureAndArrivalFromText(promptAudio, "CamemBERT")
100
-
101
- return (
102
- gr.update(visible=True),
103
- gr.update(visible=False),
104
- gr.update(value=promptAudio),
105
- gr.update(value=dep),
106
- gr.update(value=arr),
107
- )
108
-
109
 
110
- def handle_file(file):
111
- loading_screen.update(visible=True)
112
  dep = None
113
  arr = None
114
- if file is not None:
115
- with open(file.name, "r") as f:
116
- file_content = f.read()
117
- row = file_content.split("\n")
118
- if len(row) > 1:
119
- return
120
- else:
121
- dep, arr = getDepartureAndArrivalFromText(file_content, "CamemBERT")
122
- else:
123
- file_content = "Aucun fichier uploadΓ©."
124
-
125
- loading_screen.update(visible=False)
126
- return (
127
- gr.update(visible=True),
128
- gr.update(visible=False),
129
- gr.update(value=file_content),
130
- gr.update(value=dep),
131
- gr.update(value=arr),
132
- )
133
-
134
-
135
- def handle_back():
136
- audio.clear()
137
- file.clear()
138
- return (gr.update(visible=False), gr.update(visible=True))
139
-
140
-
141
- def handleCityChange(city):
142
- stations = getStationsByCityName(city)
143
- return gr.update(choices=stations, value=stations[0], interactive=True)
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- def handleStationChange(departureStation, destinationStation):
147
- if departureStation and destinationStation:
148
- dijkstraPath, dijkstraCost = getDijkstraResult(
149
- departureStation, destinationStation
150
- )
151
- dijkstraPathFormatted = "\n".join(
152
- [f"{i + 1}. {elem}" for i, elem in enumerate(dijkstraPath)]
153
- )
154
- AStarPath, AStarCost = getAStarResult(departureStation, destinationStation)
155
- AStarPathFormatted = "\n".join(
156
- [f"{i + 1}. {elem}" for i, elem in enumerate(AStarPath)]
157
- )
158
- return (
159
- gr.update(value=dijkstraCost),
160
- gr.update(value=dijkstraPathFormatted, lines=len(dijkstraPath)),
161
- gr.update(value=AStarCost),
162
- gr.update(value=AStarPathFormatted, lines=len(AStarPath)),
163
- )
164
- return (
165
- gr.HTML("<p>Aucun prompt renseignΓ©</p>"),
166
- gr.update(value=""),
167
- gr.HTML("<p>Aucun prompt renseignΓ©</p>"),
168
- gr.update(value=""),
169
- )
170
-
171
-
172
- with gr.Blocks(css="#back-button {width: fit-content}") as demo:
173
- with gr.Row(visible=False) as loading_screen:
174
- gr.Text("Chargement ...", elem_id="loading")
175
- with gr.Column() as promptChooser:
176
- with gr.Row():
177
- audio = gr.Audio(label="Fichier audio")
178
- file = gr.File(
179
- label="Fichier texte", file_types=["text"], file_count="single"
180
- )
181
- with gr.Column(visible=False) as content:
182
- backButton = gr.Button("← Back", elem_id="back-button")
183
- with gr.Row():
184
- with gr.Column(scale=1, min_width=300) as parameters:
185
- prompt = gr.Textbox(label="Prompt")
186
- departureCity = gr.Textbox(label="Ville de dΓ©part")
187
- destinationCity = gr.Textbox(label="Ville de de destination")
188
- with gr.Column(scale=2, min_width=300) as result:
189
  with gr.Row():
190
- departureStation = gr.Dropdown(label="Gare de dΓ©part")
191
- destinationStation = gr.Dropdown(label="Gare d'arrivΓ©e")
192
- with gr.Tab("Dijkstra"):
193
- timeDijkstra = gr.HTML("<p>Aucun prompt renseignΓ©</p>")
194
- dijkstraPath = gr.Textbox(label="Chemin empruntΓ©")
195
-
196
- with gr.Tab("AStar"):
197
- timeAStar = gr.HTML("<p>Aucun prompt renseignΓ©</p>")
198
- AstarPath = gr.Textbox(label="Chemin empruntΓ©")
199
- audio.change(
200
- handle_audio,
201
- inputs=[audio],
202
- outputs=[
203
- content,
204
- promptChooser,
205
- prompt,
206
- departureCity,
207
- destinationCity,
208
- ], # On rend la section "content" visible
209
- show_progress="full",
210
- )
211
- file.upload(
212
- handle_file,
213
- inputs=[file],
214
- outputs=[
215
- content,
216
- promptChooser,
217
- prompt,
218
- departureCity,
219
- destinationCity,
220
- ], # On rend la section "content" visible
221
- show_progress="full",
222
- )
223
- backButton.click(handle_back, inputs=[], outputs=[content, promptChooser])
224
- departureCity.change(
225
- handleCityChange, inputs=[departureCity], outputs=[departureStation]
226
- )
227
- destinationCity.change(
228
- handleCityChange, inputs=[destinationCity], outputs=[destinationStation]
229
- )
230
- departureStation.change(
231
- handleStationChange,
232
- inputs=[departureStation, destinationStation],
233
- outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
234
- )
235
- destinationStation.change(
236
- handleStationChange,
237
- inputs=[departureStation, destinationStation],
238
- outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
239
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  if __name__ == "__main__":
242
  demo.launch()
 
3
  import numpy as np
4
  import pandas as pd
5
  from travel_resolver.libs.nlp.ner.models import BiLSTM_NER, LSTM_NER, CamemBERT_NER
6
+
7
+ # import torch
8
  from travel_resolver.libs.nlp.ner.data_processing import process_sentence
9
  from travel_resolver.libs.pathfinder.CSVTravelGraph import CSVTravelGraph
10
  from travel_resolver.libs.pathfinder.graph import Graph
11
+ import time
12
 
13
  transcriber = pipeline(
14
  "automatic-speech-recognition", model="openai/whisper-base", device="cpu"
15
  )
16
 
17
+ models = {"LSTM": LSTM_NER(), "BiLSTM": BiLSTM_NER(), "CamemBERT": CamemBERT_NER()}
18
+
19
+ entities_label_mapping = {1: "LOC-DEP", 2: "LOC-ARR"}
20
+
21
+ with gr.Blocks(css="#back-button {width: fit-content}") as demo:
22
+ with gr.Column() as promptChooser:
23
+ with gr.Row():
24
+ audio = gr.Audio(label="Fichier audio")
25
+ file = gr.File(
26
+ label="Fichier texte", file_types=["text"], file_count="single"
27
+ )
28
+
29
+ model = gr.Dropdown(
30
+ label="Modèle NER", choices=models.keys(), value="CamemBERT"
31
+ )
32
+
33
+ @gr.render(inputs=[audio, file, model], triggers=[model.change])
34
+ def handle_model_change(audio, file, model):
35
+ if audio:
36
+ render_tabs([transcribe(audio)], model, gr.Progress())
37
+ elif file:
38
+ with open(file.name, "r") as f:
39
+ sentences = f.read().split("\n")
40
+ render_tabs(sentences, model, gr.Progress())
41
+
42
+ @gr.render(inputs=[audio, model], triggers=[audio.change])
43
+ def handle_audio(audio, model, progress=gr.Progress()):
44
+ progress(0, "Analyzing audio...")
45
+ promptAudio = transcribe(audio)
46
+
47
+ time.sleep(1)
48
+
49
+ render_tabs([promptAudio], model, progress)
50
+
51
+ @gr.render(
52
+ inputs=[file, model],
53
+ triggers=[file.upload],
54
+ )
55
+ def handle_file(file, model, progress=gr.Progress()):
56
+ progress(0, desc="Analyzing file...")
57
+ time.sleep(1)
58
+ if file is not None:
59
+ with open(file.name, "r") as f:
60
+ progress(0.33, desc="Reading file...")
61
+ file_content = f.read()
62
+ rows = file_content.split("\n")
63
+ sentences = [row for row in rows if row]
64
+ render_tabs(sentences, model, progress)
65
+
66
+
67
+ def handle_back():
68
+ audio.clear()
69
+ file.clear()
70
+ return (gr.update(visible=False), gr.update(visible=True))
71
+
72
+
73
+ def handleCityChange(city):
74
+ stations = getStationsByCityName(city)
75
+ return gr.update(choices=stations, value=stations[0], interactive=True)
76
+
77
+
78
+ def handleCityChange(city):
79
+ stations = getStationsByCityName(city)
80
+ return gr.update(choices=stations, value=stations[0], interactive=True)
81
+
82
+
83
+ def formatPath(path):
84
+ return "\n".join([f"{i + 1}. {elem}" for i, elem in enumerate(path)])
85
+
86
+
87
+ def handleStationChange(departureStation, destinationStation):
88
+ if departureStation and destinationStation:
89
+ dijkstraPath, dijkstraCost = getDijkstraResult(
90
+ departureStation, destinationStation
91
+ )
92
+ dijkstraPathFormatted = formatPath(dijkstraPath)
93
+ AStarPath, AStarCost = getAStarResult(departureStation, destinationStation)
94
+ AStarPathFormatted = formatPath(AStarPath)
95
+ return (
96
+ gr.update(value=dijkstraCost),
97
+ gr.update(value=dijkstraPathFormatted, lines=len(dijkstraPath)),
98
+ gr.update(value=AStarCost),
99
+ gr.update(value=AStarPathFormatted, lines=len(AStarPath)),
100
+ )
101
+ return (
102
+ gr.HTML("<p>Aucun prompt renseignΓ©</p>"),
103
+ gr.update(value=""),
104
+ gr.HTML("<p>Aucun prompt renseignΓ©</p>"),
105
+ gr.update(value=""),
106
+ )
107
 
108
 
109
  def transcribe(audio):
 
171
  return stations
172
 
173
 
174
+ def getEntitiesPositions(text, entity):
175
+ start_idx = text.find(entity)
176
+ end_idx = start_idx + len(entity)
 
 
 
177
 
178
+ return start_idx, end_idx
179
 
180
 
181
+ def getDepartureAndArrivalFromText(text: str, model: str):
182
+ entities = models[model].get_entities(text)
183
+ if not isinstance(entities, list):
184
+ entities = entities.tolist()
185
+ tokenized_sentence = process_sentence(text, return_tokens=True)
 
 
 
 
 
 
 
 
 
 
186
 
 
 
187
  dep = None
188
  arr = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ if 1 in entities:
191
+ dep_idx = entities.index(1)
192
+ dep = tokenized_sentence[dep_idx]
193
+ start, end = getEntitiesPositions(text, dep)
194
+ dep = {
195
+ "entity": entities_label_mapping[1],
196
+ "word": dep,
197
+ "start": start,
198
+ "end": end,
199
+ }
200
+
201
+ if 2 in entities:
202
+ arr_idx = entities.index(2)
203
+ arr = tokenized_sentence[arr_idx]
204
+ start, end = getEntitiesPositions(text, arr)
205
+ arr = {
206
+ "entity": entities_label_mapping[2],
207
+ "word": arr,
208
+ "start": start,
209
+ "end": end,
210
+ }
211
+
212
+ return dep, arr
213
+
214
+
215
+ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
216
+ idx = 0
217
+ with gr.Tabs() as tabs:
218
+ for sentence in progress_bar.tqdm(sentences, desc="Processing sentences..."):
219
+ with gr.Tab(f"Sentence {idx}"):
220
+ dep, arr = getDepartureAndArrivalFromText(sentence, model)
221
+ entities = []
222
+ for entity in [dep, arr]:
223
+ if entity:
224
+ entities.append(entity)
225
+
226
+ # Format the classified entities
227
+ departureCityValue = dep["word"].upper() if dep else ""
228
+ arrivalCityValue = arr["word"].upper() if arr else ""
229
+
230
+ # Get the available stations
231
+ departureStations = getStationsByCityName(departureCityValue)
232
+ departureStationValue = (
233
+ departureStations[0] if departureStations else ""
234
+ )
235
+ arrivalStations = getStationsByCityName(arrivalCityValue)
236
+ arrivalStationValue = arrivalStations[0] if arrivalStations else ""
237
+
238
+ dijkstraPathValues = []
239
+ AStarPathValues = []
240
+ timeDijkstraValue = "<p>Aucun prompt renseignΓ©</p>"
241
+ timeAStarValue = "<p>Aucun prompt renseignΓ©</p>"
242
+
243
+ # Get the paths and time for the two algorithms
244
+ if departureStationValue and arrivalStationValue:
245
+ dijkstraPathValues, timeDijkstraValue = getDijkstraResult(
246
+ departureStationValue, arrivalStationValue
247
+ )
248
+ AStarPathValues, timeAStarValue = getAStarResult(
249
+ departureStationValue, arrivalStationValue
250
+ )
251
+
252
+ dijkstraPathFormatted = formatPath(dijkstraPathValues)
253
+ AStarPathFormatted = formatPath(AStarPathValues)
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  with gr.Row():
256
+ with gr.Column(scale=1, min_width=300):
257
+ gr.HighlightedText(
258
+ value={"text": sentence, "entities": entities}
259
+ )
260
+ departureCity = gr.Textbox(
261
+ label="Ville de dΓ©part",
262
+ value=departureCityValue,
263
+ )
264
+ arrivalCity = gr.Textbox(
265
+ label="Ville d'arrivΓ©e",
266
+ value=arrivalCityValue,
267
+ )
268
+ with gr.Column(scale=2, min_width=300):
269
+ with gr.Row():
270
+ departureStation = gr.Dropdown(
271
+ label="Gare de dΓ©part",
272
+ choices=departureStations,
273
+ value=departureStationValue,
274
+ )
275
+ arrivalStation = gr.Dropdown(
276
+ label="Gare d'arrivΓ©e",
277
+ choices=arrivalStations,
278
+ value=arrivalStationValue,
279
+ )
280
+ with gr.Tab("Dijkstra"):
281
+ timeDijkstra = gr.HTML(value=timeDijkstraValue)
282
+ dijkstraPath = gr.Textbox(
283
+ label="Chemin empruntΓ©",
284
+ value=dijkstraPathFormatted,
285
+ lines=len(dijkstraPathValues),
286
+ )
287
+
288
+ with gr.Tab("AStar"):
289
+ timeAStar = gr.HTML(value=timeAStarValue)
290
+ AstarPath = gr.Textbox(
291
+ label="Chemin empruntΓ©",
292
+ value=AStarPathFormatted,
293
+ lines=len(AStarPathValues),
294
+ )
295
+
296
+ departureCity.change(
297
+ handleCityChange,
298
+ inputs=[departureCity],
299
+ outputs=[departureStation],
300
+ )
301
+ arrivalCity.change(
302
+ handleCityChange,
303
+ inputs=[arrivalCity],
304
+ outputs=[arrivalStation],
305
+ )
306
+ departureStation.change(
307
+ handleStationChange,
308
+ inputs=[departureStation, arrivalStation],
309
+ outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
310
+ )
311
+ arrivalStation.change(
312
+ handleStationChange,
313
+ inputs=[departureStation, arrivalStation],
314
+ outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
315
+ )
316
+
317
+ idx += 1
318
+
319
 
320
  if __name__ == "__main__":
321
  demo.launch()
app/travel_resolver/libs/nlp/ner/models.py CHANGED
@@ -11,6 +11,8 @@ from .data_processing import (
11
  )
12
  from .metrics import masked_loss, masked_accuracy, entity_accuracy
13
  import stanza
 
 
14
 
15
  nlp = stanza.Pipeline("fr", processors="tokenize,pos")
16
 
@@ -37,18 +39,14 @@ class NERModel(ABC):
37
 
38
  class LSTM_NER(NERModel):
39
  def __init__(self):
40
- self.model_path = os.path.join(
41
- self.file_path, "models", "lstm_with_pos", "model.keras"
42
- )
43
- self.model = tf.keras.models.load_model(
44
- self.model_path,
45
- custom_objects={
46
- "masked_loss": masked_loss,
47
- "masked_accuracy": masked_accuracy,
48
- "entity_accuracy": entity_accuracy,
49
- "log_softmax_v2": tf.nn.log_softmax,
50
- },
51
  )
 
 
52
 
53
  def encode_sentence(self, sentence: str):
54
  processed_sentence = process_sentence(
@@ -75,18 +73,11 @@ class LSTM_NER(NERModel):
75
 
76
  class BiLSTM_NER(NERModel):
77
  def __init__(self):
78
- self.model_path = os.path.join(
79
- self.file_path, "models", "bilstm", "model.keras"
80
- )
81
- self.model = tf.keras.models.load_model(
82
- self.model_path,
83
- custom_objects={
84
- "masked_loss": masked_loss,
85
- "masked_accuracy": masked_accuracy,
86
- "entity_accuracy": entity_accuracy,
87
- "log_softmax_v2": tf.nn.log_softmax,
88
- },
89
  )
 
 
90
 
91
  def encode_sentence(self, sentence: str):
92
  processed_sentence = process_sentence(
@@ -167,8 +158,6 @@ class CamemBERT_NER(NERModel):
167
  if current_word is not None:
168
  sentence_labels.append(word_label)
169
 
170
- print(i)
171
- print(token_idx)
172
  # Reset for the new word
173
  current_word = word_idx
174
  word_label = predictions[i][token_idx]
 
11
  )
12
  from .metrics import masked_loss, masked_accuracy, entity_accuracy
13
  import stanza
14
+ from .models_definitions.bilstm.architecture import BiLSTM
15
+ from .models_definitions.lstm_with_pos.architecture import LSTM
16
 
17
  nlp = stanza.Pipeline("fr", processors="tokenize,pos")
18
 
 
39
 
40
  class LSTM_NER(NERModel):
41
  def __init__(self):
42
+ self.model_weights_path = os.path.join(
43
+ self.file_path,
44
+ "models_definitions",
45
+ "lstm_with_pos",
46
+ "lstm_with_pos.weights.h5",
 
 
 
 
 
 
47
  )
48
+ self.model = LSTM(self.vocab, 3, self.pos_tags)
49
+ self.model.load_from_weights(self.model_weights_path)
50
 
51
  def encode_sentence(self, sentence: str):
52
  processed_sentence = process_sentence(
 
73
 
74
  class BiLSTM_NER(NERModel):
75
  def __init__(self):
76
+ self.model_weights_path = os.path.join(
77
+ self.file_path, "models_definitions", "bilstm", "bilstm.weights.h5"
 
 
 
 
 
 
 
 
 
78
  )
79
+ self.model = BiLSTM(self.vocab, 3)
80
+ self.model.load_from_weights(self.model_weights_path)
81
 
82
  def encode_sentence(self, sentence: str):
83
  processed_sentence = process_sentence(
 
158
  if current_word is not None:
159
  sentence_labels.append(word_label)
160
 
 
 
161
  # Reset for the new word
162
  current_word = word_idx
163
  word_label = predictions[i][token_idx]
app/travel_resolver/libs/nlp/ner/models_definitions/__init__.py ADDED
File without changes
app/travel_resolver/libs/nlp/ner/models_definitions/bilstm/__init__.py ADDED
File without changes
app/travel_resolver/libs/nlp/ner/models_definitions/bilstm/architecture.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ class BiLSTM:
5
+ def __init__(self, vocab, nb_labels, emb_dim=100):
6
+ self.model = tf.keras.models.Sequential(
7
+ layers=[
8
+ tf.keras.layers.Embedding(len(vocab) + 1, emb_dim, mask_zero=True),
9
+ tf.keras.layers.Bidirectional(
10
+ tf.keras.layers.LSTM(emb_dim, return_sequences=True)
11
+ ),
12
+ tf.keras.layers.Dropout(0.3),
13
+ tf.keras.layers.Dense(nb_labels, activation=tf.nn.log_softmax),
14
+ ]
15
+ )
16
+
17
+ def load_from_weights(self, weights_path):
18
+ self.model.load_weights(weights_path)
19
+
20
+ def predict(self, x, verbose=0):
21
+ return self.model.predict(x, verbose=verbose)
app/travel_resolver/libs/nlp/ner/models_definitions/bilstm/bilstm.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c4457e1c5249bef9062556a0f371f75bbe41f1c903ca485de55899d3a99c453
3
+ size 7964368
app/travel_resolver/libs/nlp/ner/{models β†’ models_definitions}/bilstm/model.keras RENAMED
File without changes
app/travel_resolver/libs/nlp/ner/{models β†’ models_definitions}/bilstm/tf_version.txt RENAMED
File without changes
app/travel_resolver/libs/nlp/ner/models_definitions/lstm_with_pos/__init__.py ADDED
File without changes
app/travel_resolver/libs/nlp/ner/models_definitions/lstm_with_pos/architecture.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ class LSTM:
5
+ def __init__(self, vocab, nb_labels: int, pos_tags: list, emb_dim=100, emb_size=32):
6
+ word_input = tf.keras.layers.Input(shape=(emb_dim,), name="word_input")
7
+ pos_input = tf.keras.layers.Input(shape=(emb_dim,), name="pos_input")
8
+
9
+ word_embedding = tf.keras.layers.Embedding(
10
+ len(vocab), emb_size, name="word_embedding"
11
+ )(word_input)
12
+
13
+ pos_embedding = tf.keras.layers.Embedding(
14
+ len(pos_tags),
15
+ emb_size,
16
+ name="pos_embedding",
17
+ )(pos_input)
18
+
19
+ concatenated = tf.keras.layers.Concatenate()([word_embedding, pos_embedding])
20
+
21
+ masked_cat = tf.keras.layers.Masking(mask_value=0)(concatenated)
22
+
23
+ lstm_layer_with_pos = tf.keras.layers.LSTM(
24
+ emb_size, return_sequences=True, name="lstm_layer"
25
+ )(masked_cat)
26
+
27
+ dropout = tf.keras.layers.Dropout(0.2)(lstm_layer_with_pos)
28
+
29
+ output = tf.keras.layers.Dense(nb_labels, activation=tf.nn.log_softmax)(dropout)
30
+
31
+ self.model = tf.keras.Model(inputs=[word_input, pos_input], outputs=output)
32
+
33
+ def load_from_weights(self, weights_path):
34
+ self.model.load_weights(weights_path)
35
+
36
+ def predict(self, x, verbose=0):
37
+ return self.model.predict(x, verbose=verbose)
app/travel_resolver/libs/nlp/ner/models_definitions/lstm_with_pos/lstm_with_pos.weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb5a1a558c9156caa7dc56a64576be9331b13991bc8c87886ae69546fdeb80ca
3
+ size 2111328
app/travel_resolver/libs/nlp/ner/{models β†’ models_definitions}/lstm_with_pos/model.keras RENAMED
File without changes
app/travel_resolver/libs/nlp/ner/{models β†’ models_definitions}/lstm_with_pos/tf_version.txt RENAMED
File without changes
app/travel_resolver/tests/data_processing_test.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  from pathlib import Path
3
- from travel_resolver.libs.nlp.data_processing import (
4
  get_tagged_content,
5
  convert_tagged_sentence_to_bio,
6
  from_tagged_file_to_bio_file,
 
1
  import unittest
2
  from pathlib import Path
3
+ from app.travel_resolver.libs.nlp.ner.data_processing import (
4
  get_tagged_content,
5
  convert_tagged_sentence_to_bio,
6
  from_tagged_file_to_bio_file,
test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.travel_resolver.libs.nlp.ner.models import LSTM_NER, BiLSTM_NER, CamemBERT_NER
2
+ import tensorflow as tf
3
+
4
+ print(tf.__version__)
5
+
6
+ ner_model = LSTM_NER()
7
+
8
+ sentence = "Je voudrais voyager de Nice Γ  Clermont Ferrand."
9
+
10
+ print(ner_model.get_entities(sentence))
test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Je veux partir de Montpellier Γ  Clermont-Ferrand.
test2.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Je suis Γ  Paris. Je veux prendre le train Γ  Montpellier.
2
+ Je veux prendre le train de Lyon Γ  Marseille.