Jihuai commited on
Commit
d572f56
·
0 Parent(s):

have to create an orphan branch to bypass large file history: cleanup .ipynb and create LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .github/FUNDING.yml +14 -0
  3. .gitignore +166 -0
  4. .idea/.gitignore +8 -0
  5. .idea/inspectionProfiles/Project_Default.xml +50 -0
  6. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  7. .idea/misc.xml +4 -0
  8. .idea/modules.xml +8 -0
  9. .idea/query-bandit.iml +8 -0
  10. .idea/vcs.xml +6 -0
  11. .vscode/launch.json +21 -0
  12. LICENSE +21 -0
  13. README.md +124 -0
  14. assets/banquet-logo.png +0 -0
  15. config/data/moisesdb-test.yml +55 -0
  16. config/data/setup-a/moisesdb-vdb-query-d-aug.yml +63 -0
  17. config/data/setup-a/moisesdb-vdb-query-d.yml +46 -0
  18. config/data/setup-a/moisesdb-vdb-query.yml +46 -0
  19. config/data/setup-b/moisesdb-vdbgp-query-d-aug-bal.yml +67 -0
  20. config/data/setup-b/moisesdb-vdbgp-query-d-aug.yml +67 -0
  21. config/data/setup-b/moisesdb-vdbgp-query-d.yml +50 -0
  22. config/data/setup-b/moisesdb-vdbgp-query.yml +50 -0
  23. config/data/setup-c/moisesdb-everything-query-d-aug-bal.yml +117 -0
  24. config/data/setup-c/moisesdb-everything-query-d-aug.yml +117 -0
  25. config/data/setup-c/moisesdb-everything-query-d-bal.yml +100 -0
  26. config/data/setup-c/moisesdb-everything-query-d.yml +100 -0
  27. config/data/vdbo/moisesdb-vdbo-aug.yml +35 -0
  28. config/data/vdbo/moisesdb-vdbo.yml +18 -0
  29. config/losses/both_l1snr.yml +4 -0
  30. config/losses/both_l1snrdbm.yml +4 -0
  31. config/models/bandit-query-pre.yml +31 -0
  32. config/models/bandit-query-prefz.yml +31 -0
  33. config/models/bandit-query.yml +29 -0
  34. config/models/bandit-vdbo.yml +27 -0
  35. config/optim/adam.yml +9 -0
  36. config/trainer/default-long.yml +12 -0
  37. config/trainer/default.yml +12 -0
  38. core/__init__.py +0 -0
  39. core/data/__init__.py +0 -0
  40. core/data/base.py +138 -0
  41. core/data/moisesdb/__init__.py +97 -0
  42. core/data/moisesdb/audio.ipynb +76 -0
  43. core/data/moisesdb/datamodule.py +239 -0
  44. core/data/moisesdb/dataset.py +1383 -0
  45. core/data/moisesdb/eda.ipynb +0 -0
  46. core/data/moisesdb/npyify.py +923 -0
  47. core/data/moisesdb/passt.ipynb +32 -0
  48. core/losses/__init__.py +0 -0
  49. core/losses/base.py +171 -0
  50. core/losses/l1snr.py +110 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.github/FUNDING.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: kwatcharasupat # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4
+ patreon: # Replace with a single Patreon username
5
+ open_collective: # Replace with a single Open Collective username
6
+ ko_fi: # Replace with a single Ko-fi username
7
+ tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8
+ community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9
+ liberapay: # Replace with a single Liberapay username
10
+ issuehunt: # Replace with a single IssueHunt username
11
+ lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12
+ polar: # Replace with a single Polar username
13
+ buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
14
+ custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+
163
+ input/
164
+ output/
165
+ logs/
166
+ checkpoints/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="52" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
10
+ <inspection_tool class="PyAttributeOutsideInitInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
11
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
12
+ <option name="ignoredPackages">
13
+ <value>
14
+ <list size="8">
15
+ <item index="0" class="java.lang.String" itemvalue="pytorch_lightning" />
16
+ <item index="1" class="java.lang.String" itemvalue="torch" />
17
+ <item index="2" class="java.lang.String" itemvalue="torchaudio" />
18
+ <item index="3" class="java.lang.String" itemvalue="matplotlib" />
19
+ <item index="4" class="java.lang.String" itemvalue="ipython" />
20
+ <item index="5" class="java.lang.String" itemvalue="numpy" />
21
+ <item index="6" class="java.lang.String" itemvalue="opencv_python" />
22
+ <item index="7" class="java.lang.String" itemvalue="Pillow" />
23
+ </list>
24
+ </value>
25
+ </option>
26
+ </inspection_tool>
27
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
28
+ <option name="ignoredErrors">
29
+ <list>
30
+ <option value="N806" />
31
+ <option value="N812" />
32
+ <option value="N802" />
33
+ <option value="N803" />
34
+ </list>
35
+ </option>
36
+ </inspection_tool>
37
+ <inspection_tool class="PyShadowingBuiltinsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
38
+ <option name="ignoredNames">
39
+ <list>
40
+ <option value="round" />
41
+ </list>
42
+ </option>
43
+ </inspection_tool>
44
+ <inspection_tool class="SpellCheckingInspection" enabled="true" level="TYPO" enabled_by_default="true">
45
+ <option name="processCode" value="false" />
46
+ <option name="processLiterals" value="true" />
47
+ <option name="processComments" value="true" />
48
+ </inspection_tool>
49
+ </profile>
50
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (aa-listening-test-sigsep-gen)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/query-bandit.iml" filepath="$PROJECT_DIR$/.idea/query-bandit.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/query-bandit.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.vscode/launch.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Python Debugger: Remote Attach",
6
+ "type": "debugpy",
7
+ "request": "attach",
8
+ "justMyCode": true,
9
+ "connect": {
10
+ "host": "localhost",
11
+ "port": 5678
12
+ },
13
+ "pathMappings": [
14
+ {
15
+ "localRoot": "${workspaceFolder}",
16
+ "remoteRoot": "."
17
+ }
18
+ ]
19
+ }
20
+ ]
21
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Karn Watcharasupat
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language-Audio Banquet
2
+ <a href='https://github.com/ModistAndrew/query-bandit'><img alt="Static Badge" src="https://img.shields.io/badge/github_repo-lightgrey?logo=github"></a>
3
+ <a href='https://huggingface.co/spaces/chenxie95/Language-Audio-Banquet'><img alt="Static Badge" src="https://img.shields.io/badge/huggingface_space-yellow?logo=huggingface"></a>
4
+
5
+ - Change the query embedding model from PaSST to CLAP, which supports language queries.
6
+
7
+ - Change RNN to Transformer.
8
+
9
+ - Some utility functions for inference.
10
+
11
+ - (TODO) Train on more datasets.
12
+
13
+ ## Model weights
14
+ You need to download the model weights from [huggingface model](https://huggingface.co/chenxie95/Language-Audio-Banquet-ckpt) and put them in `checkpoints/`. `bandit-vdbo-roformer.ckpt` is needed for training. `ev-pre.ckpt` and `ev-pre-aug.ckpt` can be choosen for inference.
15
+
16
+ What's more, you need to download the query embedding model CLAP from [here](https://huggingface.co/lukewys/laion_clap/blob/main/music_speech_epoch_15_esc_89.25.pt) and put it in `checkpoints/querier/`.
17
+
18
+ ## Inference examples
19
+ ```bash
20
+ export CONFIG_ROOT=./config
21
+ python \
22
+ # -m debugpy --listen 5678 --wait-for-client \
23
+ train.py inference_byoq \
24
+ checkpoints/ev-pre-aug.ckpt \
25
+ input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4.wav \
26
+ input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_bass.wav \
27
+ output/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_bass.wav \
28
+ --batch_size=12 \
29
+ --use_cuda=true
30
+
31
+ python \
32
+ train.py inference_byoq_text \
33
+ checkpoints/ev-pre-aug.ckpt \
34
+ input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4.wav \
35
+ piano \
36
+ output/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_piano.wav \
37
+ --batch_size=12 \
38
+ --use_cuda=true
39
+
40
+ python \
41
+ train.py inference_test_folder \
42
+ checkpoints/ev-pre-aug.ckpt \
43
+ /inspire/hdd/project/multilingualspeechrecognition/chenxie-25019/data/karaoke_converted/test \
44
+ output/karaoke \
45
+ bass \
46
+ --batch_size=30 \
47
+ --use_cuda=true \
48
+ --input_name=mixture
49
+ ```
50
+
51
+ ## Training examples
52
+ ```bash
53
+ export CONFIG_ROOT=./config
54
+ # export DATA_ROOT=/inspire/hdd/project/multilingualspeechrecognition/chenxie-25019/data
55
+ # export DATA_ROOT=/dev/shm
56
+ export DATA_ROOT=/inspire/ssd/project/multilingualspeechrecognition/public
57
+ export LOG_ROOT=./logs/ev-pre-aug-bal
58
+ export CUDA_VISIBLE_DEVICES=0
59
+ python \
60
+ train.py train \
61
+ expt/setup-c/bandit-everything-query-pre-d-aug-bal.yml \
62
+ --ckpt_path=logs/ev-pre-aug-bal/e2e/HBRPOI/lightning_logs/version_1/checkpoints/last.ckpt
63
+ # You may modify the batch size in yaml files in config/data/. A batch size of 3 fits on a NVIDIA 4090 (48GB).
64
+ ```
65
+
66
+ ---
67
+
68
+ > ### Please consider giving back to the community if you have benefited from this work.
69
+ >
70
+ > If you've **benefited commercially from this work**, which we've poured significant effort into and released under permissive licenses, we hope you've found it valuable! While these licenses give you lots of freedom, we believe in nurturing a vibrant ecosystem where innovation can continue to flourish.
71
+ >
72
+ > So, as a gesture of appreciation and responsibility, we strongly urge commercial entities that have gained from this software to consider making voluntary contributions to music-related non-profit organizations of your choice. Your contribution directly helps support the foundational work that empowers your commercial success and ensures open-source innovation keeps moving forward.
73
+ >
74
+ > Some suggestions for the beneficiaries are provided [here](https://github.com/the-secret-source/nonprofits). Please do not hesitate to contribute to the list by opening pull requests there.
75
+
76
+ ---
77
+
78
+
79
+ <div align="center">
80
+ <img src="assets/banquet-logo.png">
81
+ </div>
82
+
83
+ # Banquet: A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems
84
+
85
+ Repository for **A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems**
86
+ by Karn N. Watcharasupat and Alexander Lerch. [arXiv](https://arxiv.org/abs/2406.18747)
87
+
88
+ > Despite significant recent progress across multiple subtasks of audio source separation, few music source separation systems support separation beyond the four-stem vocals, drums, bass, and other (VDBO) setup. Of the very few current systems that support source separation beyond this setup, most continue to rely on an inflexible decoder setup that can only support a fixed pre-defined set of stems. Increasing stem support in these inflexible systems correspondingly requires increasing computational complexity, rendering extensions of these systems computationally infeasible for long-tail instruments. In this work, we propose Banquet, a system that allows source separation of multiple stems using just one decoder. A bandsplit source separation model is extended to work in a query-based setup in tandem with a music instrument recognition PaSST model. On the MoisesDB dataset, Banquet, at only 24.9 M trainable parameters, approached the performance level of the significantly more complex 6-stem Hybrid Transformer Demucs on VDBO stems and outperformed it on guitar and piano. The query-based setup allows for the separation of narrow instrument classes such as clean acoustic guitars, and can be successfully applied to the extraction of less common stems such as reeds and organs.
89
+
90
+ For the Cinematic Audio Source Separation model, Bandit, see [this repository](https://github.com/kwatcharasupat/bandit).
91
+
92
+ ## Inference
93
+
94
+ ```bash
95
+ git clone https://github.com/kwatcharasupat/query-bandit.git
96
+ cd query-bandit
97
+ export CONFIG_ROOT="./config"
98
+
99
+ python train.py inference_byoq \
100
+ --ckpt_path="/path/to/checkpoint/see-below.ckpt" \
101
+ --input_path="/path/to/input/file/fearOfMatlab.wav" \
102
+ --output_path="/path/to/output/file/fearOfMatlabStemEst/guitar.wav" \
103
+ --query_path="/path/to/query/file/random-guitar.wav" \
104
+ --batch_size=12 \
105
+ --use_cuda=true
106
+ ```
107
+ Batch size of 12 _usually_ fits on a RTX 4090.
108
+
109
+ ### Model weights
110
+ Model weights are available on Zenodo [here](https://zenodo.org/records/13694558).
111
+ If you are not sure, use `ev-pre-aug.ckpt`.
112
+
113
+ ## Citation
114
+ ```
115
+ @inproceedings{Watcharasupat2024Banquet,
116
+ title = {A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems},
117
+ booktitle = {To Appear in the Proceedings of the 25th International Society for Music Information Retrieval},
118
+ author = {Watcharasupat, Karn N. and Lerch, Alexander},
119
+ year = {2024},
120
+ month = {nov},
121
+ eprint = {2406.18747},
122
+ address = {San Francisco, CA, USA},
123
+ }
124
+ ```
assets/banquet-logo.png ADDED
config/data/moisesdb-test.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesTestDataModule
3
+ batch_size: 1
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+
7
+ inference_kwargs:
8
+ chunk_size_seconds: 6.0
9
+ hop_size_seconds: 0.5
10
+ batch_size: 12
11
+ fs: 44100
12
+
13
+ test_kwargs:
14
+ npy_memmap: true
15
+ mixture_stem: mixture
16
+ use_own_query: false
17
+ allowed_stems: [
18
+ "drums",
19
+ "lead_male_singer",
20
+ "lead_female_singer",
21
+ # "human_choir",
22
+ "background_vocals",
23
+ # "other_vocals",
24
+ "bass_guitar",
25
+ "bass_synthesizer",
26
+ # "contrabass_double_bass",
27
+ # "tuba",
28
+ # "bassoon",
29
+ "fx",
30
+ "clean_electric_guitar",
31
+ "distorted_electric_guitar",
32
+ # "lap_steel_guitar_or_slide_guitar",
33
+ "acoustic_guitar",
34
+ "other_plucked",
35
+ "pitched_percussion",
36
+ "grand_piano",
37
+ "electric_piano",
38
+ "organ_electric_organ",
39
+ "synth_pad",
40
+ "synth_lead",
41
+ # "violin",
42
+ # "viola",
43
+ # "cello",
44
+ # "violin_section",
45
+ # "viola_section",
46
+ # "cello_section",
47
+ "string_section",
48
+ "other_strings",
49
+ "brass",
50
+ # "flutes",
51
+ "reeds",
52
+ "other_wind"
53
+ ]
54
+ query_file: "query-10s"
55
+ n_channels: 2
config/data/setup-a/moisesdb-vdb-query-d-aug.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ # "distorted_electric_guitar",
21
+ # "clean_electric_guitar",
22
+ # "acoustic_guitar",
23
+ ]
24
+ query_file: "query-10s"
25
+ augment:
26
+ - cls: Shift
27
+ kwargs:
28
+ p: 1.0
29
+ min_shift: -0.5
30
+ max_shift: 0.5
31
+ - cls: Gain
32
+ kwargs:
33
+ p: 1.0
34
+ min_gain_in_db: -6
35
+ max_gain_in_db: 6
36
+ - cls: ShuffleChannels
37
+ kwargs:
38
+ p: 0.5
39
+ - cls: PolarityInversion
40
+ kwargs:
41
+ p: 0.5
42
+ val_kwargs:
43
+ chunk_size_seconds: 6.0
44
+ hop_size_seconds: 6.0
45
+ query_size_seconds: 10.0
46
+ top_k_instrument: 10
47
+ npy_memmap: true
48
+ mixture_stem: mixture
49
+ use_own_query: false
50
+ allowed_stems:
51
+ [
52
+ "bass",
53
+ "drums",
54
+ "lead_male_singer",
55
+ "lead_female_singer",
56
+ # "distorted_electric_guitar",
57
+ # "clean_electric_guitar",
58
+ # "acoustic_guitar",
59
+ ]
60
+ query_file: "query-10s"
61
+ test_kwargs:
62
+ npy_memmap: true
63
+ n_channels: 2
config/data/setup-a/moisesdb-vdb-query-d.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ # "distorted_electric_guitar",
21
+ # "clean_electric_guitar",
22
+ # "acoustic_guitar",
23
+ ]
24
+ query_file: "query-10s"
25
+ val_kwargs:
26
+ chunk_size_seconds: 6.0
27
+ hop_size_seconds: 6.0
28
+ query_size_seconds: 10.0
29
+ top_k_instrument: 10
30
+ npy_memmap: true
31
+ mixture_stem: mixture
32
+ use_own_query: false
33
+ allowed_stems:
34
+ [
35
+ "bass",
36
+ "drums",
37
+ "lead_male_singer",
38
+ "lead_female_singer",
39
+ # "distorted_electric_guitar",
40
+ # "clean_electric_guitar",
41
+ # "acoustic_guitar",
42
+ ]
43
+ query_file: "query-10s"
44
+ test_kwargs:
45
+ npy_memmap: true
46
+ n_channels: 2
config/data/setup-a/moisesdb-vdb-query.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: true
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ # "distorted_electric_guitar",
21
+ # "clean_electric_guitar",
22
+ # "acoustic_guitar",
23
+ ]
24
+ query_file: "query-10s"
25
+ val_kwargs:
26
+ chunk_size_seconds: 6.0
27
+ hop_size_seconds: 6.0
28
+ query_size_seconds: 10.0
29
+ top_k_instrument: 10
30
+ npy_memmap: true
31
+ mixture_stem: mixture
32
+ use_own_query: true
33
+ allowed_stems:
34
+ [
35
+ "bass",
36
+ "drums",
37
+ "lead_male_singer",
38
+ "lead_female_singer",
39
+ # "distorted_electric_guitar",
40
+ # "clean_electric_guitar",
41
+ # "acoustic_guitar",
42
+ ]
43
+ query_file: "query-10s"
44
+ test_kwargs:
45
+ npy_memmap: true
46
+ n_channels: 2
config/data/setup-b/moisesdb-vdbgp-query-d-aug-bal.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesBalancedTrainDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ "distorted_electric_guitar",
21
+ "clean_electric_guitar",
22
+ "acoustic_guitar",
23
+ 'grand_piano',
24
+ 'electric_piano',
25
+ ]
26
+ query_file: "query-10s"
27
+ augment:
28
+ - cls: Shift
29
+ kwargs:
30
+ p: 1.0
31
+ min_shift: -0.5
32
+ max_shift: 0.5
33
+ - cls: Gain
34
+ kwargs:
35
+ p: 1.0
36
+ min_gain_in_db: -6
37
+ max_gain_in_db: 6
38
+ - cls: ShuffleChannels
39
+ kwargs:
40
+ p: 0.5
41
+ - cls: PolarityInversion
42
+ kwargs:
43
+ p: 0.5
44
+ val_kwargs:
45
+ chunk_size_seconds: 6.0
46
+ hop_size_seconds: 6.0
47
+ query_size_seconds: 10.0
48
+ top_k_instrument: 10
49
+ npy_memmap: true
50
+ mixture_stem: mixture
51
+ use_own_query: false
52
+ allowed_stems:
53
+ [
54
+ "bass",
55
+ "drums",
56
+ "lead_male_singer",
57
+ "lead_female_singer",
58
+ "distorted_electric_guitar",
59
+ "clean_electric_guitar",
60
+ "acoustic_guitar",
61
+ 'grand_piano',
62
+ 'electric_piano',
63
+ ]
64
+ query_file: "query-10s"
65
+ test_kwargs:
66
+ npy_memmap: true
67
+ n_channels: 2
config/data/setup-b/moisesdb-vdbgp-query-d-aug.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ "distorted_electric_guitar",
21
+ "clean_electric_guitar",
22
+ "acoustic_guitar",
23
+ 'grand_piano',
24
+ 'electric_piano',
25
+ ]
26
+ query_file: "query-10s"
27
+ augment:
28
+ - cls: Shift
29
+ kwargs:
30
+ p: 1.0
31
+ min_shift: -0.5
32
+ max_shift: 0.5
33
+ - cls: Gain
34
+ kwargs:
35
+ p: 1.0
36
+ min_gain_in_db: -6
37
+ max_gain_in_db: 6
38
+ - cls: ShuffleChannels
39
+ kwargs:
40
+ p: 0.5
41
+ - cls: PolarityInversion
42
+ kwargs:
43
+ p: 0.5
44
+ val_kwargs:
45
+ chunk_size_seconds: 6.0
46
+ hop_size_seconds: 6.0
47
+ query_size_seconds: 10.0
48
+ top_k_instrument: 10
49
+ npy_memmap: true
50
+ mixture_stem: mixture
51
+ use_own_query: false
52
+ allowed_stems:
53
+ [
54
+ "bass",
55
+ "drums",
56
+ "lead_male_singer",
57
+ "lead_female_singer",
58
+ "distorted_electric_guitar",
59
+ "clean_electric_guitar",
60
+ "acoustic_guitar",
61
+ 'grand_piano',
62
+ 'electric_piano',
63
+ ]
64
+ query_file: "query-10s"
65
+ test_kwargs:
66
+ npy_memmap: true
67
+ n_channels: 2
config/data/setup-b/moisesdb-vdbgp-query-d.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ "distorted_electric_guitar",
21
+ "clean_electric_guitar",
22
+ "acoustic_guitar",
23
+ 'grand_piano',
24
+ 'electric_piano',
25
+ ]
26
+ query_file: "query-10s"
27
+ val_kwargs:
28
+ chunk_size_seconds: 6.0
29
+ hop_size_seconds: 6.0
30
+ query_size_seconds: 10.0
31
+ top_k_instrument: 10
32
+ npy_memmap: true
33
+ mixture_stem: mixture
34
+ use_own_query: false
35
+ allowed_stems:
36
+ [
37
+ "bass",
38
+ "drums",
39
+ "lead_male_singer",
40
+ "lead_female_singer",
41
+ "distorted_electric_guitar",
42
+ "clean_electric_guitar",
43
+ "acoustic_guitar",
44
+ 'grand_piano',
45
+ 'electric_piano',
46
+ ]
47
+ query_file: "query-10s"
48
+ test_kwargs:
49
+ npy_memmap: true
50
+ n_channels: 2
config/data/setup-b/moisesdb-vdbgp-query.yml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: true
14
+ allowed_stems:
15
+ [
16
+ "bass",
17
+ "drums",
18
+ "lead_male_singer",
19
+ "lead_female_singer",
20
+ "distorted_electric_guitar",
21
+ "clean_electric_guitar",
22
+ "acoustic_guitar",
23
+ 'grand_piano',
24
+ 'electric_piano',
25
+ ]
26
+ query_file: "query-10s"
27
+ val_kwargs:
28
+ chunk_size_seconds: 6.0
29
+ hop_size_seconds: 6.0
30
+ query_size_seconds: 10.0
31
+ top_k_instrument: 10
32
+ npy_memmap: true
33
+ mixture_stem: mixture
34
+ use_own_query: true
35
+ allowed_stems:
36
+ [
37
+ "bass",
38
+ "drums",
39
+ "lead_male_singer",
40
+ "lead_female_singer",
41
+ "distorted_electric_guitar",
42
+ "clean_electric_guitar",
43
+ "acoustic_guitar",
44
+ 'grand_piano',
45
+ 'electric_piano',
46
+ ]
47
+ query_file: "query-10s"
48
+ test_kwargs:
49
+ npy_memmap: true
50
+ n_channels: 2
config/data/setup-c/moisesdb-everything-query-d-aug-bal.yml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesBalancedTrainDataModule
3
+ batch_size: 3
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems: [
15
+ "drums",
16
+ "lead_male_singer",
17
+ "lead_female_singer",
18
+ # "human_choir",
19
+ "background_vocals",
20
+ # "other_vocals",
21
+ "bass_guitar",
22
+ "bass_synthesizer",
23
+ # "contrabass_double_bass",
24
+ # "tuba",
25
+ # "bassoon",
26
+ "fx",
27
+ "clean_electric_guitar",
28
+ "distorted_electric_guitar",
29
+ # "lap_steel_guitar_or_slide_guitar",
30
+ "acoustic_guitar",
31
+ "other_plucked",
32
+ "pitched_percussion",
33
+ "grand_piano",
34
+ "electric_piano",
35
+ "organ_electric_organ",
36
+ "synth_pad",
37
+ "synth_lead",
38
+ # "violin",
39
+ # "viola",
40
+ # "cello",
41
+ # "violin_section",
42
+ # "viola_section",
43
+ # "cello_section",
44
+ "string_section",
45
+ "other_strings",
46
+ "brass",
47
+ # "flutes",
48
+ "reeds",
49
+ "other_wind"
50
+ ]
51
+ query_file: "query-10s"
52
+ augment:
53
+ - cls: Shift
54
+ kwargs:
55
+ p: 1.0
56
+ min_shift: -0.5
57
+ max_shift: 0.5
58
+ - cls: Gain
59
+ kwargs:
60
+ p: 1.0
61
+ min_gain_in_db: -6
62
+ max_gain_in_db: 6
63
+ - cls: ShuffleChannels
64
+ kwargs:
65
+ p: 0.5
66
+ - cls: PolarityInversion
67
+ kwargs:
68
+ p: 0.5
69
+ val_kwargs:
70
+ chunk_size_seconds: 6.0
71
+ hop_size_seconds: 6.0
72
+ query_size_seconds: 10.0
73
+ top_k_instrument: 10
74
+ npy_memmap: true
75
+ mixture_stem: mixture
76
+ use_own_query: false
77
+ allowed_stems: [
78
+ "drums",
79
+ "lead_male_singer",
80
+ "lead_female_singer",
81
+ # "human_choir",
82
+ "background_vocals",
83
+ # "other_vocals",
84
+ "bass_guitar",
85
+ "bass_synthesizer",
86
+ # "contrabass_double_bass",
87
+ # "tuba",
88
+ # "bassoon",
89
+ "fx",
90
+ "clean_electric_guitar",
91
+ "distorted_electric_guitar",
92
+ # "lap_steel_guitar_or_slide_guitar",
93
+ "acoustic_guitar",
94
+ "other_plucked",
95
+ "pitched_percussion",
96
+ "grand_piano",
97
+ "electric_piano",
98
+ "organ_electric_organ",
99
+ "synth_pad",
100
+ "synth_lead",
101
+ # "violin",
102
+ # "viola",
103
+ # "cello",
104
+ # "violin_section",
105
+ # "viola_section",
106
+ # "cello_section",
107
+ "string_section",
108
+ "other_strings",
109
+ "brass",
110
+ # "flutes",
111
+ "reeds",
112
+ "other_wind"
113
+ ]
114
+ query_file: "query-10s"
115
+ test_kwargs:
116
+ npy_memmap: true
117
+ n_channels: 2
config/data/setup-c/moisesdb-everything-query-d-aug.yml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 3
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems: [
15
+ "drums",
16
+ "lead_male_singer",
17
+ "lead_female_singer",
18
+ # "human_choir",
19
+ "background_vocals",
20
+ # "other_vocals",
21
+ "bass_guitar",
22
+ "bass_synthesizer",
23
+ # "contrabass_double_bass",
24
+ # "tuba",
25
+ # "bassoon",
26
+ "fx",
27
+ "clean_electric_guitar",
28
+ "distorted_electric_guitar",
29
+ # "lap_steel_guitar_or_slide_guitar",
30
+ "acoustic_guitar",
31
+ "other_plucked",
32
+ "pitched_percussion",
33
+ "grand_piano",
34
+ "electric_piano",
35
+ "organ_electric_organ",
36
+ "synth_pad",
37
+ "synth_lead",
38
+ # "violin",
39
+ # "viola",
40
+ # "cello",
41
+ # "violin_section",
42
+ # "viola_section",
43
+ # "cello_section",
44
+ "string_section",
45
+ "other_strings",
46
+ "brass",
47
+ # "flutes",
48
+ "reeds",
49
+ "other_wind"
50
+ ]
51
+ query_file: "query-10s"
52
+ augment:
53
+ - cls: Shift
54
+ kwargs:
55
+ p: 1.0
56
+ min_shift: -0.5
57
+ max_shift: 0.5
58
+ - cls: Gain
59
+ kwargs:
60
+ p: 1.0
61
+ min_gain_in_db: -6
62
+ max_gain_in_db: 6
63
+ - cls: ShuffleChannels
64
+ kwargs:
65
+ p: 0.5
66
+ - cls: PolarityInversion
67
+ kwargs:
68
+ p: 0.5
69
+ val_kwargs:
70
+ chunk_size_seconds: 6.0
71
+ hop_size_seconds: 6.0
72
+ query_size_seconds: 10.0
73
+ top_k_instrument: 10
74
+ npy_memmap: true
75
+ mixture_stem: mixture
76
+ use_own_query: false
77
+ allowed_stems: [
78
+ "drums",
79
+ "lead_male_singer",
80
+ "lead_female_singer",
81
+ # "human_choir",
82
+ "background_vocals",
83
+ # "other_vocals",
84
+ "bass_guitar",
85
+ "bass_synthesizer",
86
+ # "contrabass_double_bass",
87
+ # "tuba",
88
+ # "bassoon",
89
+ "fx",
90
+ "clean_electric_guitar",
91
+ "distorted_electric_guitar",
92
+ # "lap_steel_guitar_or_slide_guitar",
93
+ "acoustic_guitar",
94
+ "other_plucked",
95
+ "pitched_percussion",
96
+ "grand_piano",
97
+ "electric_piano",
98
+ "organ_electric_organ",
99
+ "synth_pad",
100
+ "synth_lead",
101
+ # "violin",
102
+ # "viola",
103
+ # "cello",
104
+ # "violin_section",
105
+ # "viola_section",
106
+ # "cello_section",
107
+ "string_section",
108
+ "other_strings",
109
+ "brass",
110
+ # "flutes",
111
+ "reeds",
112
+ "other_wind"
113
+ ]
114
+ query_file: "query-10s"
115
+ test_kwargs:
116
+ npy_memmap: true
117
+ n_channels: 2
config/data/setup-c/moisesdb-everything-query-d-bal.yml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesBalancedTrainDataModule
3
+ batch_size: 3
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems: [
15
+ "drums",
16
+ "lead_male_singer",
17
+ "lead_female_singer",
18
+ # "human_choir",
19
+ "background_vocals",
20
+ # "other_vocals",
21
+ "bass_guitar",
22
+ "bass_synthesizer",
23
+ # "contrabass_double_bass",
24
+ # "tuba",
25
+ # "bassoon",
26
+ "fx",
27
+ "clean_electric_guitar",
28
+ "distorted_electric_guitar",
29
+ # "lap_steel_guitar_or_slide_guitar",
30
+ "acoustic_guitar",
31
+ "other_plucked",
32
+ "pitched_percussion",
33
+ "grand_piano",
34
+ "electric_piano",
35
+ "organ_electric_organ",
36
+ "synth_pad",
37
+ "synth_lead",
38
+ # "violin",
39
+ # "viola",
40
+ # "cello",
41
+ # "violin_section",
42
+ # "viola_section",
43
+ # "cello_section",
44
+ "string_section",
45
+ "other_strings",
46
+ "brass",
47
+ # "flutes",
48
+ "reeds",
49
+ "other_wind"
50
+ ]
51
+ query_file: "query-10s"
52
+ val_kwargs:
53
+ chunk_size_seconds: 6.0
54
+ hop_size_seconds: 6.0
55
+ query_size_seconds: 10.0
56
+ top_k_instrument: 10
57
+ npy_memmap: true
58
+ mixture_stem: mixture
59
+ use_own_query: false
60
+ allowed_stems: [
61
+ "drums",
62
+ "lead_male_singer",
63
+ "lead_female_singer",
64
+ # "human_choir",
65
+ "background_vocals",
66
+ # "other_vocals",
67
+ "bass_guitar",
68
+ "bass_synthesizer",
69
+ # "contrabass_double_bass",
70
+ # "tuba",
71
+ # "bassoon",
72
+ "fx",
73
+ "clean_electric_guitar",
74
+ "distorted_electric_guitar",
75
+ # "lap_steel_guitar_or_slide_guitar",
76
+ "acoustic_guitar",
77
+ "other_plucked",
78
+ "pitched_percussion",
79
+ "grand_piano",
80
+ "electric_piano",
81
+ "organ_electric_organ",
82
+ "synth_pad",
83
+ "synth_lead",
84
+ # "violin",
85
+ # "viola",
86
+ # "cello",
87
+ # "violin_section",
88
+ # "viola_section",
89
+ # "cello_section",
90
+ "string_section",
91
+ "other_strings",
92
+ "brass",
93
+ # "flutes",
94
+ "reeds",
95
+ "other_wind"
96
+ ]
97
+ query_file: "query-10s"
98
+ test_kwargs:
99
+ npy_memmap: true
100
+ n_channels: 2
config/data/setup-c/moisesdb-everything-query-d.yml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesDataModule
3
+ batch_size: 3
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ query_size_seconds: 10.0
10
+ top_k_instrument: 10
11
+ npy_memmap: true
12
+ mixture_stem: mixture
13
+ use_own_query: false
14
+ allowed_stems: [
15
+ "drums",
16
+ "lead_male_singer",
17
+ "lead_female_singer",
18
+ # "human_choir",
19
+ "background_vocals",
20
+ # "other_vocals",
21
+ "bass_guitar",
22
+ "bass_synthesizer",
23
+ # "contrabass_double_bass",
24
+ # "tuba",
25
+ # "bassoon",
26
+ "fx",
27
+ "clean_electric_guitar",
28
+ "distorted_electric_guitar",
29
+ # "lap_steel_guitar_or_slide_guitar",
30
+ "acoustic_guitar",
31
+ "other_plucked",
32
+ "pitched_percussion",
33
+ "grand_piano",
34
+ "electric_piano",
35
+ "organ_electric_organ",
36
+ "synth_pad",
37
+ "synth_lead",
38
+ # "violin",
39
+ # "viola",
40
+ # "cello",
41
+ # "violin_section",
42
+ # "viola_section",
43
+ # "cello_section",
44
+ "string_section",
45
+ "other_strings",
46
+ "brass",
47
+ # "flutes",
48
+ "reeds",
49
+ "other_wind"
50
+ ]
51
+ query_file: "query-10s"
52
+ val_kwargs:
53
+ chunk_size_seconds: 6.0
54
+ hop_size_seconds: 6.0
55
+ query_size_seconds: 10.0
56
+ top_k_instrument: 10
57
+ npy_memmap: true
58
+ mixture_stem: mixture
59
+ use_own_query: false
60
+ allowed_stems: [
61
+ "drums",
62
+ "lead_male_singer",
63
+ "lead_female_singer",
64
+ # "human_choir",
65
+ "background_vocals",
66
+ # "other_vocals",
67
+ "bass_guitar",
68
+ "bass_synthesizer",
69
+ # "contrabass_double_bass",
70
+ # "tuba",
71
+ # "bassoon",
72
+ "fx",
73
+ "clean_electric_guitar",
74
+ "distorted_electric_guitar",
75
+ # "lap_steel_guitar_or_slide_guitar",
76
+ "acoustic_guitar",
77
+ "other_plucked",
78
+ "pitched_percussion",
79
+ "grand_piano",
80
+ "electric_piano",
81
+ "organ_electric_organ",
82
+ "synth_pad",
83
+ "synth_lead",
84
+ # "violin",
85
+ # "viola",
86
+ # "cello",
87
+ # "violin_section",
88
+ # "viola_section",
89
+ # "cello_section",
90
+ "string_section",
91
+ "other_strings",
92
+ "brass",
93
+ # "flutes",
94
+ "reeds",
95
+ "other_wind"
96
+ ]
97
+ query_file: "query-10s"
98
+ test_kwargs:
99
+ npy_memmap: true
100
+ n_channels: 2
config/data/vdbo/moisesdb-vdbo-aug.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesVDBODataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ fs: 44100
10
+ npy_memmap: true
11
+ augment:
12
+ - cls: Shift
13
+ kwargs:
14
+ p: 1.0
15
+ min_shift: -0.5
16
+ max_shift: 0.5
17
+ - cls: Gain
18
+ kwargs:
19
+ p: 1.0
20
+ min_gain_in_db: -6
21
+ max_gain_in_db: 6
22
+ - cls: ShuffleChannels
23
+ kwargs:
24
+ p: 0.5
25
+ - cls: PolarityInversion
26
+ kwargs:
27
+ p: 0.5
28
+ val_kwargs:
29
+ chunk_size_seconds: 6.0
30
+ hop_size_seconds: 6.0
31
+ fs: 44100
32
+ npy_memmap: true
33
+ test_kwargs:
34
+ npy_memmap: true
35
+ n_channels: 2
config/data/vdbo/moisesdb-vdbo.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: ${oc.env:DATA_ROOT}/moisesdb
2
+ cls: MoisesVDBODataModule
3
+ batch_size: 4
4
+ effective_batch_size: null
5
+ num_workers: 8
6
+ train_kwargs:
7
+ target_length: 8192
8
+ chunk_size_seconds: 6.0
9
+ fs: 44100
10
+ npy_memmap: true
11
+ val_kwargs:
12
+ chunk_size_seconds: 6.0
13
+ hop_size_seconds: 6.0
14
+ fs: 44100
15
+ npy_memmap: true
16
+ test_kwargs:
17
+ npy_memmap: true
18
+ n_channels: 2
config/losses/both_l1snr.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ cls: L1SNRLoss
2
+ modality:
3
+ - audio
4
+ - spectrogram
config/losses/both_l1snrdbm.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ cls: L1SNRDecibelMatchLoss
2
+ modality:
3
+ - audio
4
+ - spectrogram
config/models/bandit-query-pre.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cls: PasstFiLMConditionedBandit
2
+ kwargs:
3
+ in_channel: 2
4
+ band_type: "musical"
5
+ n_bands: 64
6
+ additive_film: true
7
+ multiplicative_film: true
8
+ film_depth: 2
9
+ n_sqm_modules: 8
10
+ emb_dim: 128
11
+ rnn_dim: 256
12
+ bidirectional: true
13
+ rnn_type: "GRU"
14
+ mlp_dim: 512
15
+ hidden_activation: "Tanh"
16
+ hidden_activation_kwargs: null
17
+ complex_mask: true
18
+ use_freq_weights: true
19
+ n_fft: 2048
20
+ win_length: 2048
21
+ hop_length: 512
22
+ window_fn: "hann_window"
23
+ wkwargs: null
24
+ power: null
25
+ center: true
26
+ normalized: true
27
+ pad_mode: "reflect"
28
+ onesided: true
29
+ fs: 44100
30
+ pretrain_encoder: checkpoints/bandit-vdbo-roformer.ckpt
31
+ freeze_encoder: false
config/models/bandit-query-prefz.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cls: PasstFiLMConditionedBandit
2
+ kwargs:
3
+ in_channel: 2
4
+ band_type: "musical"
5
+ n_bands: 64
6
+ additive_film: true
7
+ multiplicative_film: true
8
+ film_depth: 2
9
+ n_sqm_modules: 8
10
+ emb_dim: 128
11
+ rnn_dim: 256
12
+ bidirectional: true
13
+ rnn_type: "GRU"
14
+ mlp_dim: 512
15
+ hidden_activation: "Tanh"
16
+ hidden_activation_kwargs: null
17
+ complex_mask: true
18
+ use_freq_weights: true
19
+ n_fft: 2048
20
+ win_length: 2048
21
+ hop_length: 512
22
+ window_fn: "hann_window"
23
+ wkwargs: null
24
+ power: null
25
+ center: true
26
+ normalized: true
27
+ pad_mode: "reflect"
28
+ onesided: true
29
+ fs: 44100
30
+ pretrain_encoder: checkpoints/bandit-vdbo-roformer.ckpt
31
+ freeze_encoder: true
config/models/bandit-query.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cls: PasstFiLMConditionedBandit
2
+ kwargs:
3
+ in_channel: 2
4
+ band_type: "musical"
5
+ n_bands: 64
6
+ additive_film: true
7
+ multiplicative_film: true
8
+ film_depth: 2
9
+ n_sqm_modules: 8
10
+ emb_dim: 128
11
+ rnn_dim: 256
12
+ bidirectional: true
13
+ rnn_type: "GRU"
14
+ mlp_dim: 512
15
+ hidden_activation: "Tanh"
16
+ hidden_activation_kwargs: null
17
+ complex_mask: true
18
+ use_freq_weights: true
19
+ n_fft: 2048
20
+ win_length: 2048
21
+ hop_length: 512
22
+ window_fn: "hann_window"
23
+ wkwargs: null
24
+ power: null
25
+ center: true
26
+ normalized: true
27
+ pad_mode: "reflect"
28
+ onesided: true
29
+ fs: 44100
config/models/bandit-vdbo.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cls: Bandit
2
+ kwargs:
3
+ in_channel: 2
4
+ stems: ["vocals", "bass", "drums", "vdbo_others"]
5
+ band_type: "musical"
6
+ n_bands: 64
7
+ n_sqm_modules: 8
8
+ emb_dim: 128
9
+ rnn_dim: 256
10
+ bidirectional: true
11
+ rnn_type: "GRU"
12
+ mlp_dim: 512
13
+ hidden_activation: "Tanh"
14
+ hidden_activation_kwargs: null
15
+ complex_mask: true
16
+ use_freq_weights: true
17
+ n_fft: 2048
18
+ win_length: 2048
19
+ hop_length: 512
20
+ window_fn: "hann_window"
21
+ wkwargs: null
22
+ power: null
23
+ center: true
24
+ normalized: true
25
+ pad_mode: "reflect"
26
+ onesided: true
27
+ fs: 44100
config/optim/adam.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ optimizer:
2
+ cls: Adam
3
+ kwargs:
4
+ lr: 1.0e-3
5
+ scheduler:
6
+ cls: StepLR
7
+ kwargs:
8
+ step_size: 1
9
+ gamma: 0.98
config/trainer/default-long.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ callbacks:
2
+ checkpoint:
3
+ monitor: val/loss
4
+ mode: min
5
+ save_top_k: 3
6
+ save_last: True
7
+ max_epochs: 500
8
+ accumulate_grad_batches: null
9
+ gradient_clip_val: 10.0
10
+ gradient_clip_algorithm: norm
11
+ logger:
12
+ save_dir: ${oc.env:LOG_ROOT}/e2e
config/trainer/default.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ callbacks:
2
+ checkpoint:
3
+ monitor: val/loss
4
+ mode: min
5
+ save_top_k: 3
6
+ save_last: True
7
+ max_epochs: 150
8
+ accumulate_grad_batches: null
9
+ gradient_clip_val: 10.0
10
+ gradient_clip_algorithm: norm
11
+ logger:
12
+ save_dir: ${oc.env:LOG_ROOT}/e2e
core/__init__.py ADDED
File without changes
core/data/__init__.py ADDED
File without changes
core/data/base.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchaudio as ta
9
+ from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
10
+ from torch.utils import data
11
+
12
+ from pytorch_lightning import LightningDataModule
13
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
14
+
15
+
16
+ def from_datasets(
17
+ train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
18
+ val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
19
+ test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
20
+ predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
21
+ batch_size: int = 1,
22
+ num_workers: int = 0,
23
+ **datamodule_kwargs: Any,
24
+ ) -> "LightningDataModule":
25
+
26
+ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
27
+ shuffle &= not isinstance(ds, IterableDataset)
28
+ return DataLoader(
29
+ ds,
30
+ batch_size=batch_size,
31
+ shuffle=shuffle,
32
+ num_workers=num_workers,
33
+ pin_memory=True,
34
+ prefetch_factor=4,
35
+ persistent_workers=True,
36
+ )
37
+
38
+ def train_dataloader() -> TRAIN_DATALOADERS:
39
+ assert train_dataset
40
+
41
+ if isinstance(train_dataset, Mapping):
42
+ return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
43
+ if isinstance(train_dataset, Sequence):
44
+ return [dataloader(ds, shuffle=True) for ds in train_dataset]
45
+ return dataloader(train_dataset, shuffle=True)
46
+
47
+ def val_dataloader() -> EVAL_DATALOADERS:
48
+ assert val_dataset
49
+
50
+ if isinstance(val_dataset, Sequence):
51
+ return [dataloader(ds) for ds in val_dataset]
52
+ return dataloader(val_dataset)
53
+
54
+ def test_dataloader() -> EVAL_DATALOADERS:
55
+ assert test_dataset
56
+
57
+ if isinstance(test_dataset, Sequence):
58
+ return [dataloader(ds) for ds in test_dataset]
59
+ return dataloader(test_dataset)
60
+
61
+ def predict_dataloader() -> EVAL_DATALOADERS:
62
+ assert predict_dataset
63
+
64
+ if isinstance(predict_dataset, Sequence):
65
+ return [dataloader(ds) for ds in predict_dataset]
66
+ return dataloader(predict_dataset)
67
+
68
+ candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
69
+ accepted_params = inspect.signature(LightningDataModule.__init__).parameters
70
+ accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
71
+ if accepts_kwargs:
72
+ special_kwargs = candidate_kwargs
73
+ else:
74
+ accepted_param_names = set(accepted_params)
75
+ accepted_param_names.discard("self")
76
+ special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}
77
+
78
+ datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs)
79
+ if train_dataset is not None:
80
+ datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
81
+ if val_dataset is not None:
82
+ datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
83
+ if test_dataset is not None:
84
+ datamodule.test_dataloader = test_dataloader # type: ignore[method-assign]
85
+ if predict_dataset is not None:
86
+ datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign]
87
+
88
+ return datamodule
89
+
90
+
91
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
92
+ def __init__(
93
+ self,
94
+ split: str,
95
+ stems: List[str],
96
+ files: List[str],
97
+ data_path: str,
98
+ fs: int,
99
+ npy_memmap: bool,
100
+ recompute_mixture: bool,
101
+ ):
102
+ if "mixture" not in stems:
103
+ stems = ["mixture"] + stems
104
+
105
+ self.split = split
106
+ self.stems = stems
107
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
108
+ self.files = files
109
+ self.data_path = data_path
110
+ self.fs = fs
111
+ self.npy_memmap = npy_memmap
112
+ self.recompute_mixture = recompute_mixture
113
+
114
+ @abstractmethod
115
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
116
+ raise NotImplementedError
117
+
118
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
119
+ audio = {}
120
+ for stem in stems:
121
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
122
+
123
+ return audio
124
+
125
+ def get_audio(self, identifier: Dict[str, Any]):
126
+ if self.recompute_mixture:
127
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
128
+ audio["mixture"] = self.compute_mixture(audio)
129
+ return audio
130
+ else:
131
+ return self._get_audio(self.stems, identifier=identifier)
132
+
133
+ @abstractmethod
134
+ def get_identifier(self, index: int) -> Dict[str, Any]:
135
+ pass
136
+
137
+ def compute_mixture(self, audio) -> torch.Tensor:
138
+ return sum(audio[stem] for stem in audio if stem != "mixture")
core/data/moisesdb/__init__.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ taxonomy = {
2
+ "vocals": [
3
+ "lead male singer",
4
+ "lead female singer",
5
+ "human choir",
6
+ "background vocals",
7
+ "other (vocoder, beatboxing etc)",
8
+ ],
9
+ "bass": [
10
+ "bass guitar",
11
+ "bass synthesizer (moog etc)",
12
+ "contrabass/double bass (bass of instrings)",
13
+ "tuba (bass of brass)",
14
+ "bassoon (bass of woodwind)",
15
+ ],
16
+ "drums": [
17
+ "snare drum",
18
+ "toms",
19
+ "kick drum",
20
+ "cymbals",
21
+ "overheads",
22
+ "full acoustic drumkit",
23
+ "drum machine",
24
+ ],
25
+ "other": [
26
+ "fx/processed sound, scratches, gun shots, explosions etc",
27
+ "click track",
28
+ ],
29
+ "guitar": [
30
+ "clean electric guitar",
31
+ "distorted electric guitar",
32
+ "lap steel guitar or slide guitar",
33
+ "acoustic guitar",
34
+ ],
35
+ "other plucked": ["banjo, mandolin, ukulele, harp etc"],
36
+ "percussion": [
37
+ "a-tonal percussion (claps, shakers, congas, cowbell etc)",
38
+ "pitched percussion (mallets, glockenspiel, ...)",
39
+ ],
40
+ "piano": [
41
+ "grand piano",
42
+ "electric piano (rhodes, wurlitzer, piano sound alike)",
43
+ ],
44
+ "other keys": [
45
+ "organ, electric organ",
46
+ "synth pad",
47
+ "synth lead",
48
+ "other sounds (hapischord, melotron etc)",
49
+ ],
50
+ "bowed strings": [
51
+ "violin (solo)",
52
+ "viola (solo)",
53
+ "cello (solo)",
54
+ "violin section",
55
+ "viola section",
56
+ "cello section",
57
+ "string section",
58
+ "other strings",
59
+ ],
60
+ "wind": [
61
+ "brass (trumpet, trombone, french horn, brass etc)",
62
+ "flutes (piccolo, bamboo flute, panpipes, flutes etc)",
63
+ "reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
64
+ "other wind",
65
+ ],
66
+ }
67
+
68
+
69
+ def clean_track_inst(inst):
70
+
71
+ if "fx" in inst:
72
+ inst = "fx"
73
+
74
+ if "contrabass_double_bass" in inst:
75
+ inst = "double_bass"
76
+
77
+ if "banjo" in inst:
78
+ return "other_plucked"
79
+
80
+ if "(" in inst:
81
+ inst = inst.split("(")[0]
82
+
83
+ for s in [",", "-"]:
84
+ if s in inst:
85
+ inst = inst.replace(s, "")
86
+
87
+ for s in ["/"]:
88
+ if s in inst:
89
+ inst = inst.replace(s, "_")
90
+
91
+ if inst[-1] == "_":
92
+ inst = inst[:-1]
93
+
94
+ return inst
95
+
96
+
97
+ taxonomy = {k: [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()}
core/data/moisesdb/audio.ipynb ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "ExecuteTime": {
8
+ "end_time": "2024-02-17T00:59:48.228125593Z",
9
+ "start_time": "2024-02-17T00:59:47.533488738Z"
10
+ },
11
+ "collapsed": true
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import numpy as np\n",
16
+ "import IPython\n",
17
+ "file = \"/home/kwatchar3/Documents/data/moisesdb/npy/e2ccbc17-44bf-431a-af2b-4cf2fbd19a72/mixture.npy\"\n",
18
+ "\n",
19
+ "audio = np.load(file)\n",
20
+ "\n",
21
+ "IPython.display.Audio(audio, rate=44100)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "ExecuteTime": {
29
+ "end_time": "2024-02-16T18:01:10.487779628Z",
30
+ "start_time": "2024-02-16T18:01:06.898408871Z"
31
+ },
32
+ "collapsed": false
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "from scipy.signal import spectrogram\n",
37
+ "\n",
38
+ "f, t, Sxx = spectrogram(audio, 44100, nperseg=1024, noverlap=512)\n",
39
+ "\n",
40
+ "import matplotlib.pyplot as plt\n",
41
+ "\n",
42
+ "plt.pcolormesh(t, f, 10 * np.log10(Sxx[0]))"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {
49
+ "collapsed": false
50
+ },
51
+ "outputs": [],
52
+ "source": []
53
+ }
54
+ ],
55
+ "metadata": {
56
+ "kernelspec": {
57
+ "display_name": "Python 3",
58
+ "language": "python",
59
+ "name": "python3"
60
+ },
61
+ "language_info": {
62
+ "codemirror_mode": {
63
+ "name": "ipython",
64
+ "version": 2
65
+ },
66
+ "file_extension": ".py",
67
+ "mimetype": "text/x-python",
68
+ "name": "python",
69
+ "nbconvert_exporter": "python",
70
+ "pygments_lexer": "ipython2",
71
+ "version": "2.7.6"
72
+ }
73
+ },
74
+ "nbformat": 4,
75
+ "nbformat_minor": 0
76
+ }
core/data/moisesdb/datamodule.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from core.data.base import from_datasets
7
+ from core.data.moisesdb.dataset import MoisesDBRandomChunkBalancedRandomQueryDataset, MoisesDBRandomChunkRandomQueryDataset, \
8
+ MoisesDBDeterministicChunkDeterministicQueryDataset, \
9
+ MoisesDBFullTrackDataset, MoisesDBVDBODeterministicChunkDataset, \
10
+ MoisesDBVDBOFullTrackDataset, MoisesDBVDBORandomChunkDataset, \
11
+ MoisesDBFullTrackTestQueryDataset
12
+
13
+ def MoisesDataModule(
14
+ data_root: str,
15
+ batch_size: int,
16
+ num_workers: int = 8,
17
+ train_kwargs: Optional[Mapping] = None,
18
+ val_kwargs: Optional[Mapping] = None,
19
+ test_kwargs: Optional[Mapping] = None,
20
+ datamodule_kwargs: Optional[Mapping] = None,
21
+ ) -> pl.LightningDataModule:
22
+ if train_kwargs is None:
23
+ train_kwargs = {}
24
+
25
+ if val_kwargs is None:
26
+ val_kwargs = {}
27
+
28
+ if test_kwargs is None:
29
+ test_kwargs = {}
30
+
31
+ if datamodule_kwargs is None:
32
+ datamodule_kwargs = {}
33
+
34
+ train_dataset = MoisesDBRandomChunkRandomQueryDataset(
35
+ data_root=data_root, split="train", **train_kwargs
36
+ )
37
+
38
+ val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
39
+ data_root=data_root, split="val", **val_kwargs
40
+ )
41
+
42
+ test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
43
+ data_root=data_root, split="test", **test_kwargs
44
+ )
45
+
46
+ datamodule = from_datasets(
47
+ train_dataset=train_dataset,
48
+ val_dataset=val_dataset,
49
+ test_dataset=test_dataset,
50
+ batch_size=batch_size,
51
+ num_workers=num_workers,
52
+ **datamodule_kwargs
53
+ )
54
+
55
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
56
+ datamodule.test_dataloader
57
+ )
58
+
59
+ return datamodule
60
+
61
+ def MoisesBalancedTrainDataModule(
62
+ data_root: str,
63
+ batch_size: int,
64
+ num_workers: int = 8,
65
+ train_kwargs: Optional[Mapping] = None,
66
+ val_kwargs: Optional[Mapping] = None,
67
+ test_kwargs: Optional[Mapping] = None,
68
+ datamodule_kwargs: Optional[Mapping] = None,
69
+ ) -> pl.LightningDataModule:
70
+ if train_kwargs is None:
71
+ train_kwargs = {}
72
+
73
+ if val_kwargs is None:
74
+ val_kwargs = {}
75
+
76
+ if test_kwargs is None:
77
+ test_kwargs = {}
78
+
79
+ if datamodule_kwargs is None:
80
+ datamodule_kwargs = {}
81
+
82
+ train_dataset = MoisesDBRandomChunkBalancedRandomQueryDataset(
83
+ data_root=data_root, split="train", **train_kwargs
84
+ )
85
+
86
+ val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
87
+ data_root=data_root, split="val", **val_kwargs
88
+ )
89
+
90
+ test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
91
+ data_root=data_root, split="test", **test_kwargs
92
+ )
93
+
94
+ datamodule = from_datasets(
95
+ train_dataset=train_dataset,
96
+ val_dataset=val_dataset,
97
+ test_dataset=test_dataset,
98
+ batch_size=batch_size,
99
+ num_workers=num_workers,
100
+ **datamodule_kwargs
101
+ )
102
+
103
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
104
+ datamodule.test_dataloader
105
+ )
106
+
107
+ return datamodule
108
+
109
+
110
+ def MoisesValidationDataModule(
111
+ data_root: str,
112
+ batch_size: int,
113
+ num_workers: int = 8,
114
+ val_kwargs: Optional[Mapping] = None,
115
+ datamodule_kwargs: Optional[Mapping] = None,
116
+ **kwargs
117
+ ) -> pl.LightningDataModule:
118
+ if val_kwargs is None:
119
+ val_kwargs = {}
120
+
121
+ if datamodule_kwargs is None:
122
+ datamodule_kwargs = {}
123
+
124
+ allowed_stems = val_kwargs.get("allowed_stems", None)
125
+
126
+ assert allowed_stems is not None, "allowed_stems must be provided"
127
+
128
+ val_datasets = []
129
+
130
+ for allowed_stem in allowed_stems:
131
+ kwargs = val_kwargs.copy()
132
+ kwargs["allowed_stems"] = [allowed_stem]
133
+ val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
134
+ data_root=data_root, split="val",
135
+ **kwargs
136
+ )
137
+
138
+ val_datasets.append(val_dataset)
139
+
140
+ datamodule = from_datasets(
141
+ val_dataset=val_datasets,
142
+ batch_size=batch_size,
143
+ num_workers=num_workers,
144
+ **datamodule_kwargs
145
+ )
146
+
147
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
148
+ datamodule.val_dataloader
149
+ )
150
+
151
+ return datamodule
152
+
153
+ def MoisesTestDataModule(
154
+ data_root: str,
155
+ batch_size: int = 1,
156
+ num_workers: int = 8,
157
+ test_kwargs: Optional[Mapping] = None,
158
+ datamodule_kwargs: Optional[Mapping] = None,
159
+ **kwargs
160
+ ) -> pl.LightningDataModule:
161
+ if test_kwargs is None:
162
+ test_kwargs = {}
163
+
164
+ if datamodule_kwargs is None:
165
+ datamodule_kwargs = {}
166
+
167
+ allowed_stems = test_kwargs.get("allowed_stems", None)
168
+
169
+ assert allowed_stems is not None, "allowed_stems must be provided"
170
+
171
+ test_dataset = MoisesDBFullTrackTestQueryDataset(
172
+ data_root=data_root, split="test",
173
+ **test_kwargs
174
+ )
175
+
176
+ datamodule = from_datasets(
177
+ test_dataset=test_dataset,
178
+ batch_size=batch_size,
179
+ num_workers=num_workers,
180
+ **datamodule_kwargs
181
+ )
182
+
183
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
184
+ datamodule.test_dataloader
185
+ )
186
+
187
+ return datamodule
188
+
189
+
190
+ def MoisesVDBODataModule(
191
+ data_root: str,
192
+ batch_size: int,
193
+ num_workers: int = 8,
194
+ train_kwargs: Optional[Mapping] = None,
195
+ val_kwargs: Optional[Mapping] = None,
196
+ test_kwargs: Optional[Mapping] = None,
197
+ datamodule_kwargs: Optional[Mapping] = None,
198
+ ):
199
+
200
+
201
+ if train_kwargs is None:
202
+ train_kwargs = {}
203
+
204
+ if val_kwargs is None:
205
+ val_kwargs = {}
206
+
207
+ if test_kwargs is None:
208
+ test_kwargs = {}
209
+
210
+ if datamodule_kwargs is None:
211
+ datamodule_kwargs = {}
212
+
213
+ train_dataset = MoisesDBVDBORandomChunkDataset(
214
+ data_root=data_root, split="train", **train_kwargs
215
+ )
216
+
217
+ val_dataset = MoisesDBVDBODeterministicChunkDataset(
218
+ data_root=data_root, split="val", **val_kwargs
219
+ )
220
+
221
+ test_dataset = MoisesDBVDBOFullTrackDataset(
222
+ data_root=data_root, split="test", **test_kwargs
223
+ )
224
+
225
+ predict_dataset = test_dataset
226
+
227
+ datamodule = from_datasets(
228
+ train_dataset=train_dataset,
229
+ val_dataset=val_dataset,
230
+ test_dataset=test_dataset,
231
+ predict_dataset=predict_dataset,
232
+ batch_size=batch_size,
233
+ num_workers=num_workers,
234
+ **datamodule_kwargs
235
+ )
236
+
237
+ return datamodule
238
+
239
+
core/data/moisesdb/dataset.py ADDED
@@ -0,0 +1,1383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import warnings
5
+ from abc import ABC
6
+ from collections import defaultdict
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ from omegaconf import OmegaConf
11
+ import pandas as pd
12
+ import torch
13
+ from torch_audiomentations.utils.object_dict import ObjectDict
14
+ import torchaudio as ta
15
+ from torch.utils import data
16
+ from tqdm import tqdm
17
+
18
+ from core.data.base import BaseSourceSeparationDataset
19
+ from core.types import input_dict
20
+
21
+ from . import clean_track_inst
22
+
23
+ from torch import Tensor, nn
24
+
25
+ DBFS_HOP_SIZE = int(0.125 * 44100)
26
+ DBFS_CHUNK_SIZE = int(1 * 44100)
27
+
28
+ INST_BY_OCCURRENCE = [
29
+ "bass_guitar",
30
+ "kick_drum",
31
+ "snare_drum",
32
+ "lead_male_singer",
33
+ "distorted_electric_guitar",
34
+ "clean_electric_guitar",
35
+ "toms",
36
+ "acoustic_guitar",
37
+ "background_vocals",
38
+ "hi_hat",
39
+ "overheads",
40
+ "atonal_percussion",
41
+ "grand_piano",
42
+ "cymbals",
43
+ "lead_female_singer",
44
+ "synth_lead",
45
+ "bass_synthesizer",
46
+ "synth_pad",
47
+ "organ_electric_organ",
48
+ "fx",
49
+ "drum_machine",
50
+ "string_section",
51
+ "electric_piano",
52
+ "full_acoustic_drumkit",
53
+ "other_sounds",
54
+ "pitched_percussion",
55
+ "brass",
56
+ "reeds",
57
+ "contrabass_double_bass",
58
+ "other_plucked",
59
+ "other_strings",
60
+ "other_wind",
61
+ "cello",
62
+ "other",
63
+ "flutes",
64
+ "viola_section",
65
+ "viola",
66
+ "cello_section",
67
+ ]
68
+
69
+ FINE_LEVEL_INSTRUMENTS = {
70
+ "lead_male_singer",
71
+ "lead_female_singer",
72
+ "human_choir",
73
+ "background_vocals",
74
+ "other_vocals",
75
+ "bass_guitar",
76
+ "bass_synthesizer",
77
+ "contrabass_double_bass",
78
+ "tuba",
79
+ "bassoon",
80
+ "snare_drum",
81
+ "toms",
82
+ "kick_drum",
83
+ "cymbals",
84
+ "overheads",
85
+ "full_acoustic_drumkit",
86
+ "drum_machine",
87
+ "hihat",
88
+ "fx",
89
+ "click_track",
90
+ "clean_electric_guitar",
91
+ "distorted_electric_guitar",
92
+ "lap_steel_guitar_or_slide_guitar",
93
+ "acoustic_guitar",
94
+ "other_plucked",
95
+ "atonal_percussion",
96
+ "pitched_percussion",
97
+ "grand_piano",
98
+ "electric_piano",
99
+ "organ_electric_organ",
100
+ "synth_pad",
101
+ "synth_lead",
102
+ "other_sounds",
103
+ "violin",
104
+ "viola",
105
+ "cello",
106
+ "violin_section",
107
+ "viola_section",
108
+ "cello_section",
109
+ "string_section",
110
+ "other_strings",
111
+ "brass",
112
+ "flutes",
113
+ "reeds",
114
+ "other_wind",
115
+ }
116
+
117
+ COARSE_LEVEL_INSTRUMENTS = {
118
+ "vocals",
119
+ "bass",
120
+ "drums",
121
+ "guitar",
122
+ "other_plucked",
123
+ "percussion",
124
+ "piano",
125
+ "other_keys",
126
+ "bowed_strings",
127
+ "wind",
128
+ "other",
129
+ }
130
+
131
+ COARSE_TO_FINE = {
132
+ "vocals": [
133
+ "lead_male_singer",
134
+ "lead_female_singer",
135
+ "human_choir",
136
+ "background_vocals",
137
+ "other_vocals",
138
+ ],
139
+ "bass": [
140
+ "bass_guitar",
141
+ "bass_synthesizer",
142
+ "contrabass_double_bass",
143
+ "tuba",
144
+ "bassoon",
145
+ ],
146
+ "drums": [
147
+ "snare_drum",
148
+ "toms",
149
+ "kick_drum",
150
+ "cymbals",
151
+ "overheads",
152
+ "full_acoustic_drumkit",
153
+ "drum_machine",
154
+ "hihat",
155
+ ],
156
+ "other": ["fx", "click_track"],
157
+ "guitar": [
158
+ "clean_electric_guitar",
159
+ "distorted_electric_guitar",
160
+ "lap_steel_guitar_or_slide_guitar",
161
+ "acoustic_guitar",
162
+ ],
163
+ "other_plucked": ["other_plucked"],
164
+ "percussion": ["atonal_percussion", "pitched_percussion"],
165
+ "piano": ["grand_piano", "electric_piano"],
166
+ "other_keys": ["organ_electric_organ", "synth_pad", "synth_lead", "other_sounds"],
167
+ "bowed_strings": [
168
+ "violin",
169
+ "viola",
170
+ "cello",
171
+ "violin_section",
172
+ "viola_section",
173
+ "cello_section",
174
+ "string_section",
175
+ "other_strings",
176
+ ],
177
+ "wind": ["brass", "flutes", "reeds", "other_wind"],
178
+ }
179
+
180
+ COARSE_TO_FINE = {k: set(v) for k, v in COARSE_TO_FINE.items()}
181
+ FINE_TO_COARSE = {k: kk for kk, v in COARSE_TO_FINE.items() for k in v}
182
+
183
+ ALL_LEVEL_INSTRUMENTS = COARSE_LEVEL_INSTRUMENTS.union(FINE_LEVEL_INSTRUMENTS)
184
+
185
+
186
+ class MoisesDBBaseDataset(BaseSourceSeparationDataset, ABC):
187
+ def __init__(
188
+ self,
189
+ split: str,
190
+ data_path: str = "/home/kwatchar3/Documents/data/moisesdb",
191
+ fs: int = 44100,
192
+ return_stems: Union[bool, List[str]] = False,
193
+ npy_memmap=True,
194
+ recompute_mixture=False,
195
+ train_folds=None,
196
+ val_folds=None,
197
+ test_folds=None,
198
+ query_file="query",
199
+ ) -> None:
200
+ if test_folds is None:
201
+ test_folds = [5]
202
+
203
+ if val_folds is None:
204
+ val_folds = [4]
205
+
206
+ if train_folds is None:
207
+ train_folds = [1, 2, 3]
208
+
209
+ split_path = os.path.join(data_path, "splits.csv")
210
+ splits = pd.read_csv(split_path)
211
+
212
+ metadata_path = os.path.join(data_path, "stems.csv")
213
+ metadata = pd.read_csv(metadata_path)
214
+
215
+ if split == "train":
216
+ folds = train_folds
217
+ elif split == "val":
218
+ folds = val_folds
219
+ elif split == "test":
220
+ folds = test_folds
221
+ else:
222
+ raise NameError
223
+
224
+ files = splits[splits["split"].isin(folds)]["song_id"].tolist()
225
+ metadata = metadata[metadata["song_id"].isin(files)]
226
+
227
+ super().__init__(
228
+ split=split,
229
+ stems=["mixture"],
230
+ files=files,
231
+ data_path=data_path,
232
+ fs=fs,
233
+ npy_memmap=npy_memmap,
234
+ recompute_mixture=recompute_mixture,
235
+ )
236
+
237
+ self.folds = folds
238
+
239
+ self.metadata = metadata.rename(
240
+ columns={k: k.replace(" ", "_") for k in metadata.columns}
241
+ )
242
+
243
+ self.song_to_stem = (
244
+ metadata.set_index("song_id")
245
+ .apply(lambda row: row[row == 1].index.tolist(), axis=1)
246
+ .to_dict()
247
+ )
248
+ self.stem_to_song = (
249
+ metadata.set_index("song_id")
250
+ .transpose()
251
+ .apply(lambda row: row[row == 1].index.tolist(), axis=1)
252
+ .to_dict()
253
+ )
254
+
255
+ self.true_length = len(self.files)
256
+ self.n_channels = 2
257
+
258
+ self.audio_path = os.path.join(data_path, "npy2")
259
+
260
+ self.return_stems = return_stems
261
+
262
+ self.query_file = query_file
263
+
264
+ def get_full_stem(self, *, stem: str, identifier) -> torch.Tensor:
265
+ song_id = identifier["song_id"]
266
+ path = os.path.join(self.data_path, "npy2", song_id)
267
+ # noinspection PyUnresolvedReferences
268
+
269
+ assert self.npy_memmap
270
+
271
+ if os.path.exists(os.path.join(path, f"{stem}.npy")):
272
+ audio = np.load(os.path.join(path, f"{stem}.npy"), mmap_mode="r")
273
+ else:
274
+ audio = None
275
+
276
+ return audio
277
+
278
+ def get_query_stem(self, *, stem: str, identifier) -> torch.Tensor:
279
+ song_id = identifier["song_id"]
280
+ path = os.path.join(self.data_path, "npyq", song_id)
281
+ # noinspection PyUnresolvedReferences
282
+
283
+ if self.npy_memmap:
284
+ # print(self.npy_memmap)
285
+ audio = np.load(
286
+ os.path.join(path, f"{stem}.{self.query_file}.npy"), mmap_mode="r"
287
+ )
288
+ else:
289
+ raise NotImplementedError
290
+
291
+ return audio
292
+
293
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
294
+ audio = self.get_full_stem(stem=stem, identifier=identifier)
295
+ return audio
296
+
297
+ def get_identifier(self, index):
298
+ return dict(song_id=self.files[index % self.true_length])
299
+
300
+ def __getitem__(self, index: int):
301
+ identifier = self.get_identifier(index)
302
+ audio = self.get_audio(identifier)
303
+
304
+ mixture = audio["mixture"].copy()
305
+
306
+ if isinstance(self.return_stems, list):
307
+ sources = {
308
+ stem: audio.get(stem, np.zeros_like(mixture))
309
+ for stem in self.return_stems
310
+ }
311
+ elif isinstance(self.return_stems, bool):
312
+ if self.return_stems:
313
+ sources = {
314
+ stem: audio[stem].copy()
315
+ for stem in self.song_to_stem[identifier["song_id"]]
316
+ }
317
+ else:
318
+ sources = None
319
+ else:
320
+ raise ValueError
321
+
322
+ return input_dict(
323
+ mixture=mixture,
324
+ sources=sources,
325
+ metadata=identifier,
326
+ modality="audio",
327
+ )
328
+
329
+
330
+ class MoisesDBFullTrackDataset(MoisesDBBaseDataset):
331
+ def __init__(
332
+ self,
333
+ data_root: str,
334
+ split: str,
335
+ return_stems: Union[bool, List[str]] = False,
336
+ npy_memmap=True,
337
+ recompute_mixture=False,
338
+ query_file="query",
339
+ ) -> None:
340
+ super().__init__(
341
+ split=split,
342
+ data_path=data_root,
343
+ return_stems=return_stems,
344
+ npy_memmap=npy_memmap,
345
+ recompute_mixture=recompute_mixture,
346
+ query_file=query_file,
347
+ )
348
+
349
+ def __len__(self) -> int:
350
+ return self.true_length
351
+
352
+
353
+ class MoisesDBVDBOFullTrackDataset(MoisesDBFullTrackDataset):
354
+ def __init__(
355
+ self, data_root: str, split: str, npy_memmap=True, recompute_mixture=False
356
+ ) -> None:
357
+ super().__init__(
358
+ data_root=data_root,
359
+ split=split,
360
+ return_stems=["vocals", "bass", "drums", "vdbo_others"],
361
+ npy_memmap=npy_memmap,
362
+ recompute_mixture=recompute_mixture,
363
+ query_file=None,
364
+ )
365
+
366
+
367
+ import torch_audiomentations as audiomentations
368
+ from torch_audiomentations.utils.dsp import convert_decibels_to_amplitude_ratio
369
+
370
+
371
+ class SmartGain(audiomentations.Gain):
372
+ def __init__(
373
+ self, p=0.5, min_gain_in_db=-6, max_gain_in_db=6, dbfs_threshold=-45.0
374
+ ):
375
+ super().__init__(
376
+ p=p, min_gain_in_db=min_gain_in_db, max_gain_in_db=max_gain_in_db
377
+ )
378
+
379
+ self.dbfs_threshold = dbfs_threshold
380
+
381
+ def randomize_parameters(
382
+ self,
383
+ samples: Tensor = None,
384
+ sample_rate: Optional[int] = None,
385
+ targets: Optional[Tensor] = None,
386
+ target_rate: Optional[int] = None,
387
+ ):
388
+
389
+ dbfs = 10 * torch.log10(torch.mean(torch.square(samples)) + 1e-6)
390
+
391
+ if dbfs > self.dbfs_threshold:
392
+ low = self.min_gain_in_db
393
+ else:
394
+ low = max(0.0, self.min_gain_in_db)
395
+
396
+ distribution = torch.distributions.Uniform(
397
+ low=torch.tensor(low, dtype=torch.float32, device=samples.device),
398
+ high=torch.tensor(
399
+ self.max_gain_in_db, dtype=torch.float32, device=samples.device
400
+ ),
401
+ validate_args=True,
402
+ )
403
+ selected_batch_size = samples.size(0)
404
+ self.transform_parameters["gain_factors"] = (
405
+ convert_decibels_to_amplitude_ratio(
406
+ distribution.sample(sample_shape=(selected_batch_size,))
407
+ )
408
+ .unsqueeze(1)
409
+ .unsqueeze(1)
410
+ )
411
+
412
+
413
+ class Audiomentations(audiomentations.Compose):
414
+ def __init__(self, augment="gssp", fs: int = 44100):
415
+
416
+ if isinstance(augment, str):
417
+ if augment == "gssp":
418
+ augment = OmegaConf.create(
419
+ [
420
+ dict(
421
+ cls="Shift",
422
+ kwargs=dict(p=1.0, min_shift=-0.5, max_shift=0.5),
423
+ ),
424
+ dict(
425
+ cls="Gain",
426
+ kwargs=dict(p=1.0, min_gain_in_db=-6, max_gain_in_db=6),
427
+ ),
428
+ dict(cls="ShuffleChannels", kwargs=dict(p=0.5)),
429
+ dict(cls="PolarityInversion", kwargs=dict(p=0.5)),
430
+ ]
431
+ )
432
+ else:
433
+ raise ValueError
434
+
435
+ transforms = []
436
+
437
+ for transform in augment:
438
+
439
+ if transform.cls == "Gain":
440
+ transforms.append(SmartGain(**transform.kwargs))
441
+ else:
442
+ transforms.append(
443
+ getattr(audiomentations, transform.cls)(**transform.kwargs)
444
+ )
445
+
446
+ super().__init__(transforms=transforms, shuffle=True)
447
+
448
+ self.fs = fs
449
+
450
+ def forward(
451
+ self,
452
+ samples: torch.Tensor = None,
453
+ ) -> ObjectDict:
454
+ return super().forward(samples, sample_rate=self.fs)
455
+
456
+
457
+ class MoisesDBVDBORandomChunkDataset(MoisesDBVDBOFullTrackDataset):
458
+ def __init__(
459
+ self,
460
+ data_root: str,
461
+ split: str,
462
+ chunk_size_seconds: float = 4.0,
463
+ fs: int = 44100,
464
+ target_length: int = 8192,
465
+ augment=None,
466
+ npy_memmap=True,
467
+ recompute_mixture=True,
468
+ db_threshold=-24.0,
469
+ db_step=-12.0,
470
+ ) -> None:
471
+ super().__init__(
472
+ data_root=data_root,
473
+ split=split,
474
+ npy_memmap=npy_memmap,
475
+ recompute_mixture=recompute_mixture,
476
+ )
477
+
478
+ self.chunk_size_seconds = chunk_size_seconds
479
+ self.chunk_size_samples = int(chunk_size_seconds * fs)
480
+ self.fs = fs
481
+
482
+ self.target_length = target_length
483
+
484
+ self.db_threshold = db_threshold
485
+ self.db_step = db_step
486
+
487
+ if augment is not None:
488
+ assert self.recompute_mixture
489
+ self.augment = Audiomentations(augment, fs)
490
+ else:
491
+ self.augment = None
492
+
493
+ def __len__(self) -> int:
494
+ return self.target_length
495
+
496
+ def _chunk_audio(self, audio, start, end):
497
+ audio = {k: v[..., start:end] for k, v in audio.items()}
498
+
499
+ return audio
500
+
501
+ def _get_start_end(self, audio, identifier):
502
+ n_samples = audio.shape[-1]
503
+ start = np.random.randint(0, n_samples - self.chunk_size_samples)
504
+ end = start + self.chunk_size_samples
505
+
506
+ return start, end
507
+
508
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
509
+ audio = {}
510
+
511
+ for stem in stems:
512
+ audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
513
+
514
+ for stem in stems:
515
+ if audio[stem] is None:
516
+ audio[stem] = np.zeros(
517
+ audio[
518
+ (
519
+ "mixture"
520
+ if "mixture" in stems
521
+ else [s for s in stems if audio[s] is not None][0]
522
+ )
523
+ ].shape,
524
+ dtype=np.float32,
525
+ )
526
+
527
+ start, end = self._get_start_end(audio[stems[0]], identifier)
528
+ audio = self._chunk_audio(audio, start, end)
529
+
530
+ if self.augment is not None:
531
+ audio = {
532
+ k: self.augment(torch.from_numpy(v[None, :, :]))[0, :, :].numpy()
533
+ for k, v in audio.items()
534
+ }
535
+
536
+ return audio
537
+
538
+ def get_audio(self, identifier: Dict[str, Any]):
539
+ if self.recompute_mixture:
540
+ audio = self._get_audio(
541
+ ["vocals", "bass", "drums", "vdbo_others"], identifier=identifier
542
+ )
543
+ audio["mixture"] = self.compute_mixture(audio)
544
+ return audio
545
+ else:
546
+ return self._get_audio(
547
+ ["mixture", "vocals", "bass", "drums", "vdbo_others"],
548
+ identifier=identifier,
549
+ )
550
+
551
+ def __getitem__(self, index: int):
552
+
553
+ identifier = self.get_identifier(index)
554
+ audio = self.get_audio(identifier=identifier)
555
+
556
+ mixture = audio["mixture"].copy()
557
+
558
+ sources = {
559
+ stem: audio.get(stem, np.zeros_like(mixture)) for stem in self.return_stems
560
+ }
561
+
562
+ return input_dict(
563
+ mixture=mixture,
564
+ sources=sources,
565
+ metadata=identifier,
566
+ modality="audio",
567
+ )
568
+
569
+
570
+ class MoisesDBVDBODeterministicChunkDataset(MoisesDBVDBORandomChunkDataset):
571
+ def __init__(
572
+ self,
573
+ data_root: str,
574
+ split: str,
575
+ chunk_size_seconds: float = 4.0,
576
+ hop_size_seconds: float = 8.0,
577
+ fs: int = 44100,
578
+ npy_memmap=True,
579
+ recompute_mixture=False,
580
+ ) -> None:
581
+ super().__init__(
582
+ data_root=data_root,
583
+ split=split,
584
+ chunk_size_seconds=chunk_size_seconds,
585
+ npy_memmap=npy_memmap,
586
+ recompute_mixture=recompute_mixture,
587
+ )
588
+
589
+ self.hop_size_seconds = hop_size_seconds
590
+ self.hop_size_samples = int(hop_size_seconds * fs)
591
+
592
+ self.index_to_identifiers = self._generate_index()
593
+ self.length = len(self.index_to_identifiers)
594
+
595
+ def __len__(self) -> int:
596
+ return self.length
597
+
598
+ def _generate_index(self):
599
+
600
+ identifiers = []
601
+
602
+ for song_id in self.files:
603
+ audio = self.get_full_stem(stem="mixture", identifier=dict(song_id=song_id))
604
+ n_samples = audio.shape[-1]
605
+ n_chunks = math.floor(
606
+ (n_samples - self.chunk_size_samples) / self.hop_size_samples
607
+ )
608
+
609
+ for i in range(n_chunks):
610
+ chunk_start = i * self.hop_size_samples
611
+ identifiers.append(dict(song_id=song_id, chunk_start=chunk_start))
612
+
613
+ return identifiers
614
+
615
+ def get_identifier(self, index):
616
+ return self.index_to_identifiers[index]
617
+
618
+ def _get_start_end(self, audio, identifier):
619
+
620
+ start = identifier["chunk_start"]
621
+ end = start + self.chunk_size_samples
622
+
623
+ return start, end
624
+
625
+
626
+ def round_samples(seconds, fs, hop_size, downsample):
627
+ n_frames = math.ceil(seconds * fs / hop_size) + 1
628
+ n_frames_down = math.ceil(n_frames / downsample)
629
+ n_frames = n_frames_down * downsample
630
+ n_samples = (n_frames - 1) * hop_size
631
+
632
+ return int(n_samples)
633
+
634
+
635
+ class MoisesDBRandomChunkRandomQueryDataset(MoisesDBFullTrackDataset):
636
+ def __init__(
637
+ self,
638
+ data_root: str,
639
+ split: str,
640
+ target_length: int,
641
+ chunk_size_seconds: float = 4.0,
642
+ query_size_seconds: float = 1.0,
643
+ round_query: bool = False,
644
+ min_query_dbfs: float = -40.0,
645
+ min_target_dbfs: float = -36.0,
646
+ min_target_dbfs_step: float = -12.0,
647
+ max_dbfs_tries: int = 10,
648
+ top_k_instrument: int = 10,
649
+ mixture_stem: str = "mixture",
650
+ use_own_query: bool = True,
651
+ npy_memmap=True,
652
+ allowed_stems=None,
653
+ query_file="query",
654
+ augment=None,
655
+ ) -> None:
656
+
657
+ super().__init__(
658
+ data_root=data_root,
659
+ split=split,
660
+ npy_memmap=npy_memmap,
661
+ recompute_mixture=augment is not None,
662
+ query_file=query_file,
663
+ )
664
+
665
+ self.mixture_stem = mixture_stem
666
+
667
+ self.chunk_size_seconds = chunk_size_seconds
668
+ self.chunk_size_samples = round_samples(
669
+ self.chunk_size_seconds, self.fs, 512, 2**6
670
+ )
671
+
672
+ self.query_size_seconds = query_size_seconds
673
+
674
+ if round_query:
675
+ self.query_size_samples = round_samples(
676
+ self.query_size_seconds, self.fs, 512, 2**6
677
+ )
678
+ else:
679
+ self.query_size_samples = int(self.query_size_seconds * self.fs)
680
+
681
+ self.target_length = target_length
682
+
683
+ self.min_query_dbfs = min_query_dbfs
684
+
685
+ if min_target_dbfs is None:
686
+ min_target_dbfs = -np.inf
687
+ min_target_dbfs_step = None
688
+ max_dbfs_tries = 1
689
+
690
+ self.min_target_dbfs = min_target_dbfs
691
+ self.min_target_dbfs_step = min_target_dbfs_step
692
+ self.max_dbfs_tries = max_dbfs_tries
693
+
694
+ self.top_k_instrument = top_k_instrument
695
+
696
+ if allowed_stems is None:
697
+ allowed_stems = INST_BY_OCCURRENCE[: self.top_k_instrument]
698
+ else:
699
+ self.top_k_instrument = None
700
+
701
+ self.allowed_stems = allowed_stems
702
+
703
+ self.song_to_all_stems = {
704
+ k: list(set(v) & set(ALL_LEVEL_INSTRUMENTS))
705
+ for k, v in self.song_to_stem.items()
706
+ }
707
+
708
+ self.song_to_stem = {
709
+ k: list(set(v) & set(self.allowed_stems))
710
+ for k, v in self.song_to_stem.items()
711
+ }
712
+ self.stem_to_song = {
713
+ k: list(set(v) & set(self.files)) for k, v in self.stem_to_song.items()
714
+ }
715
+
716
+ self.queriable_songs = [k for k, v in self.song_to_stem.items() if len(v) > 0]
717
+
718
+ self.use_own_query = use_own_query
719
+
720
+ if self.use_own_query:
721
+ self.files = [k for k in self.files if len(self.song_to_stem[k]) > 0]
722
+ self.true_length = len(self.files)
723
+
724
+ if augment is not None:
725
+ assert self.recompute_mixture
726
+ self.augment = Audiomentations(augment, self.fs)
727
+ else:
728
+ self.augment = None
729
+
730
+ def __len__(self) -> int:
731
+ return self.target_length
732
+
733
+ def _chunk_audio(self, audio, start, end):
734
+ audio = {k: v[..., start:end] for k, v in audio.items()}
735
+
736
+ return audio
737
+
738
+ def _get_start_end(self, audio):
739
+ n_samples = audio.shape[-1]
740
+ start = np.random.randint(0, n_samples - self.chunk_size_samples)
741
+ end = start + self.chunk_size_samples
742
+
743
+ return start, end
744
+
745
+ def _target_dbfs(self, audio):
746
+ return 10.0 * np.log10(np.mean(np.square(np.abs(audio))) + 1e-6)
747
+
748
+ def _chunk_and_check_dbfs_threshold(self, audio_, target_stem, threshold):
749
+
750
+ target_dict = {target_stem: audio_[target_stem]}
751
+
752
+ for _ in range(self.max_dbfs_tries):
753
+ start, end = self._get_start_end(audio_[target_stem])
754
+ taudio = self._chunk_audio(target_dict, start, end)
755
+
756
+ dbfs = self._target_dbfs(taudio[target_stem])
757
+ if dbfs > threshold:
758
+ return self._chunk_audio(audio_, start, end)
759
+
760
+ return None
761
+
762
+ def _chunk_and_check_dbfs(self, audio_, target_stem):
763
+ out = self._chunk_and_check_dbfs_threshold(
764
+ audio_, target_stem, self.min_target_dbfs
765
+ )
766
+
767
+ if out is not None:
768
+ return out
769
+
770
+ out = self._chunk_and_check_dbfs_threshold(
771
+ audio_, target_stem, self.min_target_dbfs + self.min_target_dbfs_step
772
+ )
773
+
774
+ if out is not None:
775
+ return out
776
+
777
+ start, end = self._get_start_end(audio_[target_stem])
778
+ audio = self._chunk_audio(audio_, start, end)
779
+
780
+ return audio
781
+
782
+ def _augment(self, audio, target_stem):
783
+ stack_audio = np.stack([v for v in audio.values()], axis=0)
784
+ aug_audio = self.augment(torch.from_numpy(stack_audio)).numpy()
785
+ mixture = np.sum(aug_audio, axis=0)
786
+
787
+ out = {
788
+ "mixture": mixture,
789
+ }
790
+
791
+ if target_stem is not None:
792
+ target_idx = list(audio.keys()).index(target_stem)
793
+ out[target_stem] = aug_audio[target_idx]
794
+
795
+ return out
796
+
797
+ def _choose_stems_for_augment(self, identifier, target_stem):
798
+ stems_for_song = set(self.song_to_all_stems[identifier["song_id"]])
799
+
800
+ stems_ = []
801
+ coarse_level_accounted = set()
802
+
803
+ is_none_target = target_stem is None
804
+ is_coarse_target = target_stem in COARSE_LEVEL_INSTRUMENTS
805
+
806
+ if is_coarse_target or is_none_target:
807
+ coarse_target = target_stem
808
+ else:
809
+ coarse_target = FINE_TO_COARSE[target_stem]
810
+
811
+ fine_level_stems = stems_for_song & FINE_LEVEL_INSTRUMENTS
812
+ coarse_level_stems = stems_for_song & COARSE_LEVEL_INSTRUMENTS
813
+
814
+ for s in fine_level_stems:
815
+ coarse_level = FINE_TO_COARSE[s]
816
+
817
+ if is_coarse_target and coarse_level == coarse_target:
818
+ continue
819
+ else:
820
+ stems_.append(s)
821
+
822
+ coarse_level_accounted.add(coarse_level)
823
+
824
+ stems_ += list(coarse_level_stems - coarse_level_accounted)
825
+
826
+ if target_stem is not None:
827
+ assert target_stem in stems_, f"stems: {stems_}, target stem: {target_stem}"
828
+
829
+ if len(stems_for_song) > 1:
830
+ assert (
831
+ len(stems_) > 1
832
+ ), f"stems: {stems_}, stems in song: {stems_for_song},\n target stem: {target_stem}"
833
+
834
+ assert "mixture" not in stems_
835
+
836
+ return stems_
837
+
838
+ def _get_audio(
839
+ self, stems, identifier: Dict[str, Any], check_dbfs=True, no_target=False
840
+ ):
841
+
842
+ target_stem = stems[0] if not no_target else None
843
+
844
+ if self.augment is not None:
845
+ stems_ = self._choose_stems_for_augment(identifier, target_stem)
846
+ else:
847
+ stems_ = stems
848
+
849
+ audio = {}
850
+ for stem in stems_:
851
+ audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
852
+
853
+ audio_ = {k: v.copy() for k, v in audio.items()}
854
+
855
+ if check_dbfs:
856
+ assert target_stem is not None
857
+ audio = self._chunk_and_check_dbfs(audio_, target_stem)
858
+ else:
859
+ first_key = list(audio_.keys())[0]
860
+ start, end = self._get_start_end(audio_[first_key])
861
+ audio = self._chunk_audio(audio_, start, end)
862
+
863
+ if self.augment is not None:
864
+ assert "mixture" not in audio
865
+ audio = self._augment(audio, target_stem)
866
+ assert "mixture" in audio
867
+
868
+ return audio
869
+
870
+ def __getitem__(self, index: int):
871
+
872
+ mix_identifier = self.get_identifier(index)
873
+ mix_stems = self.song_to_stem[mix_identifier["song_id"]]
874
+
875
+ if self.use_own_query:
876
+ query_id = mix_identifier["song_id"]
877
+ query_identifier = dict(song_id=query_id)
878
+ possible_stem = mix_stems
879
+
880
+ assert len(possible_stem) > 0
881
+
882
+ zero_target = False
883
+ else:
884
+ query_id = random.choice(self.queriable_songs)
885
+ query_identifier = dict(song_id=query_id)
886
+ query_stems = self.song_to_stem[query_id]
887
+ possible_stem = list(set(mix_stems) & set(query_stems))
888
+
889
+ if len(possible_stem) == 0:
890
+ possible_stem = query_stems
891
+ zero_target = True
892
+ # print(f"Mix {mix_identifier['song_id']} and query {query_id} have no common stems.")
893
+ # return self.__getitem__(index + 1)
894
+ else:
895
+ zero_target = False
896
+
897
+ assert (
898
+ len(possible_stem) > 0
899
+ ), f"{mix_identifier['song_id']} and {query_id} have no common stems. zero target is {zero_target}"
900
+ stem = random.choice(possible_stem)
901
+
902
+ if zero_target:
903
+ audio = self._get_audio(
904
+ [self.mixture_stem],
905
+ identifier=mix_identifier,
906
+ check_dbfs=False,
907
+ no_target=True,
908
+ )
909
+ mixture = audio[self.mixture_stem].copy()
910
+ sources = {"target": np.zeros_like(mixture)}
911
+ else:
912
+ audio = self._get_audio(
913
+ [stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True
914
+ )
915
+ mixture = audio[self.mixture_stem].copy()
916
+ sources = {"target": audio[stem].copy()}
917
+
918
+ query = self.get_query_stem(stem=stem, identifier=query_identifier)
919
+ query = query.copy()
920
+
921
+ assert mixture.shape[-1] == self.chunk_size_samples
922
+ assert query.shape[-1] == self.query_size_samples
923
+ assert sources["target"].shape[-1] == self.chunk_size_samples
924
+
925
+ return input_dict(
926
+ mixture=mixture,
927
+ sources=sources,
928
+ query=query,
929
+ metadata={
930
+ "mix": mix_identifier,
931
+ "query": query_identifier,
932
+ "stem": stem,
933
+ },
934
+ modality="audio",
935
+ )
936
+
937
+
938
+ class MoisesDBRandomChunkBalancedRandomQueryDataset(
939
+ MoisesDBRandomChunkRandomQueryDataset
940
+ ):
941
+ def __init__(
942
+ self,
943
+ data_root: str,
944
+ split: str,
945
+ target_length: int,
946
+ chunk_size_seconds: float = 4,
947
+ query_size_seconds: float = 1,
948
+ round_query: bool = False,
949
+ min_query_dbfs: float = -40.0,
950
+ min_target_dbfs: float = -36.0,
951
+ min_target_dbfs_step: float = -12.0,
952
+ max_dbfs_tries: int = 10,
953
+ top_k_instrument: int = 10,
954
+ mixture_stem: str = "mixture",
955
+ use_own_query: bool = True,
956
+ npy_memmap=True,
957
+ allowed_stems=None,
958
+ query_file="query",
959
+ augment=None,
960
+ ) -> None:
961
+ super().__init__(
962
+ data_root,
963
+ split,
964
+ target_length,
965
+ chunk_size_seconds,
966
+ query_size_seconds,
967
+ round_query,
968
+ min_query_dbfs,
969
+ min_target_dbfs,
970
+ min_target_dbfs_step,
971
+ max_dbfs_tries,
972
+ top_k_instrument,
973
+ mixture_stem,
974
+ use_own_query,
975
+ npy_memmap,
976
+ allowed_stems,
977
+ query_file,
978
+ augment,
979
+ )
980
+
981
+ self.stem_to_n_songs = {k: len(v) for k, v in self.stem_to_song.items()}
982
+ self.trainable_stems = [k for k, v in self.stem_to_n_songs.items() if v > 1]
983
+ self.n_allowed_stems = len(self.allowed_stems)
984
+
985
+
986
+
987
+ def __getitem__(self, index: int):
988
+
989
+ stem = self.allowed_stems[index % self.n_allowed_stems]
990
+ song_ids_with_stem = self.stem_to_song[stem]
991
+
992
+ song_id = song_ids_with_stem[index % self.stem_to_n_songs[stem]]
993
+
994
+ mix_identifier = dict(song_id=song_id)
995
+
996
+ audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True)
997
+ mixture = audio[self.mixture_stem].copy()
998
+
999
+ if self.use_own_query:
1000
+ query_id = song_id
1001
+ query_identifier = dict(song_id=query_id)
1002
+ else:
1003
+ query_id = random.choice(song_ids_with_stem)
1004
+ query_identifier = dict(song_id=query_id)
1005
+
1006
+ query = self.get_query_stem(stem=stem, identifier=query_identifier)
1007
+ query = query.copy()
1008
+
1009
+ sources = {"target": audio[stem].copy()}
1010
+
1011
+ return input_dict(
1012
+ mixture=mixture,
1013
+ sources=sources,
1014
+ query=query,
1015
+ metadata={
1016
+ "mix": mix_identifier,
1017
+ "query": query_identifier,
1018
+ "stem": stem,
1019
+ },
1020
+ modality="audio",
1021
+ )
1022
+
1023
+
1024
+
1025
+
1026
+ class MoisesDBDeterministicChunkDeterministicQueryDataset(
1027
+ MoisesDBRandomChunkRandomQueryDataset
1028
+ ):
1029
+ def __init__(
1030
+ self,
1031
+ data_root: str,
1032
+ split: str,
1033
+ chunk_size_seconds: float = 4.0,
1034
+ hop_size_seconds: float = 8.0,
1035
+ query_size_seconds: float = 1.0,
1036
+ min_query_dbfs: float = -40.0,
1037
+ top_k_instrument: int = 10,
1038
+ n_queries_per_chunk: int = 1,
1039
+ mixture_stem: str = "mixture",
1040
+ use_own_query: bool = True,
1041
+ npy_memmap=True,
1042
+ allowed_stems: List[str] = None,
1043
+ query_file="query",
1044
+ ) -> None:
1045
+
1046
+ super().__init__(
1047
+ data_root=data_root,
1048
+ split=split,
1049
+ target_length=None,
1050
+ chunk_size_seconds=chunk_size_seconds,
1051
+ query_size_seconds=query_size_seconds,
1052
+ min_query_dbfs=min_query_dbfs,
1053
+ top_k_instrument=top_k_instrument,
1054
+ mixture_stem=mixture_stem,
1055
+ use_own_query=use_own_query,
1056
+ npy_memmap=npy_memmap,
1057
+ allowed_stems=allowed_stems,
1058
+ query_file=query_file,
1059
+ )
1060
+
1061
+ if hop_size_seconds is None:
1062
+ hop_size_seconds = chunk_size_seconds
1063
+
1064
+ self.chunk_hop_size_seconds = hop_size_seconds
1065
+
1066
+ self.chunk_hop_size_samples = int(hop_size_seconds * self.fs)
1067
+
1068
+ self.n_queries_per_chunk = n_queries_per_chunk
1069
+
1070
+ self._overwrite = False
1071
+
1072
+ self.query_tuples = self.find_query_tuples_or_generate()
1073
+ self.n_chunks = len(self.query_tuples)
1074
+
1075
+ def __len__(self) -> int:
1076
+ return self.n_chunks
1077
+
1078
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
1079
+ audio = {}
1080
+
1081
+ for stem in stems:
1082
+ audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
1083
+
1084
+ start = identifier["chunk_start"]
1085
+ end = start + self.chunk_size_samples
1086
+ audio = self._chunk_audio(audio, start, end)
1087
+
1088
+ return audio
1089
+
1090
+ def find_query_tuples_or_generate(self):
1091
+ query_path = os.path.join(self.data_path, "queries")
1092
+ val_folds = "-".join(map(str, self.folds))
1093
+
1094
+ path_so_far = os.path.join(query_path, val_folds)
1095
+
1096
+ if not os.path.exists(path_so_far):
1097
+ return self.generate_index()
1098
+
1099
+ chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
1100
+ path_so_far = os.path.join(path_so_far, chunk_specs)
1101
+
1102
+ if not os.path.exists(path_so_far):
1103
+ return self.generate_index()
1104
+
1105
+ query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
1106
+ path_so_far = os.path.join(path_so_far, query_specs)
1107
+
1108
+ if not os.path.exists(path_so_far):
1109
+ return self.generate_index()
1110
+
1111
+ if self.top_k_instrument is not None:
1112
+ path_so_far = os.path.join(
1113
+ path_so_far, f"queries-top{self.top_k_instrument}.csv"
1114
+ )
1115
+ else:
1116
+ if len(self.allowed_stems) > 5:
1117
+ allowed_stems = (
1118
+ str(len(self.allowed_stems))
1119
+ + "stems:"
1120
+ + ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
1121
+ )
1122
+ else:
1123
+ allowed_stems = ":".join(self.allowed_stems)
1124
+
1125
+ path_so_far = os.path.join(path_so_far, f"queries-{allowed_stems}.csv")
1126
+
1127
+ if not os.path.exists(path_so_far):
1128
+ return self.generate_index()
1129
+
1130
+ print(f"Loading query tuples from {path_so_far}")
1131
+
1132
+ return pd.read_csv(path_so_far)
1133
+
1134
+ def _get_index_path(self):
1135
+ query_root = os.path.join(self.data_path, "queries")
1136
+ val_folds = "-".join(map(str, self.folds))
1137
+ chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
1138
+ query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
1139
+ query_dir = os.path.join(query_root, val_folds, chunk_specs, query_specs)
1140
+
1141
+ if self.top_k_instrument is not None:
1142
+ query_path = os.path.join(
1143
+ query_dir, f"queries-top{self.top_k_instrument}.csv"
1144
+ )
1145
+ else:
1146
+ if len(self.allowed_stems) > 5:
1147
+ allowed_stems = (
1148
+ str(len(self.allowed_stems))
1149
+ + "stems:"
1150
+ + ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
1151
+ )
1152
+ else:
1153
+ allowed_stems = ":".join(self.allowed_stems)
1154
+ query_path = os.path.join(query_dir, f"queries-{allowed_stems}.csv")
1155
+
1156
+ if not self._overwrite:
1157
+ assert not os.path.exists(
1158
+ query_path
1159
+ ), f"Query path {query_path} already exists."
1160
+
1161
+ os.makedirs(query_dir, exist_ok=True)
1162
+
1163
+ return query_path
1164
+
1165
+ def generate_index(self):
1166
+
1167
+ query_path = self._get_index_path()
1168
+
1169
+ durations = pd.read_csv(os.path.join(self.data_path, "durations.csv"))
1170
+ durations = (
1171
+ durations[["song_id", "duration"]]
1172
+ .set_index("song_id")["duration"]
1173
+ .to_dict()
1174
+ )
1175
+
1176
+ tuples = []
1177
+
1178
+ stems_without_queries = defaultdict(list)
1179
+
1180
+ for i, song_id in tqdm(enumerate(self.files), total=len(self.files)):
1181
+ song_duration = durations[song_id]
1182
+ mix_stems = self.song_to_stem[song_id]
1183
+
1184
+ n_mix_chunks = math.floor(
1185
+ (song_duration - self.chunk_size_seconds) / self.chunk_hop_size_seconds
1186
+ )
1187
+
1188
+ for stem in mix_stems:
1189
+ possible_queries = self.stem_to_song[stem]
1190
+ if song_id in possible_queries:
1191
+ possible_queries.remove(song_id)
1192
+
1193
+ if len(possible_queries) == 0:
1194
+ stems_without_queries[song_id].append(stem)
1195
+ continue
1196
+
1197
+ for k in tqdm(range(n_mix_chunks), desc=f"song{i + 1}/{stem}"):
1198
+ mix_chunk_start = int(k * self.chunk_hop_size_samples)
1199
+
1200
+ for j in range(self.n_queries_per_chunk):
1201
+ query = random.choice(possible_queries)
1202
+
1203
+ tuples.append(
1204
+ dict(
1205
+ mix=song_id,
1206
+ query=query,
1207
+ stem=stem,
1208
+ mix_chunk_start=mix_chunk_start,
1209
+ )
1210
+ )
1211
+
1212
+ if len(stems_without_queries) > 0:
1213
+ print("Stems without queries:")
1214
+ for song_id, stems in stems_without_queries.items():
1215
+ print(f"{song_id}: {stems}")
1216
+
1217
+ tuples = pd.DataFrame(tuples)
1218
+
1219
+ print(
1220
+ f"Generating query tuples for {self.split} set with {len(tuples)} tuples."
1221
+ )
1222
+ print(f"Saving query tuples to {query_path}")
1223
+
1224
+ tuples.to_csv(query_path, index=False)
1225
+
1226
+ return tuples
1227
+
1228
+ def index_to_identifiers(self, index: int) -> Tuple[str, str, str, int]:
1229
+
1230
+ row = self.query_tuples.iloc[index]
1231
+ mix_id = row["mix"]
1232
+
1233
+ if self.use_own_query:
1234
+ query_id = mix_id
1235
+ else:
1236
+ query_id = row["query"]
1237
+
1238
+ stem = row["stem"]
1239
+ mix_chunk_start = row["mix_chunk_start"]
1240
+
1241
+ return mix_id, query_id, stem, mix_chunk_start
1242
+
1243
+ def __getitem__(self, index: int):
1244
+
1245
+ mix_id, query_id, stem, mix_chunk_start = self.index_to_identifiers(index)
1246
+
1247
+ mix_identifier = dict(song_id=mix_id, chunk_start=mix_chunk_start)
1248
+ query_identifier = dict(song_id=query_id)
1249
+
1250
+ audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier)
1251
+ query = self.get_query_stem(stem=stem, identifier=query_identifier)
1252
+
1253
+ mixture = audio[self.mixture_stem].copy()
1254
+ sources = {"target": audio[stem].copy()}
1255
+ query = query.copy()
1256
+
1257
+ assert mixture.shape[-1] == self.chunk_size_samples
1258
+ # print(query.shape[-1], self.query_size_samples)
1259
+ assert query.shape[-1] == self.query_size_samples
1260
+ assert sources["target"].shape[-1] == self.chunk_size_samples
1261
+
1262
+ return input_dict(
1263
+ mixture=mixture,
1264
+ sources=sources,
1265
+ query=query,
1266
+ metadata={
1267
+ "mix": mix_identifier,
1268
+ "query": query_identifier,
1269
+ "stem": stem,
1270
+ },
1271
+ modality="audio",
1272
+ )
1273
+
1274
+
1275
+ class MoisesDBFullTrackTestQueryDataset(MoisesDBFullTrackDataset):
1276
+ def __init__(
1277
+ self,
1278
+ data_root: str,
1279
+ split: str = "test",
1280
+ top_k_instrument: int = 10,
1281
+ mixture_stem: str = "mixture",
1282
+ use_own_query: bool = True,
1283
+ npy_memmap=True,
1284
+ allowed_stems: List[str] = None,
1285
+ query_file="query-10s",
1286
+ ) -> None:
1287
+ super().__init__(
1288
+ data_root=data_root,
1289
+ split=split,
1290
+ npy_memmap=npy_memmap,
1291
+ recompute_mixture=False,
1292
+ query_file=query_file,
1293
+ )
1294
+
1295
+ self.use_own_query = use_own_query
1296
+
1297
+ self.allowed_stems = allowed_stems
1298
+
1299
+ test_indices = pd.read_csv(os.path.join(data_root, "test_indices.csv"))
1300
+
1301
+ test_indices = test_indices[test_indices.stem.isin(self.allowed_stems)]
1302
+
1303
+ self.test_indices = test_indices
1304
+
1305
+ self.length = len(self.test_indices)
1306
+
1307
+ def __len__(self) -> int:
1308
+ return self.length
1309
+
1310
+ def index_to_identifiers(self, index: int) -> Tuple[str, str, str]:
1311
+
1312
+ row = self.test_indices.iloc[index]
1313
+ mix_id = row["song_id"]
1314
+ if self.use_own_query:
1315
+ query_id = mix_id
1316
+ else:
1317
+ query_id = row["query_id"]
1318
+ stem = row["stem"]
1319
+
1320
+ return mix_id, query_id, stem
1321
+
1322
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
1323
+ audio = {}
1324
+
1325
+ for stem in stems:
1326
+ audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
1327
+
1328
+ return audio
1329
+
1330
+ def __getitem__(self, index: int):
1331
+
1332
+ mix_id, query_id, stem = self.index_to_identifiers(index)
1333
+
1334
+ mix_identifier = dict(song_id=mix_id)
1335
+
1336
+ query_identifier = dict(song_id=query_id)
1337
+
1338
+ audio = self._get_audio([stem, "mixture"], identifier=mix_identifier)
1339
+ query = self.get_query_stem(stem=stem, identifier=query_identifier)
1340
+
1341
+ mixture = audio["mixture"].copy()
1342
+ sources = {stem: audio[stem].copy()}
1343
+ query = query.copy()
1344
+
1345
+ return input_dict(
1346
+ mixture=mixture,
1347
+ sources=sources,
1348
+ query=query,
1349
+ metadata={
1350
+ "mix": mix_identifier["song_id"],
1351
+ "query": query_identifier["song_id"],
1352
+ "stem": stem,
1353
+ },
1354
+ modality="audio",
1355
+ )
1356
+
1357
+
1358
+ if __name__ == "__main__":
1359
+
1360
+ print("Beginning")
1361
+
1362
+ config = "/storage/home/hcoda1/1/kwatchar3/coda/config/data/moisesdb-everything-query-d-aug.yml"
1363
+
1364
+ config = OmegaConf.load(config)
1365
+
1366
+ print("Loaded config")
1367
+
1368
+ dataset = MoisesDBRandomChunkRandomQueryDataset(
1369
+ data_root=config.data_root, split="train", **config.train_kwargs
1370
+ )
1371
+
1372
+ print("Loaded dataset")
1373
+
1374
+ for item in tqdm(dataset, total=len(dataset)):
1375
+ target_audio = item["sources"]["target"]["audio"]
1376
+ mixture = item["mixture"]["audio"]
1377
+
1378
+ if target_audio is None:
1379
+ raise ValueError
1380
+ else:
1381
+ tdb = 10.0 * torch.log10(torch.mean(torch.square(target_audio)) + 1e-6)
1382
+ mdb = 10.0 * torch.log10(torch.mean(torch.square(mixture)) + 1e-6)
1383
+ print(f"Target db: {tdb}, Mixture db: {mdb}")
core/data/moisesdb/eda.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
core/data/moisesdb/npyify.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import shutil
7
+ from itertools import chain
8
+ from pprint import pprint
9
+ from types import SimpleNamespace
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from omegaconf import OmegaConf
14
+
15
+ from tqdm.contrib.concurrent import process_map
16
+
17
+ from tqdm import tqdm as tdqm, tqdm
18
+ import torchaudio as ta
19
+
20
+ import librosa
21
+
22
+ taxonomy = {
23
+ "vocals": [
24
+ "lead male singer",
25
+ "lead female singer",
26
+ "human choir",
27
+ "background vocals",
28
+ "other (vocoder, beatboxing etc)",
29
+ ],
30
+ "bass": [
31
+ "bass guitar",
32
+ "bass synthesizer (moog etc)",
33
+ "contrabass/double bass (bass of instrings)",
34
+ "tuba (bass of brass)",
35
+ "bassoon (bass of woodwind)",
36
+ ],
37
+ "drums": [
38
+ "snare drum",
39
+ "toms",
40
+ "kick drum",
41
+ "cymbals",
42
+ "overheads",
43
+ "full acoustic drumkit",
44
+ "drum machine",
45
+ "hi-hat"
46
+ ],
47
+ "other": [
48
+ "fx/processed sound, scratches, gun shots, explosions etc",
49
+ "click track",
50
+ ],
51
+ "guitar": [
52
+ "clean electric guitar",
53
+ "distorted electric guitar",
54
+ "lap steel guitar or slide guitar",
55
+ "acoustic guitar",
56
+ ],
57
+ "other plucked": ["banjo, mandolin, ukulele, harp etc"],
58
+ "percussion": [
59
+ "a-tonal percussion (claps, shakers, congas, cowbell etc)",
60
+ "pitched percussion (mallets, glockenspiel, ...)",
61
+ ],
62
+ "piano": [
63
+ "grand piano",
64
+ "electric piano (rhodes, wurlitzer, piano sound alike)",
65
+ ],
66
+ "other keys": [
67
+ "organ, electric organ",
68
+ "synth pad",
69
+ "synth lead",
70
+ "other sounds (hapischord, melotron etc)",
71
+ ],
72
+ "bowed strings": [
73
+ "violin (solo)",
74
+ "viola (solo)",
75
+ "cello (solo)",
76
+ "violin section",
77
+ "viola section",
78
+ "cello section",
79
+ "string section",
80
+ "other strings",
81
+ ],
82
+ "wind": [
83
+ "brass (trumpet, trombone, french horn, brass etc)",
84
+ "flutes (piccolo, bamboo flute, panpipes, flutes etc)",
85
+ "reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
86
+ "other wind",
87
+ ],
88
+ }
89
+
90
+ def clean_npy_other_vox(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq"):
91
+ npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
92
+
93
+
94
+ npys = [npy for npy in npys if "other" in npy]
95
+ npys = [npy for npy in npys if "vdbo_" not in npy]
96
+ npys = [npy for npy in npys if "other_" not in npy]
97
+
98
+ stems = set([
99
+ os.path.basename(npy).split(".")[0] for npy in npys
100
+ ])
101
+
102
+ assert len(stems) == 1
103
+
104
+ for npy in tqdm(npys):
105
+ shutil.move(npy, npy.replace("other", "other_vocals"))
106
+
107
+
108
+
109
+
110
+ def clean_track_inst(inst):
111
+
112
+ if "vocoder" in inst:
113
+ inst = "other_vocals"
114
+
115
+ if "fx" in inst:
116
+ inst = "fx"
117
+
118
+ if "contrabass_double_bass" in inst:
119
+ inst = "double_bass"
120
+
121
+ if "banjo" in inst:
122
+ return "other_plucked"
123
+
124
+ if "(" in inst:
125
+ inst = inst.split("(")[0]
126
+
127
+ for s in [",", "-"]:
128
+ if s in inst:
129
+ inst = inst.replace(s, "")
130
+
131
+ for s in ["/"]:
132
+ if s in inst:
133
+ inst = inst.replace(s, "_")
134
+
135
+ if inst[-1] == "_":
136
+ inst = inst[:-1]
137
+
138
+ return inst
139
+
140
+
141
+ taxonomy = {
142
+ k.replace(" ", "_"): [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()
143
+ }
144
+
145
+ fine_to_coarse = {}
146
+
147
+ for k, v in taxonomy.items():
148
+ for vv in v:
149
+ fine_to_coarse[vv] = k
150
+
151
+ # pprint(fine_to_coarse)
152
+
153
+ def save_taxonomy():
154
+ with open("taxonomy.json", "w") as f:
155
+ json.dump(taxonomy, f, indent=4)
156
+
157
+ taxonomy_coarse = list(taxonomy.keys())
158
+
159
+ with open("taxonomy_coarse.json", "w") as f:
160
+ json.dump(taxonomy_coarse, f, indent=4)
161
+
162
+ taxonomy_fine = list(chain(*taxonomy.values()))
163
+
164
+ count_ = defaultdict(int)
165
+ for t in taxonomy_fine:
166
+ count_[t] += 1
167
+
168
+ with open("taxonomy_fine.json", "w") as f:
169
+ json.dump(taxonomy_fine, f, indent=4)
170
+
171
+
172
+
173
+ possible_coarse = list(taxonomy.keys())
174
+ possible_fine = list(set(chain(*taxonomy.values())))
175
+
176
+
177
+ def trim_and_mix(audios, length_=None):
178
+ length = min([a.shape[-1] for a in audios])
179
+
180
+ if length_ is not None:
181
+ length = min(length, length_)
182
+
183
+ audios = [a[..., :length] for a in audios]
184
+ return np.sum(np.stack(audios, axis=0), axis=0), length
185
+
186
+
187
+ def retrim_npys(saved_npy, new_length):
188
+ print("retrimming")
189
+ for npy in saved_npy:
190
+ audio = np.load(npy)
191
+ audio = audio[..., :new_length]
192
+ np.save(npy, audio)
193
+
194
+
195
+ def convert_one(inout):
196
+ input_path = inout.input_path
197
+ output_root = inout.output_root
198
+
199
+ song_id = os.path.basename(input_path)
200
+ output_root = os.path.join(output_root, song_id)
201
+ os.makedirs(output_root, exist_ok=True)
202
+
203
+ metadata = OmegaConf.load(os.path.join(input_path, "data.json"))
204
+ stems = metadata.stems
205
+
206
+ min_length = None
207
+ saved_npy = []
208
+
209
+ all_tracks = []
210
+ other_tracks = []
211
+
212
+ outfile = None
213
+
214
+ added_tracks = set()
215
+ duplicated_tracks = set()
216
+ track_to_stem = defaultdict(list)
217
+ added_stems = set()
218
+ duplicated_stems = set()
219
+
220
+ stem_name_to_stems = defaultdict(list)
221
+
222
+ for stem in stems:
223
+ stem_name = stem.stemName
224
+ stem_name_to_stems[stem_name].append(stem)
225
+
226
+
227
+ for stem_name in tqdm(stem_name_to_stems):
228
+ stem_tracks = []
229
+ for stem in stem_name_to_stems[stem_name]:
230
+ stem_name = stem.stemName
231
+
232
+ if stem_name in added_stems:
233
+ print(f"Duplicate stem {stem_name} in {song_id}")
234
+ duplicated_stems.add(stem_name)
235
+
236
+ added_stems.add(stem_name)
237
+
238
+ for track in stem.tracks:
239
+ track_inst = track.trackType
240
+ track_inst = clean_track_inst(track_inst)
241
+
242
+ if track_inst in added_tracks:
243
+ if stem_name in track_to_stem[track_inst]:
244
+ continue
245
+ print(f"Duplicate track {track_inst} in {song_id}")
246
+ print(f"Stems: {track_to_stem[track_inst]}")
247
+ duplicated_tracks.add(track_inst)
248
+ raise ValueError
249
+ else:
250
+ added_tracks.add(track_inst)
251
+
252
+ track_to_stem[track_inst].append(stem_name)
253
+ track_id = track.id
254
+
255
+ audio, fs = ta.load(os.path.join(input_path, stem_name, f"{track_id}.wav"))
256
+
257
+ if fs != 44100:
258
+ print(f"fs is {fs} for {track_id}")
259
+ with open(os.path.join(output_root, "fs.txt"), "w") as f:
260
+ f.write(f"{song_id}\t{track_id}\t{fs}\n")
261
+
262
+ if min_length is None:
263
+ min_length = audio.shape[-1]
264
+ else:
265
+ if audio.shape[-1] < min_length:
266
+ min_length = audio.shape[-1]
267
+
268
+ if len(saved_npy) > 0:
269
+ retrim_npys(saved_npy, min_length)
270
+
271
+ audio = audio[..., :min_length]
272
+ audio = audio.numpy()
273
+ audio = audio.astype(np.float32)
274
+
275
+ if audio.shape[0] == 1:
276
+ print("mono")
277
+ if audio.shape[0] > 2:
278
+ print("multi channel")
279
+
280
+ assert outfile is None
281
+ outfile = os.path.join(output_root, f"{track_inst}.npy")
282
+ np.save(outfile, audio)
283
+ saved_npy.append(outfile)
284
+ outfile = None
285
+ stem_tracks.append(audio)
286
+ audio = None
287
+
288
+ stem_track, min_length = trim_and_mix(stem_tracks)
289
+
290
+ assert outfile is None
291
+ outfile = os.path.join(output_root, f"{stem_name}.npy")
292
+ np.save(outfile, stem_track)
293
+ saved_npy.append(outfile)
294
+ outfile = None
295
+
296
+ all_tracks.append(stem_track)
297
+
298
+ if stem_name not in ["vocals", "drums", "bass"]:
299
+ # print(f"Putting {stem_name} in other")
300
+ other_tracks.append(stem_track)
301
+
302
+
303
+ assert outfile is None
304
+ all_track, min_length_ = trim_and_mix(all_tracks, min_length)
305
+ outfile = os.path.join(output_root, f"mixture.npy")
306
+ np.save(outfile, all_track)
307
+
308
+ if min_length_ != min_length:
309
+ retrim_npys(saved_npy, min_length_)
310
+ min_length = min_length_
311
+
312
+ saved_npy.append(outfile)
313
+ outfile = None
314
+
315
+ other_track, min_length_ = trim_and_mix(other_tracks, min_length)
316
+ np.save(os.path.join(output_root, f"vdbo_others.npy"), other_track)
317
+
318
+ if min_length_ != min_length:
319
+ retrim_npys(saved_npy, min_length_)
320
+ min_length = min_length_
321
+
322
+
323
+ def convert_to_npy(
324
+ data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/canonical",
325
+ output_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2",
326
+ ):
327
+ if output_root is None:
328
+ output_root = os.path.join(os.path.dirname(data_root), "npy")
329
+
330
+ files = os.listdir(data_root)
331
+ files = [
332
+ os.path.join(data_root, f)
333
+ for f in files
334
+ if os.path.isdir(os.path.join(data_root, f))
335
+ ]
336
+
337
+ inout = [SimpleNamespace(input_path=f, output_root=output_root) for f in files]
338
+
339
+ process_map(convert_one, inout)
340
+
341
+ # for io in tdqm(inout):
342
+ # convert_one(io)
343
+
344
+
345
+ def make_others_one(input_path, dry_run=False):
346
+
347
+ other_stems = [k for k in taxonomy.keys() if k not in ["vocals", "bass", "drums"]]
348
+ npys = glob.glob(os.path.join(input_path, "**/*.npy"), recursive=True)
349
+
350
+ npys = [npy for npy in npys if ".dbfs" not in npy]
351
+ npys = [npy for npy in npys if ".query" not in npy]
352
+ npys = [npy for npy in npys if "mixture" not in npy]
353
+ npys = [npy for npy in npys if os.path.basename(npy).split(".")[0] in other_stems]
354
+
355
+ print(f"Using stems: {[os.path.basename(npy).split('.')[0] for npy in npys]}")
356
+
357
+ if len(npys) == 0:
358
+ audio = np.zeros_like(np.load(os.path.join(input_path, "mixture.npy")))
359
+ else:
360
+ audio = [np.load(npy) for npy in npys]
361
+
362
+ audio = np.sum(np.stack(audio, axis=0), axis=0)
363
+ assert audio.shape[0] == 2
364
+
365
+ output = os.path.join(input_path, "vdbo_others.npy")
366
+
367
+ if dry_run:
368
+ return
369
+
370
+ np.save(output, audio)
371
+
372
+
373
+ def check_vdbo_one(f):
374
+ s = np.sum(
375
+ np.stack(
376
+ [
377
+ np.load(os.path.join(f, s + ".npy"))
378
+ for s in ["vocals", "drums", "bass", "vdbo_others"]
379
+ if os.path.exists(os.path.join(f, s + ".npy"))
380
+ ],
381
+ axis=0,
382
+ ),
383
+ axis=0,
384
+ )
385
+ m = np.load(os.path.join(f, "mixture.npy"))
386
+ snr = 10 * np.log10(np.mean(np.square(m)) / np.mean(np.square(s - m)))
387
+ print(snr)
388
+
389
+ return snr
390
+
391
+ def check_vdbo(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):
392
+ files = os.listdir(data_root)
393
+
394
+ files = [
395
+ os.path.join(data_root, f)
396
+ for f in files
397
+ if os.path.isdir(os.path.join(data_root, f))
398
+ ]
399
+
400
+ snrs = process_map(check_vdbo_one, files)
401
+
402
+ np.save("/storage/home/hcoda1/1/kwatchar3/data/vdbo.npy", np.array(snrs))
403
+
404
+
405
+ def make_others(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):
406
+
407
+ files = os.listdir(data_root)
408
+
409
+ files = [
410
+ os.path.join(data_root, f)
411
+ for f in files
412
+ if os.path.isdir(os.path.join(data_root, f))
413
+ ]
414
+
415
+ process_map(make_others_one, files)
416
+
417
+ # for f in tqdm(files):
418
+ # make_others_one(f, dry_run=False)
419
+
420
+
421
+ def extract_metadata_one(input_path):
422
+ song_id = os.path.basename(input_path)
423
+ metadata = OmegaConf.load(os.path.join(input_path, "data.json"))
424
+
425
+ song = metadata.song
426
+ artist = metadata.artist
427
+ genre = metadata.genre
428
+
429
+ stems = metadata.stems
430
+ data_out = []
431
+
432
+ for stem in stems:
433
+ stem_name = stem.stemName
434
+ stem_id = stem.id
435
+ for track in stem.tracks:
436
+ track_inst = track.trackType
437
+ track_id = track.id
438
+
439
+ data_out.append(
440
+ {
441
+ "song_id": song_id,
442
+ "song": song,
443
+ "artist": artist,
444
+ "genre": genre,
445
+ "stem_name": stem_name,
446
+ "stem_id": stem_id,
447
+ "track_inst": track_inst,
448
+ "track_id": track_id,
449
+ "has_bleed": track.has_bleed,
450
+ }
451
+ )
452
+
453
+ return data_out
454
+
455
+
456
+ def consolidate_metadata(
457
+ data_root="/home/kwatchar3/Documents/data/moisesdb/canonical",
458
+ ):
459
+
460
+ files = os.listdir(data_root)
461
+ files = [
462
+ os.path.join(data_root, f)
463
+ for f in files
464
+ if os.path.isdir(os.path.join(data_root, f))
465
+ ]
466
+
467
+ data = process_map(extract_metadata_one, files)
468
+
469
+ df = pd.DataFrame.from_records(list(chain(*data)))
470
+
471
+ df.to_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"), index=False)
472
+
473
+
474
+ def clean_canonical(data_root="/home/kwatchar3/Documents/data/moisesdb/canonical"):
475
+
476
+ npy = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
477
+
478
+ for n in tqdm(npy):
479
+ os.remove(n)
480
+
481
+
482
+ def remove_dbfs(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy"):
483
+ npy = glob.glob(os.path.join(data_root, "**/*.dbfs.npy"), recursive=True)
484
+
485
+ for n in tqdm(npy):
486
+ os.remove(n)
487
+
488
+
489
+ def make_split(
490
+ metadata_path="/home/kwatchar3/Documents/data/moisesdb/metadata.csv",
491
+ n_splits=5,
492
+ seed=42,
493
+ ):
494
+
495
+ df = pd.read_csv(metadata_path)
496
+ # print(df.columns)
497
+ df = df[["song_id", "genre"]].drop_duplicates()
498
+
499
+ genres = df["genre"].value_counts()
500
+ genres_map = {g: g if c > n_splits else "other" for g, c in genres.items()}
501
+
502
+ df["genre"] = df["genre"].map(genres_map)
503
+
504
+ n_samples = len(df)
505
+ n_per_split = n_samples // n_splits
506
+
507
+ np.random.seed(seed)
508
+
509
+ from sklearn.model_selection import train_test_split
510
+
511
+ splits = []
512
+
513
+ df_ = df.copy()
514
+
515
+ for i in range(n_splits - 1):
516
+ df_, test = train_test_split(
517
+ df_,
518
+ test_size=n_per_split,
519
+ random_state=seed,
520
+ stratify=df_["genre"],
521
+ shuffle=True,
522
+ )
523
+
524
+ dfs = test[["song_id"]].copy().sort_values(by="song_id")
525
+ dfs["split"] = i + 1
526
+ splits.append(dfs)
527
+
528
+ test = df_
529
+ dfs = test[["song_id"]].copy().sort_values(by="song_id")
530
+ dfs["split"] = n_splits
531
+ splits.append(dfs)
532
+
533
+ splits = pd.concat(splits)
534
+
535
+ splits.to_csv(
536
+ os.path.join(os.path.dirname(metadata_path), "splits.csv"), index=False
537
+ )
538
+
539
+
540
+ def consolidate_stems(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
541
+
542
+ metadata = pd.read_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"))
543
+
544
+ dfg = metadata.groupby("song_id")[["stem_name", "track_inst"]]
545
+
546
+ pprint(dfg)
547
+
548
+ df = []
549
+
550
+ def make_stem_dict(song_id, track_inst, stem_names):
551
+
552
+ d = {"song_id": song_id}
553
+
554
+ for inst in possible_fine:
555
+ d[inst] = int(inst in track_inst)
556
+
557
+ for inst in possible_coarse:
558
+ d[inst] = int(inst in stem_names)
559
+
560
+ return d
561
+
562
+ for song_id, dfgg in dfg:
563
+
564
+ track_inst = dfgg["track_inst"].tolist()
565
+ track_inst = list(set(track_inst))
566
+ track_inst = [clean_track_inst(inst) for inst in track_inst]
567
+
568
+ stem_names = dfgg["stem_name"].tolist()
569
+ stem_names = list(set([clean_track_inst(inst) for inst in stem_names]))
570
+
571
+ d = make_stem_dict(song_id, track_inst, stem_names)
572
+ df.append(d)
573
+
574
+ print(df)
575
+
576
+ df = pd.DataFrame.from_records(df)
577
+
578
+ df.to_csv(os.path.join(os.path.dirname(data_root), "stems.csv"), index=False)
579
+
580
+
581
+ def get_dbfs(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
582
+ npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
583
+
584
+ dbfs = []
585
+
586
+ for npy in tqdm(npys):
587
+ audio = np.load(npy)
588
+ song_id = os.path.basename(os.path.dirname(npy))
589
+ track_id = os.path.basename(npy).split(".")[0]
590
+
591
+ dbfs.append(
592
+ {
593
+ "song_id": song_id,
594
+ "track_id": track_id,
595
+ "dbfs": 10 * np.log10(np.mean(np.square(audio))),
596
+ }
597
+ )
598
+
599
+ dbfs = pd.DataFrame.from_records(dbfs)
600
+
601
+ dbfs.to_csv(os.path.join(os.path.dirname(data_root), "dbfs.csv"), index=False)
602
+
603
+ return dbfs
604
+
605
+
606
+ def get_dbfs_by_chunk_one(inout):
607
+
608
+ audio = np.load(inout.audio_path, mmap_mode="r")
609
+ chunk_size = inout.chunk_size
610
+ fs = inout.fs
611
+ hop_size = inout.hop_size
612
+
613
+ n_chan, n_samples = audio.shape
614
+ chunk_size_samples = int(chunk_size * fs)
615
+ hop_size_samples = int(hop_size * fs)
616
+
617
+ x2win = np.lib.stride_tricks.sliding_window_view(
618
+ np.square(audio), chunk_size_samples, axis=1
619
+ )[:, ::hop_size_samples, :]
620
+
621
+ x2win_mean = np.mean(x2win, axis=(0, 2))
622
+ x2win_mean[x2win_mean == 0] = 1e-8
623
+ dbfs = 10 * np.log10(x2win_mean)
624
+
625
+ # song_id = os.path.basename(os.path.dirname(inout.audio_path))
626
+ track_id = os.path.basename(inout.audio_path).split(".")[0]
627
+
628
+ np.save(
629
+ os.path.join(os.path.dirname(inout.audio_path), f"{track_id}.dbfs.npy"), dbfs
630
+ )
631
+
632
+
633
+ def clean_data_root(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
634
+ npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
635
+
636
+ for npy in tqdm(npys):
637
+ if ".dbfs" in npy or ".query" in npy:
638
+ # print("removing", npy)
639
+ os.remove(npy)
640
+
641
+
642
+ #
643
+ def get_dbfs_by_chunk(
644
+ data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
645
+ query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
646
+ ):
647
+ npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
648
+
649
+ inout = [
650
+ SimpleNamespace(
651
+ audio_path=npy,
652
+ chunk_size=1,
653
+ hop_size=0.125,
654
+ fs=44100,
655
+ output_path=npy.replace(data_root, query_root).replace(
656
+ ".npy", ".query.npy"
657
+ ),
658
+ )
659
+ for npy in npys
660
+ ]
661
+
662
+ process_map(get_dbfs_by_chunk_one, inout, chunksize=2)
663
+
664
+
665
+ def round_samples(seconds, fs, hop_size, downsample):
666
+ n_frames = math.ceil(seconds * fs / hop_size) + 1
667
+ n_frames_down = math.ceil(n_frames / downsample)
668
+ n_frames = n_frames_down * downsample
669
+ n_samples = (n_frames - 1) * hop_size
670
+
671
+ return int(n_samples)
672
+
673
+
674
+ def get_query_one(inout):
675
+
676
+ audio = np.load(inout.audio_path, mmap_mode="r")
677
+ chunk_size = inout.chunk_size
678
+ fs = inout.fs
679
+ output_path = inout.output_path
680
+ round = inout.round
681
+ hop_size = inout.hop_size
682
+
683
+ if round:
684
+ chunk_size_samples = round_samples(chunk_size, fs, 512, 2**6)
685
+ else:
686
+ chunk_size_samples = int(chunk_size * fs)
687
+
688
+ audio_mono = np.mean(audio, axis=0)
689
+
690
+ onset = librosa.onset.onset_detect(
691
+ y=audio_mono, sr=fs, units="frames", hop_length=hop_size
692
+ )
693
+
694
+ onset_strength = librosa.onset.onset_strength(
695
+ y=audio_mono, sr=fs, hop_length=hop_size
696
+ )
697
+
698
+ n_frames_per_chunk = chunk_size_samples // hop_size
699
+
700
+ onset_strength_slide = np.lib.stride_tricks.sliding_window_view(
701
+ onset_strength, n_frames_per_chunk, axis=0
702
+ )
703
+
704
+ onset_strength = np.mean(onset_strength_slide, axis=1)
705
+
706
+ max_onset_frame = np.argmax(onset_strength)
707
+
708
+ max_onset_samples = librosa.frames_to_samples(max_onset_frame)
709
+
710
+ track_id = os.path.basename(inout.audio_path).split(".")[0]
711
+
712
+ segment = audio[:, max_onset_samples : max_onset_samples + chunk_size_samples]
713
+
714
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
715
+
716
+ np.save(output_path, segment)
717
+
718
+
719
+ def get_query_from_onset(
720
+ data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2", # "/home/kwatchar3/Documents/data/moisesdb/npy",
721
+ query_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq", # "/home/kwatchar3/Documents/data/moisesdb/npyq",
722
+ query_file="query-10s",
723
+ pmap=True,
724
+ ):
725
+ npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
726
+
727
+ npys = [npy for npy in npys if "dbfs" not in npy]
728
+
729
+ inout = [
730
+ SimpleNamespace(
731
+ audio_path=npy,
732
+ chunk_size=10,
733
+ hop_size=512,
734
+ round=False,
735
+ fs=44100,
736
+ output_path=npy.replace(data_root, query_root).replace(
737
+ ".npy", f".{query_file}.npy"
738
+ ),
739
+ )
740
+ for npy in npys
741
+ ]
742
+
743
+ if pmap:
744
+ process_map(get_query_one, inout, chunksize=2, max_workers=24)
745
+ else:
746
+ for io in tqdm(inout):
747
+ get_query_one(io)
748
+
749
+
750
+ def get_durations(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
751
+ npys = glob.glob(os.path.join(data_root, "**/mixture.npy"), recursive=True)
752
+
753
+ durations = []
754
+
755
+ for npy in tqdm(npys):
756
+ audio = np.load(npy, mmap_mode="r")
757
+ song_id = os.path.basename(os.path.dirname(npy))
758
+ track_id = os.path.basename(npy).split(".")[0]
759
+
760
+ durations.append(
761
+ {
762
+ "song_id": song_id,
763
+ "track_id": track_id,
764
+ "duration": audio.shape[-1] / 44100,
765
+ }
766
+ )
767
+
768
+ durations = pd.DataFrame.from_records(durations)
769
+
770
+ durations.to_csv(
771
+ os.path.join(os.path.dirname(data_root), "durations.csv"), index=False
772
+ )
773
+
774
+ return durations
775
+
776
+
777
+ def clean_query_root(
778
+ data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
779
+ query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
780
+ ):
781
+ npys = glob.glob(os.path.join(data_root, "**/*.query.npy"), recursive=True)
782
+
783
+ for npy in tqdm(npys):
784
+ dst = npy.replace(data_root, query_root)
785
+ dstdir = os.path.dirname(dst)
786
+ os.makedirs(dstdir, exist_ok=True)
787
+ shutil.move(npy, dst)
788
+
789
+
790
+ def make_test_indices(
791
+ metadata_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/metadata.csv",
792
+ stem_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/stems.csv",
793
+ splits_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/splits.csv",
794
+ test_split=5,
795
+ ):
796
+
797
+ coarse_stems = set(taxonomy.keys())
798
+ fine_stems = set(chain(*taxonomy.values()))
799
+
800
+ metadata = pd.read_csv(metadata_path)
801
+ splits = pd.read_csv(splits_path)
802
+ stems = pd.read_csv(stem_path)
803
+
804
+ file_in_test = splits[splits["split"] == test_split]["song_id"].tolist()
805
+
806
+ stems_test = stems[stems["song_id"].isin(file_in_test)]
807
+ metadata_test = metadata[metadata["song_id"].isin(file_in_test)]
808
+ splits_test = splits[splits["split"] == test_split]
809
+
810
+ stems_test = stems_test.set_index("song_id")
811
+ metadata_test = metadata_test.drop_duplicates("song_id").set_index("song_id")
812
+ splits_test = splits_test.set_index("song_id")
813
+
814
+ stem_to_song_id = defaultdict(list)
815
+ song_id_to_stem = defaultdict(list)
816
+
817
+ for song_id in file_in_test:
818
+
819
+ stems_ = stems_test.loc[song_id]
820
+ stem_names = stems_.T
821
+ stem_names = stem_names[stem_names == 1].index.tolist()
822
+
823
+ for stem in stem_names:
824
+ stem_to_song_id[stem].append(song_id)
825
+
826
+ song_id_to_stem[song_id] = stem_names
827
+
828
+
829
+ indices = []
830
+ no_query = []
831
+
832
+ for song_id in file_in_test:
833
+
834
+ genre = metadata_test.loc[song_id, "genre"]
835
+ # print(genre)
836
+ artist = metadata_test.loc[song_id, "artist"]
837
+ # print(artist)
838
+
839
+ stems_ = song_id_to_stem[song_id]
840
+
841
+ for stem in stems_:
842
+ possible_query = stem_to_song_id[stem]
843
+ possible_query = [p for p in possible_query if p != song_id]
844
+
845
+ if len(possible_query) == 0:
846
+ print(f"No possible query for {song_id} with {stem}")
847
+
848
+ no_query.append(
849
+ {
850
+ "song_id": song_id,
851
+ "stem": stem
852
+ }
853
+ )
854
+ continue
855
+
856
+ query_df = metadata_test.loc[possible_query, ["genre", "artist"]]
857
+
858
+ assert len(query_df) > 0
859
+
860
+ query_df_ = query_df.copy()
861
+
862
+ same_genre = True
863
+ different_artist = True
864
+ query_df = query_df[(query_df["genre"] == genre) & (query_df["artist"] != artist)]
865
+
866
+ if len(query_df) == 0:
867
+
868
+ same_genre = False
869
+ different_artist = True
870
+
871
+ query_df = query_df_.copy()
872
+ query_df = query_df[(query_df["artist"] != artist)]
873
+
874
+ if len(query_df) == 0:
875
+
876
+ same_genre = True
877
+ different_artist = False
878
+
879
+ query_df = query_df_.copy()
880
+ query_df = query_df[(query_df["genre"] == genre)]
881
+
882
+ if len(query_df) == 0:
883
+
884
+ same_genre = False
885
+ different_artist = False
886
+
887
+ query_df = query_df_.copy()
888
+
889
+ query_id = query_df.sample(1).index[0]
890
+
891
+ indices.append(
892
+ {
893
+ "song_id": song_id,
894
+ "query_id": query_id,
895
+ "stem": stem,
896
+ "same_genre": same_genre,
897
+ "different_artist": different_artist
898
+ }
899
+ )
900
+
901
+ indices = pd.DataFrame.from_records(indices)
902
+ no_query = pd.DataFrame.from_records(no_query)
903
+
904
+ indices.to_csv(
905
+ os.path.join(os.path.dirname(metadata_path), "test_indices.csv"), index=False
906
+ )
907
+
908
+ no_query.to_csv(
909
+ os.path.join(os.path.dirname(metadata_path), "no_query.csv"), index=False
910
+ )
911
+
912
+ print("Total number of queries:", len(indices))
913
+ print("Total number of no queries:", len(no_query))
914
+
915
+ query_type = indices.groupby(["same_genre", "different_artist"]).size()
916
+
917
+ print(query_type)
918
+
919
+
920
+ if __name__ == "__main__":
921
+ import fire
922
+
923
+ fire.Fire()
core/data/moisesdb/passt.ipynb ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import glob\n",
10
+ "\n",
11
+ "\n",
12
+ "data_root = \"/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/passt\"\n",
13
+ "\n",
14
+ "files = glob.glob(data_root + \"/*.passt.npy\")"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": []
23
+ }
24
+ ],
25
+ "metadata": {
26
+ "language_info": {
27
+ "name": "python"
28
+ }
29
+ },
30
+ "nbformat": 4,
31
+ "nbformat_minor": 2
32
+ }
core/losses/__init__.py ADDED
File without changes
core/losses/base.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+ from torch import nn
3
+ import torch
4
+ from torch.nn.modules.loss import _Loss
5
+
6
+ from core.types import BatchedInputOutput
7
+ from torch.nn import functional as F
8
+
9
+ class BaseLossHandler(nn.Module):
10
+ def __init__(
11
+ self, loss: nn.Module, modality: Union[str, List[str]], name: Optional[str] = None
12
+ ) -> None:
13
+ super().__init__()
14
+
15
+ self.loss = loss
16
+
17
+ if isinstance(modality, str):
18
+ modality = [modality]
19
+
20
+ self.modality = modality
21
+
22
+ if name is None:
23
+ name = "loss"
24
+
25
+ if name == "__auto__":
26
+ name = self.loss.__class__.__name__
27
+
28
+ self.name = name
29
+
30
+ def _audio_preprocess(self, y_pred, y_true):
31
+
32
+ n_sample_true = y_true.shape[-1]
33
+ n_sample_pred = y_pred.shape[-1]
34
+
35
+ if n_sample_pred > n_sample_true:
36
+ y_pred = y_pred[..., :n_sample_true]
37
+ elif n_sample_pred < n_sample_true:
38
+ y_true = y_true[..., :n_sample_pred]
39
+
40
+ return y_pred, y_true
41
+
42
+ def forward(self, batch: BatchedInputOutput):
43
+ y_true = batch.sources
44
+ y_pred = batch.estimates
45
+
46
+ loss_contribs = {}
47
+
48
+ stem_contribs = {
49
+ stem: 0.0 for stem in y_pred.keys()
50
+ }
51
+
52
+ for stem in y_pred.keys():
53
+ for modality in self.modality:
54
+
55
+ if modality not in y_pred[stem].keys():
56
+ continue
57
+
58
+ if y_pred[stem][modality].shape[-1] == 0:
59
+ continue
60
+
61
+ y_true_ = y_true[stem][modality]
62
+ y_pred_ = y_pred[stem][modality]
63
+
64
+ if modality == "audio":
65
+ y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
66
+ elif modality == "spectrogram":
67
+ y_pred_ = torch.view_as_real(y_pred_)
68
+ y_true_ = torch.view_as_real(y_true_)
69
+
70
+ loss_contribs[f"{self.name}/{stem}/{modality}"] = self.loss(
71
+ y_pred_, y_true_
72
+ )
73
+
74
+ stem_contribs[stem] += loss_contribs[f"{self.name}/{stem}/{modality}"]
75
+
76
+ total_loss = sum(stem_contribs.values())
77
+ loss_contribs[self.name] = total_loss
78
+
79
+ with torch.no_grad():
80
+ for stem in stem_contribs.keys():
81
+ loss_contribs[f"{self.name}/{stem}"] = stem_contribs[stem]
82
+
83
+ return loss_contribs
84
+
85
+
86
+ class AdversarialLossHandler(BaseLossHandler):
87
+ def __init__(self, loss: nn.Module, modality: str, name: Optional[str] = "adv_loss"):
88
+
89
+ super().__init__(loss, modality, name)
90
+
91
+ def discriminator_forward(self, batch: BatchedInputOutput):
92
+
93
+ y_true = batch.sources
94
+ y_pred = batch.estimates
95
+
96
+ # g_loss_contribs = {}
97
+ d_loss_contribs = {}
98
+
99
+ for stem in y_pred.keys():
100
+
101
+ if self.modality not in y_pred[stem].keys():
102
+ continue
103
+
104
+ if y_pred[stem][self.modality].shape[-1] == 0:
105
+ continue
106
+
107
+ y_true_ = y_true[stem][self.modality]
108
+ y_pred_ = y_pred[stem][self.modality]
109
+
110
+ if self.modality == "audio":
111
+ y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
112
+
113
+ # g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
114
+ # y_pred_, y_true_
115
+ # )
116
+
117
+ d_loss_contribs[f"{self.name}:d/{stem}"] = self.loss.discriminator_loss(
118
+ y_pred_, y_true_
119
+ )
120
+
121
+ # g_total_loss = sum(g_loss_contribs.values())
122
+ d_total_loss = sum(d_loss_contribs.values())
123
+
124
+ # g_loss_contribs["loss"] = g_total_loss
125
+ d_loss_contribs["disc_loss"] = d_total_loss
126
+
127
+ return d_loss_contribs
128
+
129
+ def generator_forward(self, batch: BatchedInputOutput):
130
+
131
+ y_true = batch.sources
132
+ y_pred = batch.estimates
133
+
134
+ g_loss_contribs = {}
135
+ # d_loss_contribs = {}
136
+
137
+ for stem in y_pred.keys():
138
+
139
+ if self.modality not in y_pred[stem].keys():
140
+ continue
141
+
142
+ if y_pred[stem][self.modality].shape[-1] == 0:
143
+ continue
144
+
145
+ y_true_ = y_true[stem][self.modality]
146
+ y_pred_ = y_pred[stem][self.modality]
147
+
148
+ if self.modality == "audio":
149
+ y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
150
+
151
+ g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
152
+ y_pred_, y_true_
153
+ )
154
+
155
+ # d_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.discriminator_loss(
156
+ # y_pred_, y_true_
157
+ # )
158
+
159
+ g_total_loss = sum(g_loss_contribs.values())
160
+ # d_total_loss = sum(d_loss_contribs.values())
161
+
162
+ g_loss_contribs["gen_loss"] = g_total_loss
163
+ # d_loss_contribs["loss"] = d_total_loss
164
+
165
+ return g_loss_contribs
166
+
167
+ def forward(self, batch: BatchedInputOutput):
168
+ return {
169
+ "generator": self.generator_forward(batch),
170
+ "discriminator": self.discriminator_forward(batch)
171
+ }
core/losses/l1snr.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ import torch.nn.functional as F
4
+
5
+ class WeightedL1Loss(_Loss):
6
+ def __init__(self, weights=None):
7
+ super().__init__()
8
+
9
+ def forward(self, y_pred, y_true):
10
+ ndim = y_pred.ndim
11
+ dims = list(range(1, ndim))
12
+ loss = F.l1_loss(y_pred, y_true, reduction='none')
13
+ loss = torch.mean(loss, dim=dims)
14
+ weights = torch.mean(torch.abs(y_true), dim=dims)
15
+
16
+ loss = torch.sum(loss * weights) / torch.sum(weights)
17
+
18
+ return loss
19
+
20
+
21
+ class L1MatchLoss(_Loss):
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ def forward(self, y_pred, y_true):
26
+ batch_size = y_pred.shape[0]
27
+
28
+ y_pred = y_pred.reshape(batch_size, -1)
29
+ y_true = y_true.reshape(batch_size, -1)
30
+
31
+ l1_true = torch.mean(torch.abs(y_true), dim=-1)
32
+ l1_pred = torch.mean(torch.abs(y_pred), dim=-1)
33
+ loss = torch.mean(torch.abs(l1_pred - l1_true))
34
+
35
+ return loss
36
+
37
+ class DecibelMatchLoss(_Loss):
38
+ def __init__(self, eps=1e-3):
39
+ super().__init__()
40
+
41
+ self.eps = eps
42
+
43
+ def forward(self, y_pred, y_true):
44
+ batch_size = y_pred.shape[0]
45
+
46
+ y_pred = y_pred.reshape(batch_size, -1)
47
+ y_true = y_true.reshape(batch_size, -1)
48
+
49
+ db_true = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_true)), dim=-1))
50
+ db_pred = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_pred)), dim=-1))
51
+ loss = torch.mean(torch.abs(db_pred - db_true))
52
+
53
+ return loss
54
+
55
+ class L1SNRLoss(_Loss):
56
+ def __init__(self, eps=1e-3):
57
+ super().__init__()
58
+ self.eps = torch.tensor(eps)
59
+
60
+ def forward(self, y_pred, y_true):
61
+ batch_size = y_pred.shape[0]
62
+
63
+ y_pred = y_pred.reshape(batch_size, -1)
64
+ y_true = y_true.reshape(batch_size, -1)
65
+
66
+ l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
67
+ l1_true = torch.mean(torch.abs(y_true), dim=-1)
68
+
69
+ snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))
70
+
71
+ return -torch.mean(snr)
72
+
73
+ class L1SNRLossIgnoreSilence(_Loss):
74
+ def __init__(self, eps=1e-3, dbthresh=-20, dbthresh_step=20):
75
+ super().__init__()
76
+ self.eps = torch.tensor(eps)
77
+ self.dbthresh = dbthresh
78
+ self.dbthresh_step = dbthresh_step
79
+
80
+ def forward(self, y_pred, y_true):
81
+ batch_size = y_pred.shape[0]
82
+
83
+ y_pred = y_pred.reshape(batch_size, -1)
84
+ y_true = y_true.reshape(batch_size, -1)
85
+
86
+ l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
87
+ l1_true = torch.mean(torch.abs(y_true), dim=-1)
88
+
89
+ snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))
90
+
91
+ db = 10.0 * torch.log10(torch.mean(torch.square(y_true), dim=-1) + 1e-6)
92
+
93
+ if torch.sum(db > self.dbthresh) == 0:
94
+ if torch.sum(db > self.dbthresh - self.dbthresh_step) == 0:
95
+ return -torch.mean(snr)
96
+ else:
97
+ return -torch.mean(snr[db > self.dbthresh - self.dbthresh_step])
98
+
99
+ return -torch.mean(snr[db > self.dbthresh])
100
+
101
+ class L1SNRDecibelMatchLoss(_Loss):
102
+ def __init__(self, db_weight=0.1, l1snr_eps=1e-3, dbeps=1e-3):
103
+ super().__init__()
104
+ self.l1snr = L1SNRLoss(l1snr_eps)
105
+ self.decibel_match = DecibelMatchLoss(dbeps)
106
+ self.db_weight = db_weight
107
+
108
+ def forward(self, y_pred, y_true):
109
+
110
+ return self.l1snr(y_pred, y_true) + self.decibel_match(y_pred, y_true)