Az-r-ow commited on
Commit
65b2047
·
1 Parent(s): 4e0aa6b

WIP(interface): added map with route

Browse files
Files changed (2) hide show
  1. app/app.py +95 -30
  2. test1.txt +1 -0
app/app.py CHANGED
@@ -8,6 +8,7 @@ from travel_resolver.libs.nlp.ner.data_processing import process_sentence
8
  from travel_resolver.libs.pathfinder.CSVTravelGraph import CSVTravelGraph
9
  from travel_resolver.libs.pathfinder.graph import Graph
10
  import time
 
11
 
12
  transcriber = pipeline(
13
  "automatic-speech-recognition", model="openai/whisper-base", device="cpu"
@@ -16,21 +17,14 @@ transcriber = pipeline(
16
  models = {"LSTM": LSTM_NER(), "BiLSTM": BiLSTM_NER(), "CamemBERT": CamemBERT_NER()}
17
 
18
 
19
- def handle_model_change(audio, file, model):
20
- if audio:
21
- render_tabs([transcribe(audio)], model, gr.Progress())
22
- elif file:
23
- with open(file.name, "r") as f:
24
- sentences = f.read().split("\n")
25
- return render_tabs(sentences, model, gr.Progress())
26
-
27
-
28
  def handle_audio(audio, model, progress=gr.Progress()):
29
  progress(
30
  0,
31
  )
32
  promptAudio = transcribe(audio)
33
 
 
 
34
  time.sleep(1)
35
 
36
  return render_tabs([promptAudio], model, progress)
@@ -69,28 +63,56 @@ with gr.Blocks() as demo:
69
  interactive=True,
70
  )
71
 
72
- with gr.Column() as tabs:
73
- pass
74
-
75
- audio.upload(handle_audio, inputs=[audio, model], outputs=[tabs])
76
- file.upload(handle_file, inputs=[file, model], outputs=[tabs])
77
- model.change(handle_model_change, inputs=[audio, file, model], outputs=[tabs])
78
-
79
-
80
- def handleCityChange(city):
81
- stations = getStationsByCityName(city)
82
- return gr.update(choices=stations, value=stations[0], interactive=True)
83
 
84
 
85
  def handleCityChange(city):
86
  stations = getStationsByCityName(city)
87
- return gr.update(choices=stations, value=stations[0], interactive=True)
 
 
 
 
88
 
89
 
90
  def formatPath(path):
91
  return "\n".join([f"{i + 1}. {elem}" for i, elem in enumerate(path)])
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def handleStationChange(departureStation, destinationStation):
95
  if departureStation and destinationStation:
96
  dijkstraPath, dijkstraCost = getDijkstraResult(
@@ -98,18 +120,21 @@ def handleStationChange(departureStation, destinationStation):
98
  )
99
  dijkstraPathFormatted = formatPath(dijkstraPath)
100
  AStarPath, AStarCost = getAStarResult(departureStation, destinationStation)
 
101
  AStarPathFormatted = formatPath(AStarPath)
102
  return (
103
  gr.update(value=dijkstraCost),
104
  gr.update(value=dijkstraPathFormatted, lines=len(dijkstraPath)),
105
  gr.update(value=AStarCost),
106
  gr.update(value=AStarPathFormatted, lines=len(AStarPath)),
 
107
  )
108
  return (
109
  gr.HTML(HTML_COMPONENTS.NO_PROMPT.value),
110
  gr.update(value=""),
111
  gr.HTML(HTML_COMPONENTS.NO_PROMPT.value),
112
  gr.update(value=""),
 
113
  )
114
 
115
 
@@ -174,8 +199,24 @@ def getAStarResult(depart, destination):
174
 
175
  def getStationsByCityName(city: str):
176
  data = pd.read_csv("../data/sncf/gares_info.csv", sep=",")
177
- stations = tuple(data[data["Commune"] == city]["Nom de la gare"])
178
- return stations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181
  def getEntitiesPositions(text, entity):
@@ -225,6 +266,7 @@ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
225
  for sentence in progress_bar.tqdm(sentences, desc=PROGRESS.PROCESSING.value):
226
  with gr.Tab(f"Sentence {idx}"):
227
  dep, arr = getDepartureAndArrivalFromText(sentence, model)
 
228
  entities = []
229
  for entity in [dep, arr]:
230
  if entity:
@@ -237,10 +279,16 @@ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
237
  # Get the available stations
238
  departureStations = getStationsByCityName(departureCityValue)
239
  departureStationValue = (
240
- departureStations[0] if departureStations else ""
 
 
241
  )
242
  arrivalStations = getStationsByCityName(arrivalCityValue)
243
- arrivalStationValue = arrivalStations[0] if arrivalStations else ""
 
 
 
 
244
 
245
  dijkstraPathValues = []
246
  AStarPathValues = []
@@ -255,6 +303,7 @@ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
255
  AStarPathValues, timeAStarValue = getAStarResult(
256
  departureStationValue, arrivalStationValue
257
  )
 
258
 
259
  dijkstraPathFormatted = formatPath(dijkstraPathValues)
260
  AStarPathFormatted = formatPath(AStarPathValues)
@@ -276,14 +325,19 @@ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
276
  with gr.Row():
277
  departureStation = gr.Dropdown(
278
  label="Gare de départ",
279
- choices=departureStations,
280
  value=departureStationValue,
281
  )
282
  arrivalStation = gr.Dropdown(
283
  label="Gare d'arrivée",
284
- choices=arrivalStations,
285
  value=arrivalStationValue,
286
  )
 
 
 
 
 
287
  with gr.Tab("Dijkstra"):
288
  timeDijkstra = gr.HTML(value=timeDijkstraValue)
289
  dijkstraPath = gr.Textbox(
@@ -313,16 +367,27 @@ def render_tabs(sentences: list[str], model: str, progress_bar: gr.Progress):
313
  departureStation.change(
314
  handleStationChange,
315
  inputs=[departureStation, arrivalStation],
316
- outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
 
 
 
 
 
 
317
  )
318
  arrivalStation.change(
319
  handleStationChange,
320
  inputs=[departureStation, arrivalStation],
321
- outputs=[timeDijkstra, dijkstraPath, timeAStar, AstarPath],
 
 
 
 
 
 
322
  )
323
 
324
  idx += 1
325
- return tabs
326
 
327
 
328
  if __name__ == "__main__":
 
8
  from travel_resolver.libs.pathfinder.CSVTravelGraph import CSVTravelGraph
9
  from travel_resolver.libs.pathfinder.graph import Graph
10
  import time
11
+ import plotly.graph_objects as go
12
 
13
  transcriber = pipeline(
14
  "automatic-speech-recognition", model="openai/whisper-base", device="cpu"
 
17
  models = {"LSTM": LSTM_NER(), "BiLSTM": BiLSTM_NER(), "CamemBERT": CamemBERT_NER()}
18
 
19
 
 
 
 
 
 
 
 
 
 
20
  def handle_audio(audio, model, progress=gr.Progress()):
21
  progress(
22
  0,
23
  )
24
  promptAudio = transcribe(audio)
25
 
26
+ print(f"prompt : {promptAudio}")
27
+
28
  time.sleep(1)
29
 
30
  return render_tabs([promptAudio], model, progress)
 
63
  interactive=True,
64
  )
65
 
66
+ @gr.render(
67
+ inputs=[audio, file, model], triggers=[audio.change, file.upload, model.change]
68
+ )
69
+ def handle_changes(audio, file, model):
70
+ if audio:
71
+ return handle_audio(audio, model)
72
+ elif file:
73
+ return handle_file(file, model)
 
 
 
74
 
75
 
76
  def handleCityChange(city):
77
  stations = getStationsByCityName(city)
78
+ return gr.update(
79
+ choices=[station["Nom de le gare"] for station in stations],
80
+ value=stations[0]["Nom de la gare"],
81
+ interactive=True,
82
+ )
83
 
84
 
85
  def formatPath(path):
86
  return "\n".join([f"{i + 1}. {elem}" for i, elem in enumerate(path)])
87
 
88
 
89
+ def plotMap(stationsInformation: dict):
90
+ stationNames = stationsInformation["stations"] if len(stationsInformation) else []
91
+ stationsLat = stationsInformation["lat"] if len(stationsInformation) else []
92
+ stationsLon = stationsInformation["lon"] if len(stationsInformation) else []
93
+
94
+ plt = go.Figure(
95
+ go.Scattermapbox(
96
+ lat=stationsLat,
97
+ lon=stationsLon,
98
+ mode="markers+lines",
99
+ marker=go.scattermapbox.Marker(size=14),
100
+ text=stationNames,
101
+ )
102
+ )
103
+
104
+ plt.update_layout(
105
+ mapbox_style="open-street-map",
106
+ mapbox=dict(
107
+ center=go.layout.mapbox.Center(lat=stationsLat[0], lon=stationsLon[0]),
108
+ pitch=0,
109
+ zoom=3,
110
+ ),
111
+ )
112
+
113
+ return plt
114
+
115
+
116
  def handleStationChange(departureStation, destinationStation):
117
  if departureStation and destinationStation:
118
  dijkstraPath, dijkstraCost = getDijkstraResult(
 
120
  )
121
  dijkstraPathFormatted = formatPath(dijkstraPath)
122
  AStarPath, AStarCost = getAStarResult(departureStation, destinationStation)
123
+ AStarStationsInformation = getStationsInformation(AStarPath)
124
  AStarPathFormatted = formatPath(AStarPath)
125
  return (
126
  gr.update(value=dijkstraCost),
127
  gr.update(value=dijkstraPathFormatted, lines=len(dijkstraPath)),
128
  gr.update(value=AStarCost),
129
  gr.update(value=AStarPathFormatted, lines=len(AStarPath)),
130
+ plotMap(AStarStationsInformation),
131
  )
132
  return (
133
  gr.HTML(HTML_COMPONENTS.NO_PROMPT.value),
134
  gr.update(value=""),
135
  gr.HTML(HTML_COMPONENTS.NO_PROMPT.value),
136
  gr.update(value=""),
137
+ gr.update(value=plotMap(AStarStationsInformation)),
138
  )
139
 
140
 
 
199
 
200
  def getStationsByCityName(city: str):
201
  data = pd.read_csv("../data/sncf/gares_info.csv", sep=",")
202
+ stations = data[data["Commune"] == city]
203
+ return dict(
204
+ stations=stations["Nom de la gare"].to_list(),
205
+ lat=stations["Latitude"].to_list(),
206
+ lon=stations["Longitude"].to_list(),
207
+ )
208
+
209
+
210
+ def getStationsInformation(stations: list[str]):
211
+ data = pd.read_csv("../data/sncf/gares_info.csv", sep=",")
212
+ data = data[data["Nom de la gare"].isin(stations)]
213
+ print(stations)
214
+ print(data)
215
+ return dict(
216
+ stations=data["Nom de la gare"].to_list(),
217
+ lat=data["Latitude"].to_list(),
218
+ lon=data["Longitude"].to_list(),
219
+ )
220
 
221
 
222
  def getEntitiesPositions(text, entity):
 
266
  for sentence in progress_bar.tqdm(sentences, desc=PROGRESS.PROCESSING.value):
267
  with gr.Tab(f"Sentence {idx}"):
268
  dep, arr = getDepartureAndArrivalFromText(sentence, model)
269
+ print(f"dep: {dep}, arr: {arr}")
270
  entities = []
271
  for entity in [dep, arr]:
272
  if entity:
 
279
  # Get the available stations
280
  departureStations = getStationsByCityName(departureCityValue)
281
  departureStationValue = (
282
+ departureStations["stations"][0]
283
+ if len(departureStations["stations"])
284
+ else ""
285
  )
286
  arrivalStations = getStationsByCityName(arrivalCityValue)
287
+ arrivalStationValue = (
288
+ arrivalStations["stations"][0]
289
+ if len(arrivalStations["stations"])
290
+ else ""
291
+ )
292
 
293
  dijkstraPathValues = []
294
  AStarPathValues = []
 
303
  AStarPathValues, timeAStarValue = getAStarResult(
304
  departureStationValue, arrivalStationValue
305
  )
306
+ AStarStationsInformation = getStationsInformation(AStarPathValues)
307
 
308
  dijkstraPathFormatted = formatPath(dijkstraPathValues)
309
  AStarPathFormatted = formatPath(AStarPathValues)
 
325
  with gr.Row():
326
  departureStation = gr.Dropdown(
327
  label="Gare de départ",
328
+ choices=departureStations["stations"],
329
  value=departureStationValue,
330
  )
331
  arrivalStation = gr.Dropdown(
332
  label="Gare d'arrivée",
333
+ choices=arrivalStations["stations"],
334
  value=arrivalStationValue,
335
  )
336
+
337
+ plt = plotMap(AStarStationsInformation)
338
+
339
+ map = gr.Plot(plt)
340
+
341
  with gr.Tab("Dijkstra"):
342
  timeDijkstra = gr.HTML(value=timeDijkstraValue)
343
  dijkstraPath = gr.Textbox(
 
367
  departureStation.change(
368
  handleStationChange,
369
  inputs=[departureStation, arrivalStation],
370
+ outputs=[
371
+ timeDijkstra,
372
+ dijkstraPath,
373
+ timeAStar,
374
+ AstarPath,
375
+ map,
376
+ ],
377
  )
378
  arrivalStation.change(
379
  handleStationChange,
380
  inputs=[departureStation, arrivalStation],
381
+ outputs=[
382
+ timeDijkstra,
383
+ dijkstraPath,
384
+ timeAStar,
385
+ AstarPath,
386
+ map,
387
+ ],
388
  )
389
 
390
  idx += 1
 
391
 
392
 
393
  if __name__ == "__main__":
test1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Je veux prendre le train de Lyon à Marseille.