Spaces:
Runtime error
Runtime error
Commit
·
d572f56
0
Parent(s):
have to create an orphan branch to bypass large file history: cleanup .ipynb and create LFS
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .github/FUNDING.yml +14 -0
- .gitignore +166 -0
- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/Project_Default.xml +50 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/query-bandit.iml +8 -0
- .idea/vcs.xml +6 -0
- .vscode/launch.json +21 -0
- LICENSE +21 -0
- README.md +124 -0
- assets/banquet-logo.png +0 -0
- config/data/moisesdb-test.yml +55 -0
- config/data/setup-a/moisesdb-vdb-query-d-aug.yml +63 -0
- config/data/setup-a/moisesdb-vdb-query-d.yml +46 -0
- config/data/setup-a/moisesdb-vdb-query.yml +46 -0
- config/data/setup-b/moisesdb-vdbgp-query-d-aug-bal.yml +67 -0
- config/data/setup-b/moisesdb-vdbgp-query-d-aug.yml +67 -0
- config/data/setup-b/moisesdb-vdbgp-query-d.yml +50 -0
- config/data/setup-b/moisesdb-vdbgp-query.yml +50 -0
- config/data/setup-c/moisesdb-everything-query-d-aug-bal.yml +117 -0
- config/data/setup-c/moisesdb-everything-query-d-aug.yml +117 -0
- config/data/setup-c/moisesdb-everything-query-d-bal.yml +100 -0
- config/data/setup-c/moisesdb-everything-query-d.yml +100 -0
- config/data/vdbo/moisesdb-vdbo-aug.yml +35 -0
- config/data/vdbo/moisesdb-vdbo.yml +18 -0
- config/losses/both_l1snr.yml +4 -0
- config/losses/both_l1snrdbm.yml +4 -0
- config/models/bandit-query-pre.yml +31 -0
- config/models/bandit-query-prefz.yml +31 -0
- config/models/bandit-query.yml +29 -0
- config/models/bandit-vdbo.yml +27 -0
- config/optim/adam.yml +9 -0
- config/trainer/default-long.yml +12 -0
- config/trainer/default.yml +12 -0
- core/__init__.py +0 -0
- core/data/__init__.py +0 -0
- core/data/base.py +138 -0
- core/data/moisesdb/__init__.py +97 -0
- core/data/moisesdb/audio.ipynb +76 -0
- core/data/moisesdb/datamodule.py +239 -0
- core/data/moisesdb/dataset.py +1383 -0
- core/data/moisesdb/eda.ipynb +0 -0
- core/data/moisesdb/npyify.py +923 -0
- core/data/moisesdb/passt.ipynb +32 -0
- core/losses/__init__.py +0 -0
- core/losses/base.py +171 -0
- core/losses/l1snr.py +110 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
.github/FUNDING.yml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These are supported funding model platforms
|
| 2 |
+
|
| 3 |
+
github: kwatcharasupat # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
| 4 |
+
patreon: # Replace with a single Patreon username
|
| 5 |
+
open_collective: # Replace with a single Open Collective username
|
| 6 |
+
ko_fi: # Replace with a single Ko-fi username
|
| 7 |
+
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
| 8 |
+
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
| 9 |
+
liberapay: # Replace with a single Liberapay username
|
| 10 |
+
issuehunt: # Replace with a single IssueHunt username
|
| 11 |
+
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
| 12 |
+
polar: # Replace with a single Polar username
|
| 13 |
+
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
| 14 |
+
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
.gitignore
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
input/
|
| 164 |
+
output/
|
| 165 |
+
logs/
|
| 166 |
+
checkpoints/
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 5 |
+
<Languages>
|
| 6 |
+
<language minSize="52" name="Python" />
|
| 7 |
+
</Languages>
|
| 8 |
+
</inspection_tool>
|
| 9 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
| 10 |
+
<inspection_tool class="PyAttributeOutsideInitInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
| 11 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 12 |
+
<option name="ignoredPackages">
|
| 13 |
+
<value>
|
| 14 |
+
<list size="8">
|
| 15 |
+
<item index="0" class="java.lang.String" itemvalue="pytorch_lightning" />
|
| 16 |
+
<item index="1" class="java.lang.String" itemvalue="torch" />
|
| 17 |
+
<item index="2" class="java.lang.String" itemvalue="torchaudio" />
|
| 18 |
+
<item index="3" class="java.lang.String" itemvalue="matplotlib" />
|
| 19 |
+
<item index="4" class="java.lang.String" itemvalue="ipython" />
|
| 20 |
+
<item index="5" class="java.lang.String" itemvalue="numpy" />
|
| 21 |
+
<item index="6" class="java.lang.String" itemvalue="opencv_python" />
|
| 22 |
+
<item index="7" class="java.lang.String" itemvalue="Pillow" />
|
| 23 |
+
</list>
|
| 24 |
+
</value>
|
| 25 |
+
</option>
|
| 26 |
+
</inspection_tool>
|
| 27 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 28 |
+
<option name="ignoredErrors">
|
| 29 |
+
<list>
|
| 30 |
+
<option value="N806" />
|
| 31 |
+
<option value="N812" />
|
| 32 |
+
<option value="N802" />
|
| 33 |
+
<option value="N803" />
|
| 34 |
+
</list>
|
| 35 |
+
</option>
|
| 36 |
+
</inspection_tool>
|
| 37 |
+
<inspection_tool class="PyShadowingBuiltinsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
| 38 |
+
<option name="ignoredNames">
|
| 39 |
+
<list>
|
| 40 |
+
<option value="round" />
|
| 41 |
+
</list>
|
| 42 |
+
</option>
|
| 43 |
+
</inspection_tool>
|
| 44 |
+
<inspection_tool class="SpellCheckingInspection" enabled="true" level="TYPO" enabled_by_default="true">
|
| 45 |
+
<option name="processCode" value="false" />
|
| 46 |
+
<option name="processLiterals" value="true" />
|
| 47 |
+
<option name="processComments" value="true" />
|
| 48 |
+
</inspection_tool>
|
| 49 |
+
</profile>
|
| 50 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (aa-listening-test-sigsep-gen)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/query-bandit.iml" filepath="$PROJECT_DIR$/.idea/query-bandit.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/query-bandit.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.2.0",
|
| 3 |
+
"configurations": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Python Debugger: Remote Attach",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "attach",
|
| 8 |
+
"justMyCode": true,
|
| 9 |
+
"connect": {
|
| 10 |
+
"host": "localhost",
|
| 11 |
+
"port": 5678
|
| 12 |
+
},
|
| 13 |
+
"pathMappings": [
|
| 14 |
+
{
|
| 15 |
+
"localRoot": "${workspaceFolder}",
|
| 16 |
+
"remoteRoot": "."
|
| 17 |
+
}
|
| 18 |
+
]
|
| 19 |
+
}
|
| 20 |
+
]
|
| 21 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Karn Watcharasupat
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Language-Audio Banquet
|
| 2 |
+
<a href='https://github.com/ModistAndrew/query-bandit'><img alt="Static Badge" src="https://img.shields.io/badge/github_repo-lightgrey?logo=github"></a>
|
| 3 |
+
<a href='https://huggingface.co/spaces/chenxie95/Language-Audio-Banquet'><img alt="Static Badge" src="https://img.shields.io/badge/huggingface_space-yellow?logo=huggingface"></a>
|
| 4 |
+
|
| 5 |
+
- Change the query embedding model from PaSST to CLAP, which supports language queries.
|
| 6 |
+
|
| 7 |
+
- Change RNN to Transformer.
|
| 8 |
+
|
| 9 |
+
- Some utility functions for inference.
|
| 10 |
+
|
| 11 |
+
- (TODO) Train on more datasets.
|
| 12 |
+
|
| 13 |
+
## Model weights
|
| 14 |
+
You need to download the model weights from [huggingface model](https://huggingface.co/chenxie95/Language-Audio-Banquet-ckpt) and put them in `checkpoints/`. `bandit-vdbo-roformer.ckpt` is needed for training. `ev-pre.ckpt` and `ev-pre-aug.ckpt` can be choosen for inference.
|
| 15 |
+
|
| 16 |
+
What's more, you need to download the query embedding model CLAP from [here](https://huggingface.co/lukewys/laion_clap/blob/main/music_speech_epoch_15_esc_89.25.pt) and put it in `checkpoints/querier/`.
|
| 17 |
+
|
| 18 |
+
## Inference examples
|
| 19 |
+
```bash
|
| 20 |
+
export CONFIG_ROOT=./config
|
| 21 |
+
python \
|
| 22 |
+
# -m debugpy --listen 5678 --wait-for-client \
|
| 23 |
+
train.py inference_byoq \
|
| 24 |
+
checkpoints/ev-pre-aug.ckpt \
|
| 25 |
+
input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4.wav \
|
| 26 |
+
input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_bass.wav \
|
| 27 |
+
output/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_bass.wav \
|
| 28 |
+
--batch_size=12 \
|
| 29 |
+
--use_cuda=true
|
| 30 |
+
|
| 31 |
+
python \
|
| 32 |
+
train.py inference_byoq_text \
|
| 33 |
+
checkpoints/ev-pre-aug.ckpt \
|
| 34 |
+
input/491c1ff5-1e7b-4046-8029-a82d4a8aefb4.wav \
|
| 35 |
+
piano \
|
| 36 |
+
output/491c1ff5-1e7b-4046-8029-a82d4a8aefb4_piano.wav \
|
| 37 |
+
--batch_size=12 \
|
| 38 |
+
--use_cuda=true
|
| 39 |
+
|
| 40 |
+
python \
|
| 41 |
+
train.py inference_test_folder \
|
| 42 |
+
checkpoints/ev-pre-aug.ckpt \
|
| 43 |
+
/inspire/hdd/project/multilingualspeechrecognition/chenxie-25019/data/karaoke_converted/test \
|
| 44 |
+
output/karaoke \
|
| 45 |
+
bass \
|
| 46 |
+
--batch_size=30 \
|
| 47 |
+
--use_cuda=true \
|
| 48 |
+
--input_name=mixture
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Training examples
|
| 52 |
+
```bash
|
| 53 |
+
export CONFIG_ROOT=./config
|
| 54 |
+
# export DATA_ROOT=/inspire/hdd/project/multilingualspeechrecognition/chenxie-25019/data
|
| 55 |
+
# export DATA_ROOT=/dev/shm
|
| 56 |
+
export DATA_ROOT=/inspire/ssd/project/multilingualspeechrecognition/public
|
| 57 |
+
export LOG_ROOT=./logs/ev-pre-aug-bal
|
| 58 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 59 |
+
python \
|
| 60 |
+
train.py train \
|
| 61 |
+
expt/setup-c/bandit-everything-query-pre-d-aug-bal.yml \
|
| 62 |
+
--ckpt_path=logs/ev-pre-aug-bal/e2e/HBRPOI/lightning_logs/version_1/checkpoints/last.ckpt
|
| 63 |
+
# You may modify the batch size in yaml files in config/data/. A batch size of 3 fits on a NVIDIA 4090 (48GB).
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
> ### Please consider giving back to the community if you have benefited from this work.
|
| 69 |
+
>
|
| 70 |
+
> If you've **benefited commercially from this work**, which we've poured significant effort into and released under permissive licenses, we hope you've found it valuable! While these licenses give you lots of freedom, we believe in nurturing a vibrant ecosystem where innovation can continue to flourish.
|
| 71 |
+
>
|
| 72 |
+
> So, as a gesture of appreciation and responsibility, we strongly urge commercial entities that have gained from this software to consider making voluntary contributions to music-related non-profit organizations of your choice. Your contribution directly helps support the foundational work that empowers your commercial success and ensures open-source innovation keeps moving forward.
|
| 73 |
+
>
|
| 74 |
+
> Some suggestions for the beneficiaries are provided [here](https://github.com/the-secret-source/nonprofits). Please do not hesitate to contribute to the list by opening pull requests there.
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
<div align="center">
|
| 80 |
+
<img src="assets/banquet-logo.png">
|
| 81 |
+
</div>
|
| 82 |
+
|
| 83 |
+
# Banquet: A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems
|
| 84 |
+
|
| 85 |
+
Repository for **A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems**
|
| 86 |
+
by Karn N. Watcharasupat and Alexander Lerch. [arXiv](https://arxiv.org/abs/2406.18747)
|
| 87 |
+
|
| 88 |
+
> Despite significant recent progress across multiple subtasks of audio source separation, few music source separation systems support separation beyond the four-stem vocals, drums, bass, and other (VDBO) setup. Of the very few current systems that support source separation beyond this setup, most continue to rely on an inflexible decoder setup that can only support a fixed pre-defined set of stems. Increasing stem support in these inflexible systems correspondingly requires increasing computational complexity, rendering extensions of these systems computationally infeasible for long-tail instruments. In this work, we propose Banquet, a system that allows source separation of multiple stems using just one decoder. A bandsplit source separation model is extended to work in a query-based setup in tandem with a music instrument recognition PaSST model. On the MoisesDB dataset, Banquet, at only 24.9 M trainable parameters, approached the performance level of the significantly more complex 6-stem Hybrid Transformer Demucs on VDBO stems and outperformed it on guitar and piano. The query-based setup allows for the separation of narrow instrument classes such as clean acoustic guitars, and can be successfully applied to the extraction of less common stems such as reeds and organs.
|
| 89 |
+
|
| 90 |
+
For the Cinematic Audio Source Separation model, Bandit, see [this repository](https://github.com/kwatcharasupat/bandit).
|
| 91 |
+
|
| 92 |
+
## Inference
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
git clone https://github.com/kwatcharasupat/query-bandit.git
|
| 96 |
+
cd query-bandit
|
| 97 |
+
export CONFIG_ROOT="./config"
|
| 98 |
+
|
| 99 |
+
python train.py inference_byoq \
|
| 100 |
+
--ckpt_path="/path/to/checkpoint/see-below.ckpt" \
|
| 101 |
+
--input_path="/path/to/input/file/fearOfMatlab.wav" \
|
| 102 |
+
--output_path="/path/to/output/file/fearOfMatlabStemEst/guitar.wav" \
|
| 103 |
+
--query_path="/path/to/query/file/random-guitar.wav" \
|
| 104 |
+
--batch_size=12 \
|
| 105 |
+
--use_cuda=true
|
| 106 |
+
```
|
| 107 |
+
Batch size of 12 _usually_ fits on a RTX 4090.
|
| 108 |
+
|
| 109 |
+
### Model weights
|
| 110 |
+
Model weights are available on Zenodo [here](https://zenodo.org/records/13694558).
|
| 111 |
+
If you are not sure, use `ev-pre-aug.ckpt`.
|
| 112 |
+
|
| 113 |
+
## Citation
|
| 114 |
+
```
|
| 115 |
+
@inproceedings{Watcharasupat2024Banquet,
|
| 116 |
+
title = {A Stem-Agnostic Single-Decoder System for Music Source Separation Beyond Four Stems},
|
| 117 |
+
booktitle = {To Appear in the Proceedings of the 25th International Society for Music Information Retrieval},
|
| 118 |
+
author = {Watcharasupat, Karn N. and Lerch, Alexander},
|
| 119 |
+
year = {2024},
|
| 120 |
+
month = {nov},
|
| 121 |
+
eprint = {2406.18747},
|
| 122 |
+
address = {San Francisco, CA, USA},
|
| 123 |
+
}
|
| 124 |
+
```
|
assets/banquet-logo.png
ADDED
|
config/data/moisesdb-test.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesTestDataModule
|
| 3 |
+
batch_size: 1
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
|
| 7 |
+
inference_kwargs:
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
hop_size_seconds: 0.5
|
| 10 |
+
batch_size: 12
|
| 11 |
+
fs: 44100
|
| 12 |
+
|
| 13 |
+
test_kwargs:
|
| 14 |
+
npy_memmap: true
|
| 15 |
+
mixture_stem: mixture
|
| 16 |
+
use_own_query: false
|
| 17 |
+
allowed_stems: [
|
| 18 |
+
"drums",
|
| 19 |
+
"lead_male_singer",
|
| 20 |
+
"lead_female_singer",
|
| 21 |
+
# "human_choir",
|
| 22 |
+
"background_vocals",
|
| 23 |
+
# "other_vocals",
|
| 24 |
+
"bass_guitar",
|
| 25 |
+
"bass_synthesizer",
|
| 26 |
+
# "contrabass_double_bass",
|
| 27 |
+
# "tuba",
|
| 28 |
+
# "bassoon",
|
| 29 |
+
"fx",
|
| 30 |
+
"clean_electric_guitar",
|
| 31 |
+
"distorted_electric_guitar",
|
| 32 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 33 |
+
"acoustic_guitar",
|
| 34 |
+
"other_plucked",
|
| 35 |
+
"pitched_percussion",
|
| 36 |
+
"grand_piano",
|
| 37 |
+
"electric_piano",
|
| 38 |
+
"organ_electric_organ",
|
| 39 |
+
"synth_pad",
|
| 40 |
+
"synth_lead",
|
| 41 |
+
# "violin",
|
| 42 |
+
# "viola",
|
| 43 |
+
# "cello",
|
| 44 |
+
# "violin_section",
|
| 45 |
+
# "viola_section",
|
| 46 |
+
# "cello_section",
|
| 47 |
+
"string_section",
|
| 48 |
+
"other_strings",
|
| 49 |
+
"brass",
|
| 50 |
+
# "flutes",
|
| 51 |
+
"reeds",
|
| 52 |
+
"other_wind"
|
| 53 |
+
]
|
| 54 |
+
query_file: "query-10s"
|
| 55 |
+
n_channels: 2
|
config/data/setup-a/moisesdb-vdb-query-d-aug.yml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
# "distorted_electric_guitar",
|
| 21 |
+
# "clean_electric_guitar",
|
| 22 |
+
# "acoustic_guitar",
|
| 23 |
+
]
|
| 24 |
+
query_file: "query-10s"
|
| 25 |
+
augment:
|
| 26 |
+
- cls: Shift
|
| 27 |
+
kwargs:
|
| 28 |
+
p: 1.0
|
| 29 |
+
min_shift: -0.5
|
| 30 |
+
max_shift: 0.5
|
| 31 |
+
- cls: Gain
|
| 32 |
+
kwargs:
|
| 33 |
+
p: 1.0
|
| 34 |
+
min_gain_in_db: -6
|
| 35 |
+
max_gain_in_db: 6
|
| 36 |
+
- cls: ShuffleChannels
|
| 37 |
+
kwargs:
|
| 38 |
+
p: 0.5
|
| 39 |
+
- cls: PolarityInversion
|
| 40 |
+
kwargs:
|
| 41 |
+
p: 0.5
|
| 42 |
+
val_kwargs:
|
| 43 |
+
chunk_size_seconds: 6.0
|
| 44 |
+
hop_size_seconds: 6.0
|
| 45 |
+
query_size_seconds: 10.0
|
| 46 |
+
top_k_instrument: 10
|
| 47 |
+
npy_memmap: true
|
| 48 |
+
mixture_stem: mixture
|
| 49 |
+
use_own_query: false
|
| 50 |
+
allowed_stems:
|
| 51 |
+
[
|
| 52 |
+
"bass",
|
| 53 |
+
"drums",
|
| 54 |
+
"lead_male_singer",
|
| 55 |
+
"lead_female_singer",
|
| 56 |
+
# "distorted_electric_guitar",
|
| 57 |
+
# "clean_electric_guitar",
|
| 58 |
+
# "acoustic_guitar",
|
| 59 |
+
]
|
| 60 |
+
query_file: "query-10s"
|
| 61 |
+
test_kwargs:
|
| 62 |
+
npy_memmap: true
|
| 63 |
+
n_channels: 2
|
config/data/setup-a/moisesdb-vdb-query-d.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
# "distorted_electric_guitar",
|
| 21 |
+
# "clean_electric_guitar",
|
| 22 |
+
# "acoustic_guitar",
|
| 23 |
+
]
|
| 24 |
+
query_file: "query-10s"
|
| 25 |
+
val_kwargs:
|
| 26 |
+
chunk_size_seconds: 6.0
|
| 27 |
+
hop_size_seconds: 6.0
|
| 28 |
+
query_size_seconds: 10.0
|
| 29 |
+
top_k_instrument: 10
|
| 30 |
+
npy_memmap: true
|
| 31 |
+
mixture_stem: mixture
|
| 32 |
+
use_own_query: false
|
| 33 |
+
allowed_stems:
|
| 34 |
+
[
|
| 35 |
+
"bass",
|
| 36 |
+
"drums",
|
| 37 |
+
"lead_male_singer",
|
| 38 |
+
"lead_female_singer",
|
| 39 |
+
# "distorted_electric_guitar",
|
| 40 |
+
# "clean_electric_guitar",
|
| 41 |
+
# "acoustic_guitar",
|
| 42 |
+
]
|
| 43 |
+
query_file: "query-10s"
|
| 44 |
+
test_kwargs:
|
| 45 |
+
npy_memmap: true
|
| 46 |
+
n_channels: 2
|
config/data/setup-a/moisesdb-vdb-query.yml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: true
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
# "distorted_electric_guitar",
|
| 21 |
+
# "clean_electric_guitar",
|
| 22 |
+
# "acoustic_guitar",
|
| 23 |
+
]
|
| 24 |
+
query_file: "query-10s"
|
| 25 |
+
val_kwargs:
|
| 26 |
+
chunk_size_seconds: 6.0
|
| 27 |
+
hop_size_seconds: 6.0
|
| 28 |
+
query_size_seconds: 10.0
|
| 29 |
+
top_k_instrument: 10
|
| 30 |
+
npy_memmap: true
|
| 31 |
+
mixture_stem: mixture
|
| 32 |
+
use_own_query: true
|
| 33 |
+
allowed_stems:
|
| 34 |
+
[
|
| 35 |
+
"bass",
|
| 36 |
+
"drums",
|
| 37 |
+
"lead_male_singer",
|
| 38 |
+
"lead_female_singer",
|
| 39 |
+
# "distorted_electric_guitar",
|
| 40 |
+
# "clean_electric_guitar",
|
| 41 |
+
# "acoustic_guitar",
|
| 42 |
+
]
|
| 43 |
+
query_file: "query-10s"
|
| 44 |
+
test_kwargs:
|
| 45 |
+
npy_memmap: true
|
| 46 |
+
n_channels: 2
|
config/data/setup-b/moisesdb-vdbgp-query-d-aug-bal.yml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesBalancedTrainDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
"distorted_electric_guitar",
|
| 21 |
+
"clean_electric_guitar",
|
| 22 |
+
"acoustic_guitar",
|
| 23 |
+
'grand_piano',
|
| 24 |
+
'electric_piano',
|
| 25 |
+
]
|
| 26 |
+
query_file: "query-10s"
|
| 27 |
+
augment:
|
| 28 |
+
- cls: Shift
|
| 29 |
+
kwargs:
|
| 30 |
+
p: 1.0
|
| 31 |
+
min_shift: -0.5
|
| 32 |
+
max_shift: 0.5
|
| 33 |
+
- cls: Gain
|
| 34 |
+
kwargs:
|
| 35 |
+
p: 1.0
|
| 36 |
+
min_gain_in_db: -6
|
| 37 |
+
max_gain_in_db: 6
|
| 38 |
+
- cls: ShuffleChannels
|
| 39 |
+
kwargs:
|
| 40 |
+
p: 0.5
|
| 41 |
+
- cls: PolarityInversion
|
| 42 |
+
kwargs:
|
| 43 |
+
p: 0.5
|
| 44 |
+
val_kwargs:
|
| 45 |
+
chunk_size_seconds: 6.0
|
| 46 |
+
hop_size_seconds: 6.0
|
| 47 |
+
query_size_seconds: 10.0
|
| 48 |
+
top_k_instrument: 10
|
| 49 |
+
npy_memmap: true
|
| 50 |
+
mixture_stem: mixture
|
| 51 |
+
use_own_query: false
|
| 52 |
+
allowed_stems:
|
| 53 |
+
[
|
| 54 |
+
"bass",
|
| 55 |
+
"drums",
|
| 56 |
+
"lead_male_singer",
|
| 57 |
+
"lead_female_singer",
|
| 58 |
+
"distorted_electric_guitar",
|
| 59 |
+
"clean_electric_guitar",
|
| 60 |
+
"acoustic_guitar",
|
| 61 |
+
'grand_piano',
|
| 62 |
+
'electric_piano',
|
| 63 |
+
]
|
| 64 |
+
query_file: "query-10s"
|
| 65 |
+
test_kwargs:
|
| 66 |
+
npy_memmap: true
|
| 67 |
+
n_channels: 2
|
config/data/setup-b/moisesdb-vdbgp-query-d-aug.yml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
"distorted_electric_guitar",
|
| 21 |
+
"clean_electric_guitar",
|
| 22 |
+
"acoustic_guitar",
|
| 23 |
+
'grand_piano',
|
| 24 |
+
'electric_piano',
|
| 25 |
+
]
|
| 26 |
+
query_file: "query-10s"
|
| 27 |
+
augment:
|
| 28 |
+
- cls: Shift
|
| 29 |
+
kwargs:
|
| 30 |
+
p: 1.0
|
| 31 |
+
min_shift: -0.5
|
| 32 |
+
max_shift: 0.5
|
| 33 |
+
- cls: Gain
|
| 34 |
+
kwargs:
|
| 35 |
+
p: 1.0
|
| 36 |
+
min_gain_in_db: -6
|
| 37 |
+
max_gain_in_db: 6
|
| 38 |
+
- cls: ShuffleChannels
|
| 39 |
+
kwargs:
|
| 40 |
+
p: 0.5
|
| 41 |
+
- cls: PolarityInversion
|
| 42 |
+
kwargs:
|
| 43 |
+
p: 0.5
|
| 44 |
+
val_kwargs:
|
| 45 |
+
chunk_size_seconds: 6.0
|
| 46 |
+
hop_size_seconds: 6.0
|
| 47 |
+
query_size_seconds: 10.0
|
| 48 |
+
top_k_instrument: 10
|
| 49 |
+
npy_memmap: true
|
| 50 |
+
mixture_stem: mixture
|
| 51 |
+
use_own_query: false
|
| 52 |
+
allowed_stems:
|
| 53 |
+
[
|
| 54 |
+
"bass",
|
| 55 |
+
"drums",
|
| 56 |
+
"lead_male_singer",
|
| 57 |
+
"lead_female_singer",
|
| 58 |
+
"distorted_electric_guitar",
|
| 59 |
+
"clean_electric_guitar",
|
| 60 |
+
"acoustic_guitar",
|
| 61 |
+
'grand_piano',
|
| 62 |
+
'electric_piano',
|
| 63 |
+
]
|
| 64 |
+
query_file: "query-10s"
|
| 65 |
+
test_kwargs:
|
| 66 |
+
npy_memmap: true
|
| 67 |
+
n_channels: 2
|
config/data/setup-b/moisesdb-vdbgp-query-d.yml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
"distorted_electric_guitar",
|
| 21 |
+
"clean_electric_guitar",
|
| 22 |
+
"acoustic_guitar",
|
| 23 |
+
'grand_piano',
|
| 24 |
+
'electric_piano',
|
| 25 |
+
]
|
| 26 |
+
query_file: "query-10s"
|
| 27 |
+
val_kwargs:
|
| 28 |
+
chunk_size_seconds: 6.0
|
| 29 |
+
hop_size_seconds: 6.0
|
| 30 |
+
query_size_seconds: 10.0
|
| 31 |
+
top_k_instrument: 10
|
| 32 |
+
npy_memmap: true
|
| 33 |
+
mixture_stem: mixture
|
| 34 |
+
use_own_query: false
|
| 35 |
+
allowed_stems:
|
| 36 |
+
[
|
| 37 |
+
"bass",
|
| 38 |
+
"drums",
|
| 39 |
+
"lead_male_singer",
|
| 40 |
+
"lead_female_singer",
|
| 41 |
+
"distorted_electric_guitar",
|
| 42 |
+
"clean_electric_guitar",
|
| 43 |
+
"acoustic_guitar",
|
| 44 |
+
'grand_piano',
|
| 45 |
+
'electric_piano',
|
| 46 |
+
]
|
| 47 |
+
query_file: "query-10s"
|
| 48 |
+
test_kwargs:
|
| 49 |
+
npy_memmap: true
|
| 50 |
+
n_channels: 2
|
config/data/setup-b/moisesdb-vdbgp-query.yml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: true
|
| 14 |
+
allowed_stems:
|
| 15 |
+
[
|
| 16 |
+
"bass",
|
| 17 |
+
"drums",
|
| 18 |
+
"lead_male_singer",
|
| 19 |
+
"lead_female_singer",
|
| 20 |
+
"distorted_electric_guitar",
|
| 21 |
+
"clean_electric_guitar",
|
| 22 |
+
"acoustic_guitar",
|
| 23 |
+
'grand_piano',
|
| 24 |
+
'electric_piano',
|
| 25 |
+
]
|
| 26 |
+
query_file: "query-10s"
|
| 27 |
+
val_kwargs:
|
| 28 |
+
chunk_size_seconds: 6.0
|
| 29 |
+
hop_size_seconds: 6.0
|
| 30 |
+
query_size_seconds: 10.0
|
| 31 |
+
top_k_instrument: 10
|
| 32 |
+
npy_memmap: true
|
| 33 |
+
mixture_stem: mixture
|
| 34 |
+
use_own_query: true
|
| 35 |
+
allowed_stems:
|
| 36 |
+
[
|
| 37 |
+
"bass",
|
| 38 |
+
"drums",
|
| 39 |
+
"lead_male_singer",
|
| 40 |
+
"lead_female_singer",
|
| 41 |
+
"distorted_electric_guitar",
|
| 42 |
+
"clean_electric_guitar",
|
| 43 |
+
"acoustic_guitar",
|
| 44 |
+
'grand_piano',
|
| 45 |
+
'electric_piano',
|
| 46 |
+
]
|
| 47 |
+
query_file: "query-10s"
|
| 48 |
+
test_kwargs:
|
| 49 |
+
npy_memmap: true
|
| 50 |
+
n_channels: 2
|
config/data/setup-c/moisesdb-everything-query-d-aug-bal.yml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesBalancedTrainDataModule
|
| 3 |
+
batch_size: 3
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems: [
|
| 15 |
+
"drums",
|
| 16 |
+
"lead_male_singer",
|
| 17 |
+
"lead_female_singer",
|
| 18 |
+
# "human_choir",
|
| 19 |
+
"background_vocals",
|
| 20 |
+
# "other_vocals",
|
| 21 |
+
"bass_guitar",
|
| 22 |
+
"bass_synthesizer",
|
| 23 |
+
# "contrabass_double_bass",
|
| 24 |
+
# "tuba",
|
| 25 |
+
# "bassoon",
|
| 26 |
+
"fx",
|
| 27 |
+
"clean_electric_guitar",
|
| 28 |
+
"distorted_electric_guitar",
|
| 29 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 30 |
+
"acoustic_guitar",
|
| 31 |
+
"other_plucked",
|
| 32 |
+
"pitched_percussion",
|
| 33 |
+
"grand_piano",
|
| 34 |
+
"electric_piano",
|
| 35 |
+
"organ_electric_organ",
|
| 36 |
+
"synth_pad",
|
| 37 |
+
"synth_lead",
|
| 38 |
+
# "violin",
|
| 39 |
+
# "viola",
|
| 40 |
+
# "cello",
|
| 41 |
+
# "violin_section",
|
| 42 |
+
# "viola_section",
|
| 43 |
+
# "cello_section",
|
| 44 |
+
"string_section",
|
| 45 |
+
"other_strings",
|
| 46 |
+
"brass",
|
| 47 |
+
# "flutes",
|
| 48 |
+
"reeds",
|
| 49 |
+
"other_wind"
|
| 50 |
+
]
|
| 51 |
+
query_file: "query-10s"
|
| 52 |
+
augment:
|
| 53 |
+
- cls: Shift
|
| 54 |
+
kwargs:
|
| 55 |
+
p: 1.0
|
| 56 |
+
min_shift: -0.5
|
| 57 |
+
max_shift: 0.5
|
| 58 |
+
- cls: Gain
|
| 59 |
+
kwargs:
|
| 60 |
+
p: 1.0
|
| 61 |
+
min_gain_in_db: -6
|
| 62 |
+
max_gain_in_db: 6
|
| 63 |
+
- cls: ShuffleChannels
|
| 64 |
+
kwargs:
|
| 65 |
+
p: 0.5
|
| 66 |
+
- cls: PolarityInversion
|
| 67 |
+
kwargs:
|
| 68 |
+
p: 0.5
|
| 69 |
+
val_kwargs:
|
| 70 |
+
chunk_size_seconds: 6.0
|
| 71 |
+
hop_size_seconds: 6.0
|
| 72 |
+
query_size_seconds: 10.0
|
| 73 |
+
top_k_instrument: 10
|
| 74 |
+
npy_memmap: true
|
| 75 |
+
mixture_stem: mixture
|
| 76 |
+
use_own_query: false
|
| 77 |
+
allowed_stems: [
|
| 78 |
+
"drums",
|
| 79 |
+
"lead_male_singer",
|
| 80 |
+
"lead_female_singer",
|
| 81 |
+
# "human_choir",
|
| 82 |
+
"background_vocals",
|
| 83 |
+
# "other_vocals",
|
| 84 |
+
"bass_guitar",
|
| 85 |
+
"bass_synthesizer",
|
| 86 |
+
# "contrabass_double_bass",
|
| 87 |
+
# "tuba",
|
| 88 |
+
# "bassoon",
|
| 89 |
+
"fx",
|
| 90 |
+
"clean_electric_guitar",
|
| 91 |
+
"distorted_electric_guitar",
|
| 92 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 93 |
+
"acoustic_guitar",
|
| 94 |
+
"other_plucked",
|
| 95 |
+
"pitched_percussion",
|
| 96 |
+
"grand_piano",
|
| 97 |
+
"electric_piano",
|
| 98 |
+
"organ_electric_organ",
|
| 99 |
+
"synth_pad",
|
| 100 |
+
"synth_lead",
|
| 101 |
+
# "violin",
|
| 102 |
+
# "viola",
|
| 103 |
+
# "cello",
|
| 104 |
+
# "violin_section",
|
| 105 |
+
# "viola_section",
|
| 106 |
+
# "cello_section",
|
| 107 |
+
"string_section",
|
| 108 |
+
"other_strings",
|
| 109 |
+
"brass",
|
| 110 |
+
# "flutes",
|
| 111 |
+
"reeds",
|
| 112 |
+
"other_wind"
|
| 113 |
+
]
|
| 114 |
+
query_file: "query-10s"
|
| 115 |
+
test_kwargs:
|
| 116 |
+
npy_memmap: true
|
| 117 |
+
n_channels: 2
|
config/data/setup-c/moisesdb-everything-query-d-aug.yml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 3
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems: [
|
| 15 |
+
"drums",
|
| 16 |
+
"lead_male_singer",
|
| 17 |
+
"lead_female_singer",
|
| 18 |
+
# "human_choir",
|
| 19 |
+
"background_vocals",
|
| 20 |
+
# "other_vocals",
|
| 21 |
+
"bass_guitar",
|
| 22 |
+
"bass_synthesizer",
|
| 23 |
+
# "contrabass_double_bass",
|
| 24 |
+
# "tuba",
|
| 25 |
+
# "bassoon",
|
| 26 |
+
"fx",
|
| 27 |
+
"clean_electric_guitar",
|
| 28 |
+
"distorted_electric_guitar",
|
| 29 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 30 |
+
"acoustic_guitar",
|
| 31 |
+
"other_plucked",
|
| 32 |
+
"pitched_percussion",
|
| 33 |
+
"grand_piano",
|
| 34 |
+
"electric_piano",
|
| 35 |
+
"organ_electric_organ",
|
| 36 |
+
"synth_pad",
|
| 37 |
+
"synth_lead",
|
| 38 |
+
# "violin",
|
| 39 |
+
# "viola",
|
| 40 |
+
# "cello",
|
| 41 |
+
# "violin_section",
|
| 42 |
+
# "viola_section",
|
| 43 |
+
# "cello_section",
|
| 44 |
+
"string_section",
|
| 45 |
+
"other_strings",
|
| 46 |
+
"brass",
|
| 47 |
+
# "flutes",
|
| 48 |
+
"reeds",
|
| 49 |
+
"other_wind"
|
| 50 |
+
]
|
| 51 |
+
query_file: "query-10s"
|
| 52 |
+
augment:
|
| 53 |
+
- cls: Shift
|
| 54 |
+
kwargs:
|
| 55 |
+
p: 1.0
|
| 56 |
+
min_shift: -0.5
|
| 57 |
+
max_shift: 0.5
|
| 58 |
+
- cls: Gain
|
| 59 |
+
kwargs:
|
| 60 |
+
p: 1.0
|
| 61 |
+
min_gain_in_db: -6
|
| 62 |
+
max_gain_in_db: 6
|
| 63 |
+
- cls: ShuffleChannels
|
| 64 |
+
kwargs:
|
| 65 |
+
p: 0.5
|
| 66 |
+
- cls: PolarityInversion
|
| 67 |
+
kwargs:
|
| 68 |
+
p: 0.5
|
| 69 |
+
val_kwargs:
|
| 70 |
+
chunk_size_seconds: 6.0
|
| 71 |
+
hop_size_seconds: 6.0
|
| 72 |
+
query_size_seconds: 10.0
|
| 73 |
+
top_k_instrument: 10
|
| 74 |
+
npy_memmap: true
|
| 75 |
+
mixture_stem: mixture
|
| 76 |
+
use_own_query: false
|
| 77 |
+
allowed_stems: [
|
| 78 |
+
"drums",
|
| 79 |
+
"lead_male_singer",
|
| 80 |
+
"lead_female_singer",
|
| 81 |
+
# "human_choir",
|
| 82 |
+
"background_vocals",
|
| 83 |
+
# "other_vocals",
|
| 84 |
+
"bass_guitar",
|
| 85 |
+
"bass_synthesizer",
|
| 86 |
+
# "contrabass_double_bass",
|
| 87 |
+
# "tuba",
|
| 88 |
+
# "bassoon",
|
| 89 |
+
"fx",
|
| 90 |
+
"clean_electric_guitar",
|
| 91 |
+
"distorted_electric_guitar",
|
| 92 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 93 |
+
"acoustic_guitar",
|
| 94 |
+
"other_plucked",
|
| 95 |
+
"pitched_percussion",
|
| 96 |
+
"grand_piano",
|
| 97 |
+
"electric_piano",
|
| 98 |
+
"organ_electric_organ",
|
| 99 |
+
"synth_pad",
|
| 100 |
+
"synth_lead",
|
| 101 |
+
# "violin",
|
| 102 |
+
# "viola",
|
| 103 |
+
# "cello",
|
| 104 |
+
# "violin_section",
|
| 105 |
+
# "viola_section",
|
| 106 |
+
# "cello_section",
|
| 107 |
+
"string_section",
|
| 108 |
+
"other_strings",
|
| 109 |
+
"brass",
|
| 110 |
+
# "flutes",
|
| 111 |
+
"reeds",
|
| 112 |
+
"other_wind"
|
| 113 |
+
]
|
| 114 |
+
query_file: "query-10s"
|
| 115 |
+
test_kwargs:
|
| 116 |
+
npy_memmap: true
|
| 117 |
+
n_channels: 2
|
config/data/setup-c/moisesdb-everything-query-d-bal.yml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesBalancedTrainDataModule
|
| 3 |
+
batch_size: 3
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems: [
|
| 15 |
+
"drums",
|
| 16 |
+
"lead_male_singer",
|
| 17 |
+
"lead_female_singer",
|
| 18 |
+
# "human_choir",
|
| 19 |
+
"background_vocals",
|
| 20 |
+
# "other_vocals",
|
| 21 |
+
"bass_guitar",
|
| 22 |
+
"bass_synthesizer",
|
| 23 |
+
# "contrabass_double_bass",
|
| 24 |
+
# "tuba",
|
| 25 |
+
# "bassoon",
|
| 26 |
+
"fx",
|
| 27 |
+
"clean_electric_guitar",
|
| 28 |
+
"distorted_electric_guitar",
|
| 29 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 30 |
+
"acoustic_guitar",
|
| 31 |
+
"other_plucked",
|
| 32 |
+
"pitched_percussion",
|
| 33 |
+
"grand_piano",
|
| 34 |
+
"electric_piano",
|
| 35 |
+
"organ_electric_organ",
|
| 36 |
+
"synth_pad",
|
| 37 |
+
"synth_lead",
|
| 38 |
+
# "violin",
|
| 39 |
+
# "viola",
|
| 40 |
+
# "cello",
|
| 41 |
+
# "violin_section",
|
| 42 |
+
# "viola_section",
|
| 43 |
+
# "cello_section",
|
| 44 |
+
"string_section",
|
| 45 |
+
"other_strings",
|
| 46 |
+
"brass",
|
| 47 |
+
# "flutes",
|
| 48 |
+
"reeds",
|
| 49 |
+
"other_wind"
|
| 50 |
+
]
|
| 51 |
+
query_file: "query-10s"
|
| 52 |
+
val_kwargs:
|
| 53 |
+
chunk_size_seconds: 6.0
|
| 54 |
+
hop_size_seconds: 6.0
|
| 55 |
+
query_size_seconds: 10.0
|
| 56 |
+
top_k_instrument: 10
|
| 57 |
+
npy_memmap: true
|
| 58 |
+
mixture_stem: mixture
|
| 59 |
+
use_own_query: false
|
| 60 |
+
allowed_stems: [
|
| 61 |
+
"drums",
|
| 62 |
+
"lead_male_singer",
|
| 63 |
+
"lead_female_singer",
|
| 64 |
+
# "human_choir",
|
| 65 |
+
"background_vocals",
|
| 66 |
+
# "other_vocals",
|
| 67 |
+
"bass_guitar",
|
| 68 |
+
"bass_synthesizer",
|
| 69 |
+
# "contrabass_double_bass",
|
| 70 |
+
# "tuba",
|
| 71 |
+
# "bassoon",
|
| 72 |
+
"fx",
|
| 73 |
+
"clean_electric_guitar",
|
| 74 |
+
"distorted_electric_guitar",
|
| 75 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 76 |
+
"acoustic_guitar",
|
| 77 |
+
"other_plucked",
|
| 78 |
+
"pitched_percussion",
|
| 79 |
+
"grand_piano",
|
| 80 |
+
"electric_piano",
|
| 81 |
+
"organ_electric_organ",
|
| 82 |
+
"synth_pad",
|
| 83 |
+
"synth_lead",
|
| 84 |
+
# "violin",
|
| 85 |
+
# "viola",
|
| 86 |
+
# "cello",
|
| 87 |
+
# "violin_section",
|
| 88 |
+
# "viola_section",
|
| 89 |
+
# "cello_section",
|
| 90 |
+
"string_section",
|
| 91 |
+
"other_strings",
|
| 92 |
+
"brass",
|
| 93 |
+
# "flutes",
|
| 94 |
+
"reeds",
|
| 95 |
+
"other_wind"
|
| 96 |
+
]
|
| 97 |
+
query_file: "query-10s"
|
| 98 |
+
test_kwargs:
|
| 99 |
+
npy_memmap: true
|
| 100 |
+
n_channels: 2
|
config/data/setup-c/moisesdb-everything-query-d.yml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesDataModule
|
| 3 |
+
batch_size: 3
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
query_size_seconds: 10.0
|
| 10 |
+
top_k_instrument: 10
|
| 11 |
+
npy_memmap: true
|
| 12 |
+
mixture_stem: mixture
|
| 13 |
+
use_own_query: false
|
| 14 |
+
allowed_stems: [
|
| 15 |
+
"drums",
|
| 16 |
+
"lead_male_singer",
|
| 17 |
+
"lead_female_singer",
|
| 18 |
+
# "human_choir",
|
| 19 |
+
"background_vocals",
|
| 20 |
+
# "other_vocals",
|
| 21 |
+
"bass_guitar",
|
| 22 |
+
"bass_synthesizer",
|
| 23 |
+
# "contrabass_double_bass",
|
| 24 |
+
# "tuba",
|
| 25 |
+
# "bassoon",
|
| 26 |
+
"fx",
|
| 27 |
+
"clean_electric_guitar",
|
| 28 |
+
"distorted_electric_guitar",
|
| 29 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 30 |
+
"acoustic_guitar",
|
| 31 |
+
"other_plucked",
|
| 32 |
+
"pitched_percussion",
|
| 33 |
+
"grand_piano",
|
| 34 |
+
"electric_piano",
|
| 35 |
+
"organ_electric_organ",
|
| 36 |
+
"synth_pad",
|
| 37 |
+
"synth_lead",
|
| 38 |
+
# "violin",
|
| 39 |
+
# "viola",
|
| 40 |
+
# "cello",
|
| 41 |
+
# "violin_section",
|
| 42 |
+
# "viola_section",
|
| 43 |
+
# "cello_section",
|
| 44 |
+
"string_section",
|
| 45 |
+
"other_strings",
|
| 46 |
+
"brass",
|
| 47 |
+
# "flutes",
|
| 48 |
+
"reeds",
|
| 49 |
+
"other_wind"
|
| 50 |
+
]
|
| 51 |
+
query_file: "query-10s"
|
| 52 |
+
val_kwargs:
|
| 53 |
+
chunk_size_seconds: 6.0
|
| 54 |
+
hop_size_seconds: 6.0
|
| 55 |
+
query_size_seconds: 10.0
|
| 56 |
+
top_k_instrument: 10
|
| 57 |
+
npy_memmap: true
|
| 58 |
+
mixture_stem: mixture
|
| 59 |
+
use_own_query: false
|
| 60 |
+
allowed_stems: [
|
| 61 |
+
"drums",
|
| 62 |
+
"lead_male_singer",
|
| 63 |
+
"lead_female_singer",
|
| 64 |
+
# "human_choir",
|
| 65 |
+
"background_vocals",
|
| 66 |
+
# "other_vocals",
|
| 67 |
+
"bass_guitar",
|
| 68 |
+
"bass_synthesizer",
|
| 69 |
+
# "contrabass_double_bass",
|
| 70 |
+
# "tuba",
|
| 71 |
+
# "bassoon",
|
| 72 |
+
"fx",
|
| 73 |
+
"clean_electric_guitar",
|
| 74 |
+
"distorted_electric_guitar",
|
| 75 |
+
# "lap_steel_guitar_or_slide_guitar",
|
| 76 |
+
"acoustic_guitar",
|
| 77 |
+
"other_plucked",
|
| 78 |
+
"pitched_percussion",
|
| 79 |
+
"grand_piano",
|
| 80 |
+
"electric_piano",
|
| 81 |
+
"organ_electric_organ",
|
| 82 |
+
"synth_pad",
|
| 83 |
+
"synth_lead",
|
| 84 |
+
# "violin",
|
| 85 |
+
# "viola",
|
| 86 |
+
# "cello",
|
| 87 |
+
# "violin_section",
|
| 88 |
+
# "viola_section",
|
| 89 |
+
# "cello_section",
|
| 90 |
+
"string_section",
|
| 91 |
+
"other_strings",
|
| 92 |
+
"brass",
|
| 93 |
+
# "flutes",
|
| 94 |
+
"reeds",
|
| 95 |
+
"other_wind"
|
| 96 |
+
]
|
| 97 |
+
query_file: "query-10s"
|
| 98 |
+
test_kwargs:
|
| 99 |
+
npy_memmap: true
|
| 100 |
+
n_channels: 2
|
config/data/vdbo/moisesdb-vdbo-aug.yml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesVDBODataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
fs: 44100
|
| 10 |
+
npy_memmap: true
|
| 11 |
+
augment:
|
| 12 |
+
- cls: Shift
|
| 13 |
+
kwargs:
|
| 14 |
+
p: 1.0
|
| 15 |
+
min_shift: -0.5
|
| 16 |
+
max_shift: 0.5
|
| 17 |
+
- cls: Gain
|
| 18 |
+
kwargs:
|
| 19 |
+
p: 1.0
|
| 20 |
+
min_gain_in_db: -6
|
| 21 |
+
max_gain_in_db: 6
|
| 22 |
+
- cls: ShuffleChannels
|
| 23 |
+
kwargs:
|
| 24 |
+
p: 0.5
|
| 25 |
+
- cls: PolarityInversion
|
| 26 |
+
kwargs:
|
| 27 |
+
p: 0.5
|
| 28 |
+
val_kwargs:
|
| 29 |
+
chunk_size_seconds: 6.0
|
| 30 |
+
hop_size_seconds: 6.0
|
| 31 |
+
fs: 44100
|
| 32 |
+
npy_memmap: true
|
| 33 |
+
test_kwargs:
|
| 34 |
+
npy_memmap: true
|
| 35 |
+
n_channels: 2
|
config/data/vdbo/moisesdb-vdbo.yml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: ${oc.env:DATA_ROOT}/moisesdb
|
| 2 |
+
cls: MoisesVDBODataModule
|
| 3 |
+
batch_size: 4
|
| 4 |
+
effective_batch_size: null
|
| 5 |
+
num_workers: 8
|
| 6 |
+
train_kwargs:
|
| 7 |
+
target_length: 8192
|
| 8 |
+
chunk_size_seconds: 6.0
|
| 9 |
+
fs: 44100
|
| 10 |
+
npy_memmap: true
|
| 11 |
+
val_kwargs:
|
| 12 |
+
chunk_size_seconds: 6.0
|
| 13 |
+
hop_size_seconds: 6.0
|
| 14 |
+
fs: 44100
|
| 15 |
+
npy_memmap: true
|
| 16 |
+
test_kwargs:
|
| 17 |
+
npy_memmap: true
|
| 18 |
+
n_channels: 2
|
config/losses/both_l1snr.yml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: L1SNRLoss
|
| 2 |
+
modality:
|
| 3 |
+
- audio
|
| 4 |
+
- spectrogram
|
config/losses/both_l1snrdbm.yml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: L1SNRDecibelMatchLoss
|
| 2 |
+
modality:
|
| 3 |
+
- audio
|
| 4 |
+
- spectrogram
|
config/models/bandit-query-pre.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: PasstFiLMConditionedBandit
|
| 2 |
+
kwargs:
|
| 3 |
+
in_channel: 2
|
| 4 |
+
band_type: "musical"
|
| 5 |
+
n_bands: 64
|
| 6 |
+
additive_film: true
|
| 7 |
+
multiplicative_film: true
|
| 8 |
+
film_depth: 2
|
| 9 |
+
n_sqm_modules: 8
|
| 10 |
+
emb_dim: 128
|
| 11 |
+
rnn_dim: 256
|
| 12 |
+
bidirectional: true
|
| 13 |
+
rnn_type: "GRU"
|
| 14 |
+
mlp_dim: 512
|
| 15 |
+
hidden_activation: "Tanh"
|
| 16 |
+
hidden_activation_kwargs: null
|
| 17 |
+
complex_mask: true
|
| 18 |
+
use_freq_weights: true
|
| 19 |
+
n_fft: 2048
|
| 20 |
+
win_length: 2048
|
| 21 |
+
hop_length: 512
|
| 22 |
+
window_fn: "hann_window"
|
| 23 |
+
wkwargs: null
|
| 24 |
+
power: null
|
| 25 |
+
center: true
|
| 26 |
+
normalized: true
|
| 27 |
+
pad_mode: "reflect"
|
| 28 |
+
onesided: true
|
| 29 |
+
fs: 44100
|
| 30 |
+
pretrain_encoder: checkpoints/bandit-vdbo-roformer.ckpt
|
| 31 |
+
freeze_encoder: false
|
config/models/bandit-query-prefz.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: PasstFiLMConditionedBandit
|
| 2 |
+
kwargs:
|
| 3 |
+
in_channel: 2
|
| 4 |
+
band_type: "musical"
|
| 5 |
+
n_bands: 64
|
| 6 |
+
additive_film: true
|
| 7 |
+
multiplicative_film: true
|
| 8 |
+
film_depth: 2
|
| 9 |
+
n_sqm_modules: 8
|
| 10 |
+
emb_dim: 128
|
| 11 |
+
rnn_dim: 256
|
| 12 |
+
bidirectional: true
|
| 13 |
+
rnn_type: "GRU"
|
| 14 |
+
mlp_dim: 512
|
| 15 |
+
hidden_activation: "Tanh"
|
| 16 |
+
hidden_activation_kwargs: null
|
| 17 |
+
complex_mask: true
|
| 18 |
+
use_freq_weights: true
|
| 19 |
+
n_fft: 2048
|
| 20 |
+
win_length: 2048
|
| 21 |
+
hop_length: 512
|
| 22 |
+
window_fn: "hann_window"
|
| 23 |
+
wkwargs: null
|
| 24 |
+
power: null
|
| 25 |
+
center: true
|
| 26 |
+
normalized: true
|
| 27 |
+
pad_mode: "reflect"
|
| 28 |
+
onesided: true
|
| 29 |
+
fs: 44100
|
| 30 |
+
pretrain_encoder: checkpoints/bandit-vdbo-roformer.ckpt
|
| 31 |
+
freeze_encoder: true
|
config/models/bandit-query.yml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: PasstFiLMConditionedBandit
|
| 2 |
+
kwargs:
|
| 3 |
+
in_channel: 2
|
| 4 |
+
band_type: "musical"
|
| 5 |
+
n_bands: 64
|
| 6 |
+
additive_film: true
|
| 7 |
+
multiplicative_film: true
|
| 8 |
+
film_depth: 2
|
| 9 |
+
n_sqm_modules: 8
|
| 10 |
+
emb_dim: 128
|
| 11 |
+
rnn_dim: 256
|
| 12 |
+
bidirectional: true
|
| 13 |
+
rnn_type: "GRU"
|
| 14 |
+
mlp_dim: 512
|
| 15 |
+
hidden_activation: "Tanh"
|
| 16 |
+
hidden_activation_kwargs: null
|
| 17 |
+
complex_mask: true
|
| 18 |
+
use_freq_weights: true
|
| 19 |
+
n_fft: 2048
|
| 20 |
+
win_length: 2048
|
| 21 |
+
hop_length: 512
|
| 22 |
+
window_fn: "hann_window"
|
| 23 |
+
wkwargs: null
|
| 24 |
+
power: null
|
| 25 |
+
center: true
|
| 26 |
+
normalized: true
|
| 27 |
+
pad_mode: "reflect"
|
| 28 |
+
onesided: true
|
| 29 |
+
fs: 44100
|
config/models/bandit-vdbo.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls: Bandit
|
| 2 |
+
kwargs:
|
| 3 |
+
in_channel: 2
|
| 4 |
+
stems: ["vocals", "bass", "drums", "vdbo_others"]
|
| 5 |
+
band_type: "musical"
|
| 6 |
+
n_bands: 64
|
| 7 |
+
n_sqm_modules: 8
|
| 8 |
+
emb_dim: 128
|
| 9 |
+
rnn_dim: 256
|
| 10 |
+
bidirectional: true
|
| 11 |
+
rnn_type: "GRU"
|
| 12 |
+
mlp_dim: 512
|
| 13 |
+
hidden_activation: "Tanh"
|
| 14 |
+
hidden_activation_kwargs: null
|
| 15 |
+
complex_mask: true
|
| 16 |
+
use_freq_weights: true
|
| 17 |
+
n_fft: 2048
|
| 18 |
+
win_length: 2048
|
| 19 |
+
hop_length: 512
|
| 20 |
+
window_fn: "hann_window"
|
| 21 |
+
wkwargs: null
|
| 22 |
+
power: null
|
| 23 |
+
center: true
|
| 24 |
+
normalized: true
|
| 25 |
+
pad_mode: "reflect"
|
| 26 |
+
onesided: true
|
| 27 |
+
fs: 44100
|
config/optim/adam.yml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
optimizer:
|
| 2 |
+
cls: Adam
|
| 3 |
+
kwargs:
|
| 4 |
+
lr: 1.0e-3
|
| 5 |
+
scheduler:
|
| 6 |
+
cls: StepLR
|
| 7 |
+
kwargs:
|
| 8 |
+
step_size: 1
|
| 9 |
+
gamma: 0.98
|
config/trainer/default-long.yml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
callbacks:
|
| 2 |
+
checkpoint:
|
| 3 |
+
monitor: val/loss
|
| 4 |
+
mode: min
|
| 5 |
+
save_top_k: 3
|
| 6 |
+
save_last: True
|
| 7 |
+
max_epochs: 500
|
| 8 |
+
accumulate_grad_batches: null
|
| 9 |
+
gradient_clip_val: 10.0
|
| 10 |
+
gradient_clip_algorithm: norm
|
| 11 |
+
logger:
|
| 12 |
+
save_dir: ${oc.env:LOG_ROOT}/e2e
|
config/trainer/default.yml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
callbacks:
|
| 2 |
+
checkpoint:
|
| 3 |
+
monitor: val/loss
|
| 4 |
+
mode: min
|
| 5 |
+
save_top_k: 3
|
| 6 |
+
save_last: True
|
| 7 |
+
max_epochs: 150
|
| 8 |
+
accumulate_grad_batches: null
|
| 9 |
+
gradient_clip_val: 10.0
|
| 10 |
+
gradient_clip_algorithm: norm
|
| 11 |
+
logger:
|
| 12 |
+
save_dir: ${oc.env:LOG_ROOT}/e2e
|
core/__init__.py
ADDED
|
File without changes
|
core/data/__init__.py
ADDED
|
File without changes
|
core/data/base.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import os
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
| 10 |
+
from torch.utils import data
|
| 11 |
+
|
| 12 |
+
from pytorch_lightning import LightningDataModule
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def from_datasets(
|
| 17 |
+
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
|
| 18 |
+
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
| 19 |
+
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
| 20 |
+
predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
| 21 |
+
batch_size: int = 1,
|
| 22 |
+
num_workers: int = 0,
|
| 23 |
+
**datamodule_kwargs: Any,
|
| 24 |
+
) -> "LightningDataModule":
|
| 25 |
+
|
| 26 |
+
def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
|
| 27 |
+
shuffle &= not isinstance(ds, IterableDataset)
|
| 28 |
+
return DataLoader(
|
| 29 |
+
ds,
|
| 30 |
+
batch_size=batch_size,
|
| 31 |
+
shuffle=shuffle,
|
| 32 |
+
num_workers=num_workers,
|
| 33 |
+
pin_memory=True,
|
| 34 |
+
prefetch_factor=4,
|
| 35 |
+
persistent_workers=True,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def train_dataloader() -> TRAIN_DATALOADERS:
|
| 39 |
+
assert train_dataset
|
| 40 |
+
|
| 41 |
+
if isinstance(train_dataset, Mapping):
|
| 42 |
+
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
|
| 43 |
+
if isinstance(train_dataset, Sequence):
|
| 44 |
+
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
| 45 |
+
return dataloader(train_dataset, shuffle=True)
|
| 46 |
+
|
| 47 |
+
def val_dataloader() -> EVAL_DATALOADERS:
|
| 48 |
+
assert val_dataset
|
| 49 |
+
|
| 50 |
+
if isinstance(val_dataset, Sequence):
|
| 51 |
+
return [dataloader(ds) for ds in val_dataset]
|
| 52 |
+
return dataloader(val_dataset)
|
| 53 |
+
|
| 54 |
+
def test_dataloader() -> EVAL_DATALOADERS:
|
| 55 |
+
assert test_dataset
|
| 56 |
+
|
| 57 |
+
if isinstance(test_dataset, Sequence):
|
| 58 |
+
return [dataloader(ds) for ds in test_dataset]
|
| 59 |
+
return dataloader(test_dataset)
|
| 60 |
+
|
| 61 |
+
def predict_dataloader() -> EVAL_DATALOADERS:
|
| 62 |
+
assert predict_dataset
|
| 63 |
+
|
| 64 |
+
if isinstance(predict_dataset, Sequence):
|
| 65 |
+
return [dataloader(ds) for ds in predict_dataset]
|
| 66 |
+
return dataloader(predict_dataset)
|
| 67 |
+
|
| 68 |
+
candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
|
| 69 |
+
accepted_params = inspect.signature(LightningDataModule.__init__).parameters
|
| 70 |
+
accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
|
| 71 |
+
if accepts_kwargs:
|
| 72 |
+
special_kwargs = candidate_kwargs
|
| 73 |
+
else:
|
| 74 |
+
accepted_param_names = set(accepted_params)
|
| 75 |
+
accepted_param_names.discard("self")
|
| 76 |
+
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}
|
| 77 |
+
|
| 78 |
+
datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs)
|
| 79 |
+
if train_dataset is not None:
|
| 80 |
+
datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
|
| 81 |
+
if val_dataset is not None:
|
| 82 |
+
datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
|
| 83 |
+
if test_dataset is not None:
|
| 84 |
+
datamodule.test_dataloader = test_dataloader # type: ignore[method-assign]
|
| 85 |
+
if predict_dataset is not None:
|
| 86 |
+
datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign]
|
| 87 |
+
|
| 88 |
+
return datamodule
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
split: str,
|
| 95 |
+
stems: List[str],
|
| 96 |
+
files: List[str],
|
| 97 |
+
data_path: str,
|
| 98 |
+
fs: int,
|
| 99 |
+
npy_memmap: bool,
|
| 100 |
+
recompute_mixture: bool,
|
| 101 |
+
):
|
| 102 |
+
if "mixture" not in stems:
|
| 103 |
+
stems = ["mixture"] + stems
|
| 104 |
+
|
| 105 |
+
self.split = split
|
| 106 |
+
self.stems = stems
|
| 107 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 108 |
+
self.files = files
|
| 109 |
+
self.data_path = data_path
|
| 110 |
+
self.fs = fs
|
| 111 |
+
self.npy_memmap = npy_memmap
|
| 112 |
+
self.recompute_mixture = recompute_mixture
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 116 |
+
raise NotImplementedError
|
| 117 |
+
|
| 118 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 119 |
+
audio = {}
|
| 120 |
+
for stem in stems:
|
| 121 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 122 |
+
|
| 123 |
+
return audio
|
| 124 |
+
|
| 125 |
+
def get_audio(self, identifier: Dict[str, Any]):
|
| 126 |
+
if self.recompute_mixture:
|
| 127 |
+
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 128 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 129 |
+
return audio
|
| 130 |
+
else:
|
| 131 |
+
return self._get_audio(self.stems, identifier=identifier)
|
| 132 |
+
|
| 133 |
+
@abstractmethod
|
| 134 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
def compute_mixture(self, audio) -> torch.Tensor:
|
| 138 |
+
return sum(audio[stem] for stem in audio if stem != "mixture")
|
core/data/moisesdb/__init__.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
taxonomy = {
|
| 2 |
+
"vocals": [
|
| 3 |
+
"lead male singer",
|
| 4 |
+
"lead female singer",
|
| 5 |
+
"human choir",
|
| 6 |
+
"background vocals",
|
| 7 |
+
"other (vocoder, beatboxing etc)",
|
| 8 |
+
],
|
| 9 |
+
"bass": [
|
| 10 |
+
"bass guitar",
|
| 11 |
+
"bass synthesizer (moog etc)",
|
| 12 |
+
"contrabass/double bass (bass of instrings)",
|
| 13 |
+
"tuba (bass of brass)",
|
| 14 |
+
"bassoon (bass of woodwind)",
|
| 15 |
+
],
|
| 16 |
+
"drums": [
|
| 17 |
+
"snare drum",
|
| 18 |
+
"toms",
|
| 19 |
+
"kick drum",
|
| 20 |
+
"cymbals",
|
| 21 |
+
"overheads",
|
| 22 |
+
"full acoustic drumkit",
|
| 23 |
+
"drum machine",
|
| 24 |
+
],
|
| 25 |
+
"other": [
|
| 26 |
+
"fx/processed sound, scratches, gun shots, explosions etc",
|
| 27 |
+
"click track",
|
| 28 |
+
],
|
| 29 |
+
"guitar": [
|
| 30 |
+
"clean electric guitar",
|
| 31 |
+
"distorted electric guitar",
|
| 32 |
+
"lap steel guitar or slide guitar",
|
| 33 |
+
"acoustic guitar",
|
| 34 |
+
],
|
| 35 |
+
"other plucked": ["banjo, mandolin, ukulele, harp etc"],
|
| 36 |
+
"percussion": [
|
| 37 |
+
"a-tonal percussion (claps, shakers, congas, cowbell etc)",
|
| 38 |
+
"pitched percussion (mallets, glockenspiel, ...)",
|
| 39 |
+
],
|
| 40 |
+
"piano": [
|
| 41 |
+
"grand piano",
|
| 42 |
+
"electric piano (rhodes, wurlitzer, piano sound alike)",
|
| 43 |
+
],
|
| 44 |
+
"other keys": [
|
| 45 |
+
"organ, electric organ",
|
| 46 |
+
"synth pad",
|
| 47 |
+
"synth lead",
|
| 48 |
+
"other sounds (hapischord, melotron etc)",
|
| 49 |
+
],
|
| 50 |
+
"bowed strings": [
|
| 51 |
+
"violin (solo)",
|
| 52 |
+
"viola (solo)",
|
| 53 |
+
"cello (solo)",
|
| 54 |
+
"violin section",
|
| 55 |
+
"viola section",
|
| 56 |
+
"cello section",
|
| 57 |
+
"string section",
|
| 58 |
+
"other strings",
|
| 59 |
+
],
|
| 60 |
+
"wind": [
|
| 61 |
+
"brass (trumpet, trombone, french horn, brass etc)",
|
| 62 |
+
"flutes (piccolo, bamboo flute, panpipes, flutes etc)",
|
| 63 |
+
"reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
|
| 64 |
+
"other wind",
|
| 65 |
+
],
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def clean_track_inst(inst):
|
| 70 |
+
|
| 71 |
+
if "fx" in inst:
|
| 72 |
+
inst = "fx"
|
| 73 |
+
|
| 74 |
+
if "contrabass_double_bass" in inst:
|
| 75 |
+
inst = "double_bass"
|
| 76 |
+
|
| 77 |
+
if "banjo" in inst:
|
| 78 |
+
return "other_plucked"
|
| 79 |
+
|
| 80 |
+
if "(" in inst:
|
| 81 |
+
inst = inst.split("(")[0]
|
| 82 |
+
|
| 83 |
+
for s in [",", "-"]:
|
| 84 |
+
if s in inst:
|
| 85 |
+
inst = inst.replace(s, "")
|
| 86 |
+
|
| 87 |
+
for s in ["/"]:
|
| 88 |
+
if s in inst:
|
| 89 |
+
inst = inst.replace(s, "_")
|
| 90 |
+
|
| 91 |
+
if inst[-1] == "_":
|
| 92 |
+
inst = inst[:-1]
|
| 93 |
+
|
| 94 |
+
return inst
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
taxonomy = {k: [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()}
|
core/data/moisesdb/audio.ipynb
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"ExecuteTime": {
|
| 8 |
+
"end_time": "2024-02-17T00:59:48.228125593Z",
|
| 9 |
+
"start_time": "2024-02-17T00:59:47.533488738Z"
|
| 10 |
+
},
|
| 11 |
+
"collapsed": true
|
| 12 |
+
},
|
| 13 |
+
"outputs": [],
|
| 14 |
+
"source": [
|
| 15 |
+
"import numpy as np\n",
|
| 16 |
+
"import IPython\n",
|
| 17 |
+
"file = \"/home/kwatchar3/Documents/data/moisesdb/npy/e2ccbc17-44bf-431a-af2b-4cf2fbd19a72/mixture.npy\"\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"audio = np.load(file)\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"IPython.display.Audio(audio, rate=44100)"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {
|
| 28 |
+
"ExecuteTime": {
|
| 29 |
+
"end_time": "2024-02-16T18:01:10.487779628Z",
|
| 30 |
+
"start_time": "2024-02-16T18:01:06.898408871Z"
|
| 31 |
+
},
|
| 32 |
+
"collapsed": false
|
| 33 |
+
},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"from scipy.signal import spectrogram\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"f, t, Sxx = spectrogram(audio, 44100, nperseg=1024, noverlap=512)\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"import matplotlib.pyplot as plt\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"plt.pcolormesh(t, f, 10 * np.log10(Sxx[0]))"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"metadata": {
|
| 49 |
+
"collapsed": false
|
| 50 |
+
},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": []
|
| 53 |
+
}
|
| 54 |
+
],
|
| 55 |
+
"metadata": {
|
| 56 |
+
"kernelspec": {
|
| 57 |
+
"display_name": "Python 3",
|
| 58 |
+
"language": "python",
|
| 59 |
+
"name": "python3"
|
| 60 |
+
},
|
| 61 |
+
"language_info": {
|
| 62 |
+
"codemirror_mode": {
|
| 63 |
+
"name": "ipython",
|
| 64 |
+
"version": 2
|
| 65 |
+
},
|
| 66 |
+
"file_extension": ".py",
|
| 67 |
+
"mimetype": "text/x-python",
|
| 68 |
+
"name": "python",
|
| 69 |
+
"nbconvert_exporter": "python",
|
| 70 |
+
"pygments_lexer": "ipython2",
|
| 71 |
+
"version": "2.7.6"
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
"nbformat": 4,
|
| 75 |
+
"nbformat_minor": 0
|
| 76 |
+
}
|
core/data/moisesdb/datamodule.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from core.data.base import from_datasets
|
| 7 |
+
from core.data.moisesdb.dataset import MoisesDBRandomChunkBalancedRandomQueryDataset, MoisesDBRandomChunkRandomQueryDataset, \
|
| 8 |
+
MoisesDBDeterministicChunkDeterministicQueryDataset, \
|
| 9 |
+
MoisesDBFullTrackDataset, MoisesDBVDBODeterministicChunkDataset, \
|
| 10 |
+
MoisesDBVDBOFullTrackDataset, MoisesDBVDBORandomChunkDataset, \
|
| 11 |
+
MoisesDBFullTrackTestQueryDataset
|
| 12 |
+
|
| 13 |
+
def MoisesDataModule(
|
| 14 |
+
data_root: str,
|
| 15 |
+
batch_size: int,
|
| 16 |
+
num_workers: int = 8,
|
| 17 |
+
train_kwargs: Optional[Mapping] = None,
|
| 18 |
+
val_kwargs: Optional[Mapping] = None,
|
| 19 |
+
test_kwargs: Optional[Mapping] = None,
|
| 20 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 21 |
+
) -> pl.LightningDataModule:
|
| 22 |
+
if train_kwargs is None:
|
| 23 |
+
train_kwargs = {}
|
| 24 |
+
|
| 25 |
+
if val_kwargs is None:
|
| 26 |
+
val_kwargs = {}
|
| 27 |
+
|
| 28 |
+
if test_kwargs is None:
|
| 29 |
+
test_kwargs = {}
|
| 30 |
+
|
| 31 |
+
if datamodule_kwargs is None:
|
| 32 |
+
datamodule_kwargs = {}
|
| 33 |
+
|
| 34 |
+
train_dataset = MoisesDBRandomChunkRandomQueryDataset(
|
| 35 |
+
data_root=data_root, split="train", **train_kwargs
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 39 |
+
data_root=data_root, split="val", **val_kwargs
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 43 |
+
data_root=data_root, split="test", **test_kwargs
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
datamodule = from_datasets(
|
| 47 |
+
train_dataset=train_dataset,
|
| 48 |
+
val_dataset=val_dataset,
|
| 49 |
+
test_dataset=test_dataset,
|
| 50 |
+
batch_size=batch_size,
|
| 51 |
+
num_workers=num_workers,
|
| 52 |
+
**datamodule_kwargs
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 56 |
+
datamodule.test_dataloader
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return datamodule
|
| 60 |
+
|
| 61 |
+
def MoisesBalancedTrainDataModule(
|
| 62 |
+
data_root: str,
|
| 63 |
+
batch_size: int,
|
| 64 |
+
num_workers: int = 8,
|
| 65 |
+
train_kwargs: Optional[Mapping] = None,
|
| 66 |
+
val_kwargs: Optional[Mapping] = None,
|
| 67 |
+
test_kwargs: Optional[Mapping] = None,
|
| 68 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 69 |
+
) -> pl.LightningDataModule:
|
| 70 |
+
if train_kwargs is None:
|
| 71 |
+
train_kwargs = {}
|
| 72 |
+
|
| 73 |
+
if val_kwargs is None:
|
| 74 |
+
val_kwargs = {}
|
| 75 |
+
|
| 76 |
+
if test_kwargs is None:
|
| 77 |
+
test_kwargs = {}
|
| 78 |
+
|
| 79 |
+
if datamodule_kwargs is None:
|
| 80 |
+
datamodule_kwargs = {}
|
| 81 |
+
|
| 82 |
+
train_dataset = MoisesDBRandomChunkBalancedRandomQueryDataset(
|
| 83 |
+
data_root=data_root, split="train", **train_kwargs
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 87 |
+
data_root=data_root, split="val", **val_kwargs
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 91 |
+
data_root=data_root, split="test", **test_kwargs
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
datamodule = from_datasets(
|
| 95 |
+
train_dataset=train_dataset,
|
| 96 |
+
val_dataset=val_dataset,
|
| 97 |
+
test_dataset=test_dataset,
|
| 98 |
+
batch_size=batch_size,
|
| 99 |
+
num_workers=num_workers,
|
| 100 |
+
**datamodule_kwargs
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 104 |
+
datamodule.test_dataloader
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return datamodule
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def MoisesValidationDataModule(
|
| 111 |
+
data_root: str,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
num_workers: int = 8,
|
| 114 |
+
val_kwargs: Optional[Mapping] = None,
|
| 115 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 116 |
+
**kwargs
|
| 117 |
+
) -> pl.LightningDataModule:
|
| 118 |
+
if val_kwargs is None:
|
| 119 |
+
val_kwargs = {}
|
| 120 |
+
|
| 121 |
+
if datamodule_kwargs is None:
|
| 122 |
+
datamodule_kwargs = {}
|
| 123 |
+
|
| 124 |
+
allowed_stems = val_kwargs.get("allowed_stems", None)
|
| 125 |
+
|
| 126 |
+
assert allowed_stems is not None, "allowed_stems must be provided"
|
| 127 |
+
|
| 128 |
+
val_datasets = []
|
| 129 |
+
|
| 130 |
+
for allowed_stem in allowed_stems:
|
| 131 |
+
kwargs = val_kwargs.copy()
|
| 132 |
+
kwargs["allowed_stems"] = [allowed_stem]
|
| 133 |
+
val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 134 |
+
data_root=data_root, split="val",
|
| 135 |
+
**kwargs
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
val_datasets.append(val_dataset)
|
| 139 |
+
|
| 140 |
+
datamodule = from_datasets(
|
| 141 |
+
val_dataset=val_datasets,
|
| 142 |
+
batch_size=batch_size,
|
| 143 |
+
num_workers=num_workers,
|
| 144 |
+
**datamodule_kwargs
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 148 |
+
datamodule.val_dataloader
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return datamodule
|
| 152 |
+
|
| 153 |
+
def MoisesTestDataModule(
|
| 154 |
+
data_root: str,
|
| 155 |
+
batch_size: int = 1,
|
| 156 |
+
num_workers: int = 8,
|
| 157 |
+
test_kwargs: Optional[Mapping] = None,
|
| 158 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 159 |
+
**kwargs
|
| 160 |
+
) -> pl.LightningDataModule:
|
| 161 |
+
if test_kwargs is None:
|
| 162 |
+
test_kwargs = {}
|
| 163 |
+
|
| 164 |
+
if datamodule_kwargs is None:
|
| 165 |
+
datamodule_kwargs = {}
|
| 166 |
+
|
| 167 |
+
allowed_stems = test_kwargs.get("allowed_stems", None)
|
| 168 |
+
|
| 169 |
+
assert allowed_stems is not None, "allowed_stems must be provided"
|
| 170 |
+
|
| 171 |
+
test_dataset = MoisesDBFullTrackTestQueryDataset(
|
| 172 |
+
data_root=data_root, split="test",
|
| 173 |
+
**test_kwargs
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
datamodule = from_datasets(
|
| 177 |
+
test_dataset=test_dataset,
|
| 178 |
+
batch_size=batch_size,
|
| 179 |
+
num_workers=num_workers,
|
| 180 |
+
**datamodule_kwargs
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 184 |
+
datamodule.test_dataloader
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return datamodule
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def MoisesVDBODataModule(
|
| 191 |
+
data_root: str,
|
| 192 |
+
batch_size: int,
|
| 193 |
+
num_workers: int = 8,
|
| 194 |
+
train_kwargs: Optional[Mapping] = None,
|
| 195 |
+
val_kwargs: Optional[Mapping] = None,
|
| 196 |
+
test_kwargs: Optional[Mapping] = None,
|
| 197 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 198 |
+
):
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if train_kwargs is None:
|
| 202 |
+
train_kwargs = {}
|
| 203 |
+
|
| 204 |
+
if val_kwargs is None:
|
| 205 |
+
val_kwargs = {}
|
| 206 |
+
|
| 207 |
+
if test_kwargs is None:
|
| 208 |
+
test_kwargs = {}
|
| 209 |
+
|
| 210 |
+
if datamodule_kwargs is None:
|
| 211 |
+
datamodule_kwargs = {}
|
| 212 |
+
|
| 213 |
+
train_dataset = MoisesDBVDBORandomChunkDataset(
|
| 214 |
+
data_root=data_root, split="train", **train_kwargs
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
val_dataset = MoisesDBVDBODeterministicChunkDataset(
|
| 218 |
+
data_root=data_root, split="val", **val_kwargs
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
test_dataset = MoisesDBVDBOFullTrackDataset(
|
| 222 |
+
data_root=data_root, split="test", **test_kwargs
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
predict_dataset = test_dataset
|
| 226 |
+
|
| 227 |
+
datamodule = from_datasets(
|
| 228 |
+
train_dataset=train_dataset,
|
| 229 |
+
val_dataset=val_dataset,
|
| 230 |
+
test_dataset=test_dataset,
|
| 231 |
+
predict_dataset=predict_dataset,
|
| 232 |
+
batch_size=batch_size,
|
| 233 |
+
num_workers=num_workers,
|
| 234 |
+
**datamodule_kwargs
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return datamodule
|
| 238 |
+
|
| 239 |
+
|
core/data/moisesdb/dataset.py
ADDED
|
@@ -0,0 +1,1383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
from abc import ABC
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
from torch_audiomentations.utils.object_dict import ObjectDict
|
| 14 |
+
import torchaudio as ta
|
| 15 |
+
from torch.utils import data
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from core.data.base import BaseSourceSeparationDataset
|
| 19 |
+
from core.types import input_dict
|
| 20 |
+
|
| 21 |
+
from . import clean_track_inst
|
| 22 |
+
|
| 23 |
+
from torch import Tensor, nn
|
| 24 |
+
|
| 25 |
+
DBFS_HOP_SIZE = int(0.125 * 44100)
|
| 26 |
+
DBFS_CHUNK_SIZE = int(1 * 44100)
|
| 27 |
+
|
| 28 |
+
INST_BY_OCCURRENCE = [
|
| 29 |
+
"bass_guitar",
|
| 30 |
+
"kick_drum",
|
| 31 |
+
"snare_drum",
|
| 32 |
+
"lead_male_singer",
|
| 33 |
+
"distorted_electric_guitar",
|
| 34 |
+
"clean_electric_guitar",
|
| 35 |
+
"toms",
|
| 36 |
+
"acoustic_guitar",
|
| 37 |
+
"background_vocals",
|
| 38 |
+
"hi_hat",
|
| 39 |
+
"overheads",
|
| 40 |
+
"atonal_percussion",
|
| 41 |
+
"grand_piano",
|
| 42 |
+
"cymbals",
|
| 43 |
+
"lead_female_singer",
|
| 44 |
+
"synth_lead",
|
| 45 |
+
"bass_synthesizer",
|
| 46 |
+
"synth_pad",
|
| 47 |
+
"organ_electric_organ",
|
| 48 |
+
"fx",
|
| 49 |
+
"drum_machine",
|
| 50 |
+
"string_section",
|
| 51 |
+
"electric_piano",
|
| 52 |
+
"full_acoustic_drumkit",
|
| 53 |
+
"other_sounds",
|
| 54 |
+
"pitched_percussion",
|
| 55 |
+
"brass",
|
| 56 |
+
"reeds",
|
| 57 |
+
"contrabass_double_bass",
|
| 58 |
+
"other_plucked",
|
| 59 |
+
"other_strings",
|
| 60 |
+
"other_wind",
|
| 61 |
+
"cello",
|
| 62 |
+
"other",
|
| 63 |
+
"flutes",
|
| 64 |
+
"viola_section",
|
| 65 |
+
"viola",
|
| 66 |
+
"cello_section",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
FINE_LEVEL_INSTRUMENTS = {
|
| 70 |
+
"lead_male_singer",
|
| 71 |
+
"lead_female_singer",
|
| 72 |
+
"human_choir",
|
| 73 |
+
"background_vocals",
|
| 74 |
+
"other_vocals",
|
| 75 |
+
"bass_guitar",
|
| 76 |
+
"bass_synthesizer",
|
| 77 |
+
"contrabass_double_bass",
|
| 78 |
+
"tuba",
|
| 79 |
+
"bassoon",
|
| 80 |
+
"snare_drum",
|
| 81 |
+
"toms",
|
| 82 |
+
"kick_drum",
|
| 83 |
+
"cymbals",
|
| 84 |
+
"overheads",
|
| 85 |
+
"full_acoustic_drumkit",
|
| 86 |
+
"drum_machine",
|
| 87 |
+
"hihat",
|
| 88 |
+
"fx",
|
| 89 |
+
"click_track",
|
| 90 |
+
"clean_electric_guitar",
|
| 91 |
+
"distorted_electric_guitar",
|
| 92 |
+
"lap_steel_guitar_or_slide_guitar",
|
| 93 |
+
"acoustic_guitar",
|
| 94 |
+
"other_plucked",
|
| 95 |
+
"atonal_percussion",
|
| 96 |
+
"pitched_percussion",
|
| 97 |
+
"grand_piano",
|
| 98 |
+
"electric_piano",
|
| 99 |
+
"organ_electric_organ",
|
| 100 |
+
"synth_pad",
|
| 101 |
+
"synth_lead",
|
| 102 |
+
"other_sounds",
|
| 103 |
+
"violin",
|
| 104 |
+
"viola",
|
| 105 |
+
"cello",
|
| 106 |
+
"violin_section",
|
| 107 |
+
"viola_section",
|
| 108 |
+
"cello_section",
|
| 109 |
+
"string_section",
|
| 110 |
+
"other_strings",
|
| 111 |
+
"brass",
|
| 112 |
+
"flutes",
|
| 113 |
+
"reeds",
|
| 114 |
+
"other_wind",
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
COARSE_LEVEL_INSTRUMENTS = {
|
| 118 |
+
"vocals",
|
| 119 |
+
"bass",
|
| 120 |
+
"drums",
|
| 121 |
+
"guitar",
|
| 122 |
+
"other_plucked",
|
| 123 |
+
"percussion",
|
| 124 |
+
"piano",
|
| 125 |
+
"other_keys",
|
| 126 |
+
"bowed_strings",
|
| 127 |
+
"wind",
|
| 128 |
+
"other",
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
COARSE_TO_FINE = {
|
| 132 |
+
"vocals": [
|
| 133 |
+
"lead_male_singer",
|
| 134 |
+
"lead_female_singer",
|
| 135 |
+
"human_choir",
|
| 136 |
+
"background_vocals",
|
| 137 |
+
"other_vocals",
|
| 138 |
+
],
|
| 139 |
+
"bass": [
|
| 140 |
+
"bass_guitar",
|
| 141 |
+
"bass_synthesizer",
|
| 142 |
+
"contrabass_double_bass",
|
| 143 |
+
"tuba",
|
| 144 |
+
"bassoon",
|
| 145 |
+
],
|
| 146 |
+
"drums": [
|
| 147 |
+
"snare_drum",
|
| 148 |
+
"toms",
|
| 149 |
+
"kick_drum",
|
| 150 |
+
"cymbals",
|
| 151 |
+
"overheads",
|
| 152 |
+
"full_acoustic_drumkit",
|
| 153 |
+
"drum_machine",
|
| 154 |
+
"hihat",
|
| 155 |
+
],
|
| 156 |
+
"other": ["fx", "click_track"],
|
| 157 |
+
"guitar": [
|
| 158 |
+
"clean_electric_guitar",
|
| 159 |
+
"distorted_electric_guitar",
|
| 160 |
+
"lap_steel_guitar_or_slide_guitar",
|
| 161 |
+
"acoustic_guitar",
|
| 162 |
+
],
|
| 163 |
+
"other_plucked": ["other_plucked"],
|
| 164 |
+
"percussion": ["atonal_percussion", "pitched_percussion"],
|
| 165 |
+
"piano": ["grand_piano", "electric_piano"],
|
| 166 |
+
"other_keys": ["organ_electric_organ", "synth_pad", "synth_lead", "other_sounds"],
|
| 167 |
+
"bowed_strings": [
|
| 168 |
+
"violin",
|
| 169 |
+
"viola",
|
| 170 |
+
"cello",
|
| 171 |
+
"violin_section",
|
| 172 |
+
"viola_section",
|
| 173 |
+
"cello_section",
|
| 174 |
+
"string_section",
|
| 175 |
+
"other_strings",
|
| 176 |
+
],
|
| 177 |
+
"wind": ["brass", "flutes", "reeds", "other_wind"],
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
COARSE_TO_FINE = {k: set(v) for k, v in COARSE_TO_FINE.items()}
|
| 181 |
+
FINE_TO_COARSE = {k: kk for kk, v in COARSE_TO_FINE.items() for k in v}
|
| 182 |
+
|
| 183 |
+
ALL_LEVEL_INSTRUMENTS = COARSE_LEVEL_INSTRUMENTS.union(FINE_LEVEL_INSTRUMENTS)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class MoisesDBBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
split: str,
|
| 190 |
+
data_path: str = "/home/kwatchar3/Documents/data/moisesdb",
|
| 191 |
+
fs: int = 44100,
|
| 192 |
+
return_stems: Union[bool, List[str]] = False,
|
| 193 |
+
npy_memmap=True,
|
| 194 |
+
recompute_mixture=False,
|
| 195 |
+
train_folds=None,
|
| 196 |
+
val_folds=None,
|
| 197 |
+
test_folds=None,
|
| 198 |
+
query_file="query",
|
| 199 |
+
) -> None:
|
| 200 |
+
if test_folds is None:
|
| 201 |
+
test_folds = [5]
|
| 202 |
+
|
| 203 |
+
if val_folds is None:
|
| 204 |
+
val_folds = [4]
|
| 205 |
+
|
| 206 |
+
if train_folds is None:
|
| 207 |
+
train_folds = [1, 2, 3]
|
| 208 |
+
|
| 209 |
+
split_path = os.path.join(data_path, "splits.csv")
|
| 210 |
+
splits = pd.read_csv(split_path)
|
| 211 |
+
|
| 212 |
+
metadata_path = os.path.join(data_path, "stems.csv")
|
| 213 |
+
metadata = pd.read_csv(metadata_path)
|
| 214 |
+
|
| 215 |
+
if split == "train":
|
| 216 |
+
folds = train_folds
|
| 217 |
+
elif split == "val":
|
| 218 |
+
folds = val_folds
|
| 219 |
+
elif split == "test":
|
| 220 |
+
folds = test_folds
|
| 221 |
+
else:
|
| 222 |
+
raise NameError
|
| 223 |
+
|
| 224 |
+
files = splits[splits["split"].isin(folds)]["song_id"].tolist()
|
| 225 |
+
metadata = metadata[metadata["song_id"].isin(files)]
|
| 226 |
+
|
| 227 |
+
super().__init__(
|
| 228 |
+
split=split,
|
| 229 |
+
stems=["mixture"],
|
| 230 |
+
files=files,
|
| 231 |
+
data_path=data_path,
|
| 232 |
+
fs=fs,
|
| 233 |
+
npy_memmap=npy_memmap,
|
| 234 |
+
recompute_mixture=recompute_mixture,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
self.folds = folds
|
| 238 |
+
|
| 239 |
+
self.metadata = metadata.rename(
|
| 240 |
+
columns={k: k.replace(" ", "_") for k in metadata.columns}
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.song_to_stem = (
|
| 244 |
+
metadata.set_index("song_id")
|
| 245 |
+
.apply(lambda row: row[row == 1].index.tolist(), axis=1)
|
| 246 |
+
.to_dict()
|
| 247 |
+
)
|
| 248 |
+
self.stem_to_song = (
|
| 249 |
+
metadata.set_index("song_id")
|
| 250 |
+
.transpose()
|
| 251 |
+
.apply(lambda row: row[row == 1].index.tolist(), axis=1)
|
| 252 |
+
.to_dict()
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.true_length = len(self.files)
|
| 256 |
+
self.n_channels = 2
|
| 257 |
+
|
| 258 |
+
self.audio_path = os.path.join(data_path, "npy2")
|
| 259 |
+
|
| 260 |
+
self.return_stems = return_stems
|
| 261 |
+
|
| 262 |
+
self.query_file = query_file
|
| 263 |
+
|
| 264 |
+
def get_full_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 265 |
+
song_id = identifier["song_id"]
|
| 266 |
+
path = os.path.join(self.data_path, "npy2", song_id)
|
| 267 |
+
# noinspection PyUnresolvedReferences
|
| 268 |
+
|
| 269 |
+
assert self.npy_memmap
|
| 270 |
+
|
| 271 |
+
if os.path.exists(os.path.join(path, f"{stem}.npy")):
|
| 272 |
+
audio = np.load(os.path.join(path, f"{stem}.npy"), mmap_mode="r")
|
| 273 |
+
else:
|
| 274 |
+
audio = None
|
| 275 |
+
|
| 276 |
+
return audio
|
| 277 |
+
|
| 278 |
+
def get_query_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 279 |
+
song_id = identifier["song_id"]
|
| 280 |
+
path = os.path.join(self.data_path, "npyq", song_id)
|
| 281 |
+
# noinspection PyUnresolvedReferences
|
| 282 |
+
|
| 283 |
+
if self.npy_memmap:
|
| 284 |
+
# print(self.npy_memmap)
|
| 285 |
+
audio = np.load(
|
| 286 |
+
os.path.join(path, f"{stem}.{self.query_file}.npy"), mmap_mode="r"
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
raise NotImplementedError
|
| 290 |
+
|
| 291 |
+
return audio
|
| 292 |
+
|
| 293 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 294 |
+
audio = self.get_full_stem(stem=stem, identifier=identifier)
|
| 295 |
+
return audio
|
| 296 |
+
|
| 297 |
+
def get_identifier(self, index):
|
| 298 |
+
return dict(song_id=self.files[index % self.true_length])
|
| 299 |
+
|
| 300 |
+
def __getitem__(self, index: int):
|
| 301 |
+
identifier = self.get_identifier(index)
|
| 302 |
+
audio = self.get_audio(identifier)
|
| 303 |
+
|
| 304 |
+
mixture = audio["mixture"].copy()
|
| 305 |
+
|
| 306 |
+
if isinstance(self.return_stems, list):
|
| 307 |
+
sources = {
|
| 308 |
+
stem: audio.get(stem, np.zeros_like(mixture))
|
| 309 |
+
for stem in self.return_stems
|
| 310 |
+
}
|
| 311 |
+
elif isinstance(self.return_stems, bool):
|
| 312 |
+
if self.return_stems:
|
| 313 |
+
sources = {
|
| 314 |
+
stem: audio[stem].copy()
|
| 315 |
+
for stem in self.song_to_stem[identifier["song_id"]]
|
| 316 |
+
}
|
| 317 |
+
else:
|
| 318 |
+
sources = None
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError
|
| 321 |
+
|
| 322 |
+
return input_dict(
|
| 323 |
+
mixture=mixture,
|
| 324 |
+
sources=sources,
|
| 325 |
+
metadata=identifier,
|
| 326 |
+
modality="audio",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class MoisesDBFullTrackDataset(MoisesDBBaseDataset):
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
data_root: str,
|
| 334 |
+
split: str,
|
| 335 |
+
return_stems: Union[bool, List[str]] = False,
|
| 336 |
+
npy_memmap=True,
|
| 337 |
+
recompute_mixture=False,
|
| 338 |
+
query_file="query",
|
| 339 |
+
) -> None:
|
| 340 |
+
super().__init__(
|
| 341 |
+
split=split,
|
| 342 |
+
data_path=data_root,
|
| 343 |
+
return_stems=return_stems,
|
| 344 |
+
npy_memmap=npy_memmap,
|
| 345 |
+
recompute_mixture=recompute_mixture,
|
| 346 |
+
query_file=query_file,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def __len__(self) -> int:
|
| 350 |
+
return self.true_length
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class MoisesDBVDBOFullTrackDataset(MoisesDBFullTrackDataset):
|
| 354 |
+
def __init__(
|
| 355 |
+
self, data_root: str, split: str, npy_memmap=True, recompute_mixture=False
|
| 356 |
+
) -> None:
|
| 357 |
+
super().__init__(
|
| 358 |
+
data_root=data_root,
|
| 359 |
+
split=split,
|
| 360 |
+
return_stems=["vocals", "bass", "drums", "vdbo_others"],
|
| 361 |
+
npy_memmap=npy_memmap,
|
| 362 |
+
recompute_mixture=recompute_mixture,
|
| 363 |
+
query_file=None,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
import torch_audiomentations as audiomentations
|
| 368 |
+
from torch_audiomentations.utils.dsp import convert_decibels_to_amplitude_ratio
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class SmartGain(audiomentations.Gain):
|
| 372 |
+
def __init__(
|
| 373 |
+
self, p=0.5, min_gain_in_db=-6, max_gain_in_db=6, dbfs_threshold=-45.0
|
| 374 |
+
):
|
| 375 |
+
super().__init__(
|
| 376 |
+
p=p, min_gain_in_db=min_gain_in_db, max_gain_in_db=max_gain_in_db
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.dbfs_threshold = dbfs_threshold
|
| 380 |
+
|
| 381 |
+
def randomize_parameters(
|
| 382 |
+
self,
|
| 383 |
+
samples: Tensor = None,
|
| 384 |
+
sample_rate: Optional[int] = None,
|
| 385 |
+
targets: Optional[Tensor] = None,
|
| 386 |
+
target_rate: Optional[int] = None,
|
| 387 |
+
):
|
| 388 |
+
|
| 389 |
+
dbfs = 10 * torch.log10(torch.mean(torch.square(samples)) + 1e-6)
|
| 390 |
+
|
| 391 |
+
if dbfs > self.dbfs_threshold:
|
| 392 |
+
low = self.min_gain_in_db
|
| 393 |
+
else:
|
| 394 |
+
low = max(0.0, self.min_gain_in_db)
|
| 395 |
+
|
| 396 |
+
distribution = torch.distributions.Uniform(
|
| 397 |
+
low=torch.tensor(low, dtype=torch.float32, device=samples.device),
|
| 398 |
+
high=torch.tensor(
|
| 399 |
+
self.max_gain_in_db, dtype=torch.float32, device=samples.device
|
| 400 |
+
),
|
| 401 |
+
validate_args=True,
|
| 402 |
+
)
|
| 403 |
+
selected_batch_size = samples.size(0)
|
| 404 |
+
self.transform_parameters["gain_factors"] = (
|
| 405 |
+
convert_decibels_to_amplitude_ratio(
|
| 406 |
+
distribution.sample(sample_shape=(selected_batch_size,))
|
| 407 |
+
)
|
| 408 |
+
.unsqueeze(1)
|
| 409 |
+
.unsqueeze(1)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class Audiomentations(audiomentations.Compose):
|
| 414 |
+
def __init__(self, augment="gssp", fs: int = 44100):
|
| 415 |
+
|
| 416 |
+
if isinstance(augment, str):
|
| 417 |
+
if augment == "gssp":
|
| 418 |
+
augment = OmegaConf.create(
|
| 419 |
+
[
|
| 420 |
+
dict(
|
| 421 |
+
cls="Shift",
|
| 422 |
+
kwargs=dict(p=1.0, min_shift=-0.5, max_shift=0.5),
|
| 423 |
+
),
|
| 424 |
+
dict(
|
| 425 |
+
cls="Gain",
|
| 426 |
+
kwargs=dict(p=1.0, min_gain_in_db=-6, max_gain_in_db=6),
|
| 427 |
+
),
|
| 428 |
+
dict(cls="ShuffleChannels", kwargs=dict(p=0.5)),
|
| 429 |
+
dict(cls="PolarityInversion", kwargs=dict(p=0.5)),
|
| 430 |
+
]
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
raise ValueError
|
| 434 |
+
|
| 435 |
+
transforms = []
|
| 436 |
+
|
| 437 |
+
for transform in augment:
|
| 438 |
+
|
| 439 |
+
if transform.cls == "Gain":
|
| 440 |
+
transforms.append(SmartGain(**transform.kwargs))
|
| 441 |
+
else:
|
| 442 |
+
transforms.append(
|
| 443 |
+
getattr(audiomentations, transform.cls)(**transform.kwargs)
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
super().__init__(transforms=transforms, shuffle=True)
|
| 447 |
+
|
| 448 |
+
self.fs = fs
|
| 449 |
+
|
| 450 |
+
def forward(
|
| 451 |
+
self,
|
| 452 |
+
samples: torch.Tensor = None,
|
| 453 |
+
) -> ObjectDict:
|
| 454 |
+
return super().forward(samples, sample_rate=self.fs)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class MoisesDBVDBORandomChunkDataset(MoisesDBVDBOFullTrackDataset):
|
| 458 |
+
def __init__(
|
| 459 |
+
self,
|
| 460 |
+
data_root: str,
|
| 461 |
+
split: str,
|
| 462 |
+
chunk_size_seconds: float = 4.0,
|
| 463 |
+
fs: int = 44100,
|
| 464 |
+
target_length: int = 8192,
|
| 465 |
+
augment=None,
|
| 466 |
+
npy_memmap=True,
|
| 467 |
+
recompute_mixture=True,
|
| 468 |
+
db_threshold=-24.0,
|
| 469 |
+
db_step=-12.0,
|
| 470 |
+
) -> None:
|
| 471 |
+
super().__init__(
|
| 472 |
+
data_root=data_root,
|
| 473 |
+
split=split,
|
| 474 |
+
npy_memmap=npy_memmap,
|
| 475 |
+
recompute_mixture=recompute_mixture,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self.chunk_size_seconds = chunk_size_seconds
|
| 479 |
+
self.chunk_size_samples = int(chunk_size_seconds * fs)
|
| 480 |
+
self.fs = fs
|
| 481 |
+
|
| 482 |
+
self.target_length = target_length
|
| 483 |
+
|
| 484 |
+
self.db_threshold = db_threshold
|
| 485 |
+
self.db_step = db_step
|
| 486 |
+
|
| 487 |
+
if augment is not None:
|
| 488 |
+
assert self.recompute_mixture
|
| 489 |
+
self.augment = Audiomentations(augment, fs)
|
| 490 |
+
else:
|
| 491 |
+
self.augment = None
|
| 492 |
+
|
| 493 |
+
def __len__(self) -> int:
|
| 494 |
+
return self.target_length
|
| 495 |
+
|
| 496 |
+
def _chunk_audio(self, audio, start, end):
|
| 497 |
+
audio = {k: v[..., start:end] for k, v in audio.items()}
|
| 498 |
+
|
| 499 |
+
return audio
|
| 500 |
+
|
| 501 |
+
def _get_start_end(self, audio, identifier):
|
| 502 |
+
n_samples = audio.shape[-1]
|
| 503 |
+
start = np.random.randint(0, n_samples - self.chunk_size_samples)
|
| 504 |
+
end = start + self.chunk_size_samples
|
| 505 |
+
|
| 506 |
+
return start, end
|
| 507 |
+
|
| 508 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 509 |
+
audio = {}
|
| 510 |
+
|
| 511 |
+
for stem in stems:
|
| 512 |
+
audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
|
| 513 |
+
|
| 514 |
+
for stem in stems:
|
| 515 |
+
if audio[stem] is None:
|
| 516 |
+
audio[stem] = np.zeros(
|
| 517 |
+
audio[
|
| 518 |
+
(
|
| 519 |
+
"mixture"
|
| 520 |
+
if "mixture" in stems
|
| 521 |
+
else [s for s in stems if audio[s] is not None][0]
|
| 522 |
+
)
|
| 523 |
+
].shape,
|
| 524 |
+
dtype=np.float32,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
start, end = self._get_start_end(audio[stems[0]], identifier)
|
| 528 |
+
audio = self._chunk_audio(audio, start, end)
|
| 529 |
+
|
| 530 |
+
if self.augment is not None:
|
| 531 |
+
audio = {
|
| 532 |
+
k: self.augment(torch.from_numpy(v[None, :, :]))[0, :, :].numpy()
|
| 533 |
+
for k, v in audio.items()
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
return audio
|
| 537 |
+
|
| 538 |
+
def get_audio(self, identifier: Dict[str, Any]):
|
| 539 |
+
if self.recompute_mixture:
|
| 540 |
+
audio = self._get_audio(
|
| 541 |
+
["vocals", "bass", "drums", "vdbo_others"], identifier=identifier
|
| 542 |
+
)
|
| 543 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 544 |
+
return audio
|
| 545 |
+
else:
|
| 546 |
+
return self._get_audio(
|
| 547 |
+
["mixture", "vocals", "bass", "drums", "vdbo_others"],
|
| 548 |
+
identifier=identifier,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def __getitem__(self, index: int):
|
| 552 |
+
|
| 553 |
+
identifier = self.get_identifier(index)
|
| 554 |
+
audio = self.get_audio(identifier=identifier)
|
| 555 |
+
|
| 556 |
+
mixture = audio["mixture"].copy()
|
| 557 |
+
|
| 558 |
+
sources = {
|
| 559 |
+
stem: audio.get(stem, np.zeros_like(mixture)) for stem in self.return_stems
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
return input_dict(
|
| 563 |
+
mixture=mixture,
|
| 564 |
+
sources=sources,
|
| 565 |
+
metadata=identifier,
|
| 566 |
+
modality="audio",
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class MoisesDBVDBODeterministicChunkDataset(MoisesDBVDBORandomChunkDataset):
|
| 571 |
+
def __init__(
|
| 572 |
+
self,
|
| 573 |
+
data_root: str,
|
| 574 |
+
split: str,
|
| 575 |
+
chunk_size_seconds: float = 4.0,
|
| 576 |
+
hop_size_seconds: float = 8.0,
|
| 577 |
+
fs: int = 44100,
|
| 578 |
+
npy_memmap=True,
|
| 579 |
+
recompute_mixture=False,
|
| 580 |
+
) -> None:
|
| 581 |
+
super().__init__(
|
| 582 |
+
data_root=data_root,
|
| 583 |
+
split=split,
|
| 584 |
+
chunk_size_seconds=chunk_size_seconds,
|
| 585 |
+
npy_memmap=npy_memmap,
|
| 586 |
+
recompute_mixture=recompute_mixture,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
self.hop_size_seconds = hop_size_seconds
|
| 590 |
+
self.hop_size_samples = int(hop_size_seconds * fs)
|
| 591 |
+
|
| 592 |
+
self.index_to_identifiers = self._generate_index()
|
| 593 |
+
self.length = len(self.index_to_identifiers)
|
| 594 |
+
|
| 595 |
+
def __len__(self) -> int:
|
| 596 |
+
return self.length
|
| 597 |
+
|
| 598 |
+
def _generate_index(self):
|
| 599 |
+
|
| 600 |
+
identifiers = []
|
| 601 |
+
|
| 602 |
+
for song_id in self.files:
|
| 603 |
+
audio = self.get_full_stem(stem="mixture", identifier=dict(song_id=song_id))
|
| 604 |
+
n_samples = audio.shape[-1]
|
| 605 |
+
n_chunks = math.floor(
|
| 606 |
+
(n_samples - self.chunk_size_samples) / self.hop_size_samples
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
for i in range(n_chunks):
|
| 610 |
+
chunk_start = i * self.hop_size_samples
|
| 611 |
+
identifiers.append(dict(song_id=song_id, chunk_start=chunk_start))
|
| 612 |
+
|
| 613 |
+
return identifiers
|
| 614 |
+
|
| 615 |
+
def get_identifier(self, index):
|
| 616 |
+
return self.index_to_identifiers[index]
|
| 617 |
+
|
| 618 |
+
def _get_start_end(self, audio, identifier):
|
| 619 |
+
|
| 620 |
+
start = identifier["chunk_start"]
|
| 621 |
+
end = start + self.chunk_size_samples
|
| 622 |
+
|
| 623 |
+
return start, end
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def round_samples(seconds, fs, hop_size, downsample):
|
| 627 |
+
n_frames = math.ceil(seconds * fs / hop_size) + 1
|
| 628 |
+
n_frames_down = math.ceil(n_frames / downsample)
|
| 629 |
+
n_frames = n_frames_down * downsample
|
| 630 |
+
n_samples = (n_frames - 1) * hop_size
|
| 631 |
+
|
| 632 |
+
return int(n_samples)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class MoisesDBRandomChunkRandomQueryDataset(MoisesDBFullTrackDataset):
|
| 636 |
+
def __init__(
|
| 637 |
+
self,
|
| 638 |
+
data_root: str,
|
| 639 |
+
split: str,
|
| 640 |
+
target_length: int,
|
| 641 |
+
chunk_size_seconds: float = 4.0,
|
| 642 |
+
query_size_seconds: float = 1.0,
|
| 643 |
+
round_query: bool = False,
|
| 644 |
+
min_query_dbfs: float = -40.0,
|
| 645 |
+
min_target_dbfs: float = -36.0,
|
| 646 |
+
min_target_dbfs_step: float = -12.0,
|
| 647 |
+
max_dbfs_tries: int = 10,
|
| 648 |
+
top_k_instrument: int = 10,
|
| 649 |
+
mixture_stem: str = "mixture",
|
| 650 |
+
use_own_query: bool = True,
|
| 651 |
+
npy_memmap=True,
|
| 652 |
+
allowed_stems=None,
|
| 653 |
+
query_file="query",
|
| 654 |
+
augment=None,
|
| 655 |
+
) -> None:
|
| 656 |
+
|
| 657 |
+
super().__init__(
|
| 658 |
+
data_root=data_root,
|
| 659 |
+
split=split,
|
| 660 |
+
npy_memmap=npy_memmap,
|
| 661 |
+
recompute_mixture=augment is not None,
|
| 662 |
+
query_file=query_file,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.mixture_stem = mixture_stem
|
| 666 |
+
|
| 667 |
+
self.chunk_size_seconds = chunk_size_seconds
|
| 668 |
+
self.chunk_size_samples = round_samples(
|
| 669 |
+
self.chunk_size_seconds, self.fs, 512, 2**6
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
self.query_size_seconds = query_size_seconds
|
| 673 |
+
|
| 674 |
+
if round_query:
|
| 675 |
+
self.query_size_samples = round_samples(
|
| 676 |
+
self.query_size_seconds, self.fs, 512, 2**6
|
| 677 |
+
)
|
| 678 |
+
else:
|
| 679 |
+
self.query_size_samples = int(self.query_size_seconds * self.fs)
|
| 680 |
+
|
| 681 |
+
self.target_length = target_length
|
| 682 |
+
|
| 683 |
+
self.min_query_dbfs = min_query_dbfs
|
| 684 |
+
|
| 685 |
+
if min_target_dbfs is None:
|
| 686 |
+
min_target_dbfs = -np.inf
|
| 687 |
+
min_target_dbfs_step = None
|
| 688 |
+
max_dbfs_tries = 1
|
| 689 |
+
|
| 690 |
+
self.min_target_dbfs = min_target_dbfs
|
| 691 |
+
self.min_target_dbfs_step = min_target_dbfs_step
|
| 692 |
+
self.max_dbfs_tries = max_dbfs_tries
|
| 693 |
+
|
| 694 |
+
self.top_k_instrument = top_k_instrument
|
| 695 |
+
|
| 696 |
+
if allowed_stems is None:
|
| 697 |
+
allowed_stems = INST_BY_OCCURRENCE[: self.top_k_instrument]
|
| 698 |
+
else:
|
| 699 |
+
self.top_k_instrument = None
|
| 700 |
+
|
| 701 |
+
self.allowed_stems = allowed_stems
|
| 702 |
+
|
| 703 |
+
self.song_to_all_stems = {
|
| 704 |
+
k: list(set(v) & set(ALL_LEVEL_INSTRUMENTS))
|
| 705 |
+
for k, v in self.song_to_stem.items()
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
self.song_to_stem = {
|
| 709 |
+
k: list(set(v) & set(self.allowed_stems))
|
| 710 |
+
for k, v in self.song_to_stem.items()
|
| 711 |
+
}
|
| 712 |
+
self.stem_to_song = {
|
| 713 |
+
k: list(set(v) & set(self.files)) for k, v in self.stem_to_song.items()
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
self.queriable_songs = [k for k, v in self.song_to_stem.items() if len(v) > 0]
|
| 717 |
+
|
| 718 |
+
self.use_own_query = use_own_query
|
| 719 |
+
|
| 720 |
+
if self.use_own_query:
|
| 721 |
+
self.files = [k for k in self.files if len(self.song_to_stem[k]) > 0]
|
| 722 |
+
self.true_length = len(self.files)
|
| 723 |
+
|
| 724 |
+
if augment is not None:
|
| 725 |
+
assert self.recompute_mixture
|
| 726 |
+
self.augment = Audiomentations(augment, self.fs)
|
| 727 |
+
else:
|
| 728 |
+
self.augment = None
|
| 729 |
+
|
| 730 |
+
def __len__(self) -> int:
|
| 731 |
+
return self.target_length
|
| 732 |
+
|
| 733 |
+
def _chunk_audio(self, audio, start, end):
|
| 734 |
+
audio = {k: v[..., start:end] for k, v in audio.items()}
|
| 735 |
+
|
| 736 |
+
return audio
|
| 737 |
+
|
| 738 |
+
def _get_start_end(self, audio):
|
| 739 |
+
n_samples = audio.shape[-1]
|
| 740 |
+
start = np.random.randint(0, n_samples - self.chunk_size_samples)
|
| 741 |
+
end = start + self.chunk_size_samples
|
| 742 |
+
|
| 743 |
+
return start, end
|
| 744 |
+
|
| 745 |
+
def _target_dbfs(self, audio):
|
| 746 |
+
return 10.0 * np.log10(np.mean(np.square(np.abs(audio))) + 1e-6)
|
| 747 |
+
|
| 748 |
+
def _chunk_and_check_dbfs_threshold(self, audio_, target_stem, threshold):
|
| 749 |
+
|
| 750 |
+
target_dict = {target_stem: audio_[target_stem]}
|
| 751 |
+
|
| 752 |
+
for _ in range(self.max_dbfs_tries):
|
| 753 |
+
start, end = self._get_start_end(audio_[target_stem])
|
| 754 |
+
taudio = self._chunk_audio(target_dict, start, end)
|
| 755 |
+
|
| 756 |
+
dbfs = self._target_dbfs(taudio[target_stem])
|
| 757 |
+
if dbfs > threshold:
|
| 758 |
+
return self._chunk_audio(audio_, start, end)
|
| 759 |
+
|
| 760 |
+
return None
|
| 761 |
+
|
| 762 |
+
def _chunk_and_check_dbfs(self, audio_, target_stem):
|
| 763 |
+
out = self._chunk_and_check_dbfs_threshold(
|
| 764 |
+
audio_, target_stem, self.min_target_dbfs
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
if out is not None:
|
| 768 |
+
return out
|
| 769 |
+
|
| 770 |
+
out = self._chunk_and_check_dbfs_threshold(
|
| 771 |
+
audio_, target_stem, self.min_target_dbfs + self.min_target_dbfs_step
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
if out is not None:
|
| 775 |
+
return out
|
| 776 |
+
|
| 777 |
+
start, end = self._get_start_end(audio_[target_stem])
|
| 778 |
+
audio = self._chunk_audio(audio_, start, end)
|
| 779 |
+
|
| 780 |
+
return audio
|
| 781 |
+
|
| 782 |
+
def _augment(self, audio, target_stem):
|
| 783 |
+
stack_audio = np.stack([v for v in audio.values()], axis=0)
|
| 784 |
+
aug_audio = self.augment(torch.from_numpy(stack_audio)).numpy()
|
| 785 |
+
mixture = np.sum(aug_audio, axis=0)
|
| 786 |
+
|
| 787 |
+
out = {
|
| 788 |
+
"mixture": mixture,
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
if target_stem is not None:
|
| 792 |
+
target_idx = list(audio.keys()).index(target_stem)
|
| 793 |
+
out[target_stem] = aug_audio[target_idx]
|
| 794 |
+
|
| 795 |
+
return out
|
| 796 |
+
|
| 797 |
+
def _choose_stems_for_augment(self, identifier, target_stem):
|
| 798 |
+
stems_for_song = set(self.song_to_all_stems[identifier["song_id"]])
|
| 799 |
+
|
| 800 |
+
stems_ = []
|
| 801 |
+
coarse_level_accounted = set()
|
| 802 |
+
|
| 803 |
+
is_none_target = target_stem is None
|
| 804 |
+
is_coarse_target = target_stem in COARSE_LEVEL_INSTRUMENTS
|
| 805 |
+
|
| 806 |
+
if is_coarse_target or is_none_target:
|
| 807 |
+
coarse_target = target_stem
|
| 808 |
+
else:
|
| 809 |
+
coarse_target = FINE_TO_COARSE[target_stem]
|
| 810 |
+
|
| 811 |
+
fine_level_stems = stems_for_song & FINE_LEVEL_INSTRUMENTS
|
| 812 |
+
coarse_level_stems = stems_for_song & COARSE_LEVEL_INSTRUMENTS
|
| 813 |
+
|
| 814 |
+
for s in fine_level_stems:
|
| 815 |
+
coarse_level = FINE_TO_COARSE[s]
|
| 816 |
+
|
| 817 |
+
if is_coarse_target and coarse_level == coarse_target:
|
| 818 |
+
continue
|
| 819 |
+
else:
|
| 820 |
+
stems_.append(s)
|
| 821 |
+
|
| 822 |
+
coarse_level_accounted.add(coarse_level)
|
| 823 |
+
|
| 824 |
+
stems_ += list(coarse_level_stems - coarse_level_accounted)
|
| 825 |
+
|
| 826 |
+
if target_stem is not None:
|
| 827 |
+
assert target_stem in stems_, f"stems: {stems_}, target stem: {target_stem}"
|
| 828 |
+
|
| 829 |
+
if len(stems_for_song) > 1:
|
| 830 |
+
assert (
|
| 831 |
+
len(stems_) > 1
|
| 832 |
+
), f"stems: {stems_}, stems in song: {stems_for_song},\n target stem: {target_stem}"
|
| 833 |
+
|
| 834 |
+
assert "mixture" not in stems_
|
| 835 |
+
|
| 836 |
+
return stems_
|
| 837 |
+
|
| 838 |
+
def _get_audio(
|
| 839 |
+
self, stems, identifier: Dict[str, Any], check_dbfs=True, no_target=False
|
| 840 |
+
):
|
| 841 |
+
|
| 842 |
+
target_stem = stems[0] if not no_target else None
|
| 843 |
+
|
| 844 |
+
if self.augment is not None:
|
| 845 |
+
stems_ = self._choose_stems_for_augment(identifier, target_stem)
|
| 846 |
+
else:
|
| 847 |
+
stems_ = stems
|
| 848 |
+
|
| 849 |
+
audio = {}
|
| 850 |
+
for stem in stems_:
|
| 851 |
+
audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
|
| 852 |
+
|
| 853 |
+
audio_ = {k: v.copy() for k, v in audio.items()}
|
| 854 |
+
|
| 855 |
+
if check_dbfs:
|
| 856 |
+
assert target_stem is not None
|
| 857 |
+
audio = self._chunk_and_check_dbfs(audio_, target_stem)
|
| 858 |
+
else:
|
| 859 |
+
first_key = list(audio_.keys())[0]
|
| 860 |
+
start, end = self._get_start_end(audio_[first_key])
|
| 861 |
+
audio = self._chunk_audio(audio_, start, end)
|
| 862 |
+
|
| 863 |
+
if self.augment is not None:
|
| 864 |
+
assert "mixture" not in audio
|
| 865 |
+
audio = self._augment(audio, target_stem)
|
| 866 |
+
assert "mixture" in audio
|
| 867 |
+
|
| 868 |
+
return audio
|
| 869 |
+
|
| 870 |
+
def __getitem__(self, index: int):
|
| 871 |
+
|
| 872 |
+
mix_identifier = self.get_identifier(index)
|
| 873 |
+
mix_stems = self.song_to_stem[mix_identifier["song_id"]]
|
| 874 |
+
|
| 875 |
+
if self.use_own_query:
|
| 876 |
+
query_id = mix_identifier["song_id"]
|
| 877 |
+
query_identifier = dict(song_id=query_id)
|
| 878 |
+
possible_stem = mix_stems
|
| 879 |
+
|
| 880 |
+
assert len(possible_stem) > 0
|
| 881 |
+
|
| 882 |
+
zero_target = False
|
| 883 |
+
else:
|
| 884 |
+
query_id = random.choice(self.queriable_songs)
|
| 885 |
+
query_identifier = dict(song_id=query_id)
|
| 886 |
+
query_stems = self.song_to_stem[query_id]
|
| 887 |
+
possible_stem = list(set(mix_stems) & set(query_stems))
|
| 888 |
+
|
| 889 |
+
if len(possible_stem) == 0:
|
| 890 |
+
possible_stem = query_stems
|
| 891 |
+
zero_target = True
|
| 892 |
+
# print(f"Mix {mix_identifier['song_id']} and query {query_id} have no common stems.")
|
| 893 |
+
# return self.__getitem__(index + 1)
|
| 894 |
+
else:
|
| 895 |
+
zero_target = False
|
| 896 |
+
|
| 897 |
+
assert (
|
| 898 |
+
len(possible_stem) > 0
|
| 899 |
+
), f"{mix_identifier['song_id']} and {query_id} have no common stems. zero target is {zero_target}"
|
| 900 |
+
stem = random.choice(possible_stem)
|
| 901 |
+
|
| 902 |
+
if zero_target:
|
| 903 |
+
audio = self._get_audio(
|
| 904 |
+
[self.mixture_stem],
|
| 905 |
+
identifier=mix_identifier,
|
| 906 |
+
check_dbfs=False,
|
| 907 |
+
no_target=True,
|
| 908 |
+
)
|
| 909 |
+
mixture = audio[self.mixture_stem].copy()
|
| 910 |
+
sources = {"target": np.zeros_like(mixture)}
|
| 911 |
+
else:
|
| 912 |
+
audio = self._get_audio(
|
| 913 |
+
[stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True
|
| 914 |
+
)
|
| 915 |
+
mixture = audio[self.mixture_stem].copy()
|
| 916 |
+
sources = {"target": audio[stem].copy()}
|
| 917 |
+
|
| 918 |
+
query = self.get_query_stem(stem=stem, identifier=query_identifier)
|
| 919 |
+
query = query.copy()
|
| 920 |
+
|
| 921 |
+
assert mixture.shape[-1] == self.chunk_size_samples
|
| 922 |
+
assert query.shape[-1] == self.query_size_samples
|
| 923 |
+
assert sources["target"].shape[-1] == self.chunk_size_samples
|
| 924 |
+
|
| 925 |
+
return input_dict(
|
| 926 |
+
mixture=mixture,
|
| 927 |
+
sources=sources,
|
| 928 |
+
query=query,
|
| 929 |
+
metadata={
|
| 930 |
+
"mix": mix_identifier,
|
| 931 |
+
"query": query_identifier,
|
| 932 |
+
"stem": stem,
|
| 933 |
+
},
|
| 934 |
+
modality="audio",
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
class MoisesDBRandomChunkBalancedRandomQueryDataset(
|
| 939 |
+
MoisesDBRandomChunkRandomQueryDataset
|
| 940 |
+
):
|
| 941 |
+
def __init__(
|
| 942 |
+
self,
|
| 943 |
+
data_root: str,
|
| 944 |
+
split: str,
|
| 945 |
+
target_length: int,
|
| 946 |
+
chunk_size_seconds: float = 4,
|
| 947 |
+
query_size_seconds: float = 1,
|
| 948 |
+
round_query: bool = False,
|
| 949 |
+
min_query_dbfs: float = -40.0,
|
| 950 |
+
min_target_dbfs: float = -36.0,
|
| 951 |
+
min_target_dbfs_step: float = -12.0,
|
| 952 |
+
max_dbfs_tries: int = 10,
|
| 953 |
+
top_k_instrument: int = 10,
|
| 954 |
+
mixture_stem: str = "mixture",
|
| 955 |
+
use_own_query: bool = True,
|
| 956 |
+
npy_memmap=True,
|
| 957 |
+
allowed_stems=None,
|
| 958 |
+
query_file="query",
|
| 959 |
+
augment=None,
|
| 960 |
+
) -> None:
|
| 961 |
+
super().__init__(
|
| 962 |
+
data_root,
|
| 963 |
+
split,
|
| 964 |
+
target_length,
|
| 965 |
+
chunk_size_seconds,
|
| 966 |
+
query_size_seconds,
|
| 967 |
+
round_query,
|
| 968 |
+
min_query_dbfs,
|
| 969 |
+
min_target_dbfs,
|
| 970 |
+
min_target_dbfs_step,
|
| 971 |
+
max_dbfs_tries,
|
| 972 |
+
top_k_instrument,
|
| 973 |
+
mixture_stem,
|
| 974 |
+
use_own_query,
|
| 975 |
+
npy_memmap,
|
| 976 |
+
allowed_stems,
|
| 977 |
+
query_file,
|
| 978 |
+
augment,
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
self.stem_to_n_songs = {k: len(v) for k, v in self.stem_to_song.items()}
|
| 982 |
+
self.trainable_stems = [k for k, v in self.stem_to_n_songs.items() if v > 1]
|
| 983 |
+
self.n_allowed_stems = len(self.allowed_stems)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
def __getitem__(self, index: int):
|
| 988 |
+
|
| 989 |
+
stem = self.allowed_stems[index % self.n_allowed_stems]
|
| 990 |
+
song_ids_with_stem = self.stem_to_song[stem]
|
| 991 |
+
|
| 992 |
+
song_id = song_ids_with_stem[index % self.stem_to_n_songs[stem]]
|
| 993 |
+
|
| 994 |
+
mix_identifier = dict(song_id=song_id)
|
| 995 |
+
|
| 996 |
+
audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True)
|
| 997 |
+
mixture = audio[self.mixture_stem].copy()
|
| 998 |
+
|
| 999 |
+
if self.use_own_query:
|
| 1000 |
+
query_id = song_id
|
| 1001 |
+
query_identifier = dict(song_id=query_id)
|
| 1002 |
+
else:
|
| 1003 |
+
query_id = random.choice(song_ids_with_stem)
|
| 1004 |
+
query_identifier = dict(song_id=query_id)
|
| 1005 |
+
|
| 1006 |
+
query = self.get_query_stem(stem=stem, identifier=query_identifier)
|
| 1007 |
+
query = query.copy()
|
| 1008 |
+
|
| 1009 |
+
sources = {"target": audio[stem].copy()}
|
| 1010 |
+
|
| 1011 |
+
return input_dict(
|
| 1012 |
+
mixture=mixture,
|
| 1013 |
+
sources=sources,
|
| 1014 |
+
query=query,
|
| 1015 |
+
metadata={
|
| 1016 |
+
"mix": mix_identifier,
|
| 1017 |
+
"query": query_identifier,
|
| 1018 |
+
"stem": stem,
|
| 1019 |
+
},
|
| 1020 |
+
modality="audio",
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
class MoisesDBDeterministicChunkDeterministicQueryDataset(
|
| 1027 |
+
MoisesDBRandomChunkRandomQueryDataset
|
| 1028 |
+
):
|
| 1029 |
+
def __init__(
|
| 1030 |
+
self,
|
| 1031 |
+
data_root: str,
|
| 1032 |
+
split: str,
|
| 1033 |
+
chunk_size_seconds: float = 4.0,
|
| 1034 |
+
hop_size_seconds: float = 8.0,
|
| 1035 |
+
query_size_seconds: float = 1.0,
|
| 1036 |
+
min_query_dbfs: float = -40.0,
|
| 1037 |
+
top_k_instrument: int = 10,
|
| 1038 |
+
n_queries_per_chunk: int = 1,
|
| 1039 |
+
mixture_stem: str = "mixture",
|
| 1040 |
+
use_own_query: bool = True,
|
| 1041 |
+
npy_memmap=True,
|
| 1042 |
+
allowed_stems: List[str] = None,
|
| 1043 |
+
query_file="query",
|
| 1044 |
+
) -> None:
|
| 1045 |
+
|
| 1046 |
+
super().__init__(
|
| 1047 |
+
data_root=data_root,
|
| 1048 |
+
split=split,
|
| 1049 |
+
target_length=None,
|
| 1050 |
+
chunk_size_seconds=chunk_size_seconds,
|
| 1051 |
+
query_size_seconds=query_size_seconds,
|
| 1052 |
+
min_query_dbfs=min_query_dbfs,
|
| 1053 |
+
top_k_instrument=top_k_instrument,
|
| 1054 |
+
mixture_stem=mixture_stem,
|
| 1055 |
+
use_own_query=use_own_query,
|
| 1056 |
+
npy_memmap=npy_memmap,
|
| 1057 |
+
allowed_stems=allowed_stems,
|
| 1058 |
+
query_file=query_file,
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
if hop_size_seconds is None:
|
| 1062 |
+
hop_size_seconds = chunk_size_seconds
|
| 1063 |
+
|
| 1064 |
+
self.chunk_hop_size_seconds = hop_size_seconds
|
| 1065 |
+
|
| 1066 |
+
self.chunk_hop_size_samples = int(hop_size_seconds * self.fs)
|
| 1067 |
+
|
| 1068 |
+
self.n_queries_per_chunk = n_queries_per_chunk
|
| 1069 |
+
|
| 1070 |
+
self._overwrite = False
|
| 1071 |
+
|
| 1072 |
+
self.query_tuples = self.find_query_tuples_or_generate()
|
| 1073 |
+
self.n_chunks = len(self.query_tuples)
|
| 1074 |
+
|
| 1075 |
+
def __len__(self) -> int:
|
| 1076 |
+
return self.n_chunks
|
| 1077 |
+
|
| 1078 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 1079 |
+
audio = {}
|
| 1080 |
+
|
| 1081 |
+
for stem in stems:
|
| 1082 |
+
audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
|
| 1083 |
+
|
| 1084 |
+
start = identifier["chunk_start"]
|
| 1085 |
+
end = start + self.chunk_size_samples
|
| 1086 |
+
audio = self._chunk_audio(audio, start, end)
|
| 1087 |
+
|
| 1088 |
+
return audio
|
| 1089 |
+
|
| 1090 |
+
def find_query_tuples_or_generate(self):
|
| 1091 |
+
query_path = os.path.join(self.data_path, "queries")
|
| 1092 |
+
val_folds = "-".join(map(str, self.folds))
|
| 1093 |
+
|
| 1094 |
+
path_so_far = os.path.join(query_path, val_folds)
|
| 1095 |
+
|
| 1096 |
+
if not os.path.exists(path_so_far):
|
| 1097 |
+
return self.generate_index()
|
| 1098 |
+
|
| 1099 |
+
chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
|
| 1100 |
+
path_so_far = os.path.join(path_so_far, chunk_specs)
|
| 1101 |
+
|
| 1102 |
+
if not os.path.exists(path_so_far):
|
| 1103 |
+
return self.generate_index()
|
| 1104 |
+
|
| 1105 |
+
query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
|
| 1106 |
+
path_so_far = os.path.join(path_so_far, query_specs)
|
| 1107 |
+
|
| 1108 |
+
if not os.path.exists(path_so_far):
|
| 1109 |
+
return self.generate_index()
|
| 1110 |
+
|
| 1111 |
+
if self.top_k_instrument is not None:
|
| 1112 |
+
path_so_far = os.path.join(
|
| 1113 |
+
path_so_far, f"queries-top{self.top_k_instrument}.csv"
|
| 1114 |
+
)
|
| 1115 |
+
else:
|
| 1116 |
+
if len(self.allowed_stems) > 5:
|
| 1117 |
+
allowed_stems = (
|
| 1118 |
+
str(len(self.allowed_stems))
|
| 1119 |
+
+ "stems:"
|
| 1120 |
+
+ ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
|
| 1121 |
+
)
|
| 1122 |
+
else:
|
| 1123 |
+
allowed_stems = ":".join(self.allowed_stems)
|
| 1124 |
+
|
| 1125 |
+
path_so_far = os.path.join(path_so_far, f"queries-{allowed_stems}.csv")
|
| 1126 |
+
|
| 1127 |
+
if not os.path.exists(path_so_far):
|
| 1128 |
+
return self.generate_index()
|
| 1129 |
+
|
| 1130 |
+
print(f"Loading query tuples from {path_so_far}")
|
| 1131 |
+
|
| 1132 |
+
return pd.read_csv(path_so_far)
|
| 1133 |
+
|
| 1134 |
+
def _get_index_path(self):
|
| 1135 |
+
query_root = os.path.join(self.data_path, "queries")
|
| 1136 |
+
val_folds = "-".join(map(str, self.folds))
|
| 1137 |
+
chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
|
| 1138 |
+
query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
|
| 1139 |
+
query_dir = os.path.join(query_root, val_folds, chunk_specs, query_specs)
|
| 1140 |
+
|
| 1141 |
+
if self.top_k_instrument is not None:
|
| 1142 |
+
query_path = os.path.join(
|
| 1143 |
+
query_dir, f"queries-top{self.top_k_instrument}.csv"
|
| 1144 |
+
)
|
| 1145 |
+
else:
|
| 1146 |
+
if len(self.allowed_stems) > 5:
|
| 1147 |
+
allowed_stems = (
|
| 1148 |
+
str(len(self.allowed_stems))
|
| 1149 |
+
+ "stems:"
|
| 1150 |
+
+ ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
|
| 1151 |
+
)
|
| 1152 |
+
else:
|
| 1153 |
+
allowed_stems = ":".join(self.allowed_stems)
|
| 1154 |
+
query_path = os.path.join(query_dir, f"queries-{allowed_stems}.csv")
|
| 1155 |
+
|
| 1156 |
+
if not self._overwrite:
|
| 1157 |
+
assert not os.path.exists(
|
| 1158 |
+
query_path
|
| 1159 |
+
), f"Query path {query_path} already exists."
|
| 1160 |
+
|
| 1161 |
+
os.makedirs(query_dir, exist_ok=True)
|
| 1162 |
+
|
| 1163 |
+
return query_path
|
| 1164 |
+
|
| 1165 |
+
def generate_index(self):
|
| 1166 |
+
|
| 1167 |
+
query_path = self._get_index_path()
|
| 1168 |
+
|
| 1169 |
+
durations = pd.read_csv(os.path.join(self.data_path, "durations.csv"))
|
| 1170 |
+
durations = (
|
| 1171 |
+
durations[["song_id", "duration"]]
|
| 1172 |
+
.set_index("song_id")["duration"]
|
| 1173 |
+
.to_dict()
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
tuples = []
|
| 1177 |
+
|
| 1178 |
+
stems_without_queries = defaultdict(list)
|
| 1179 |
+
|
| 1180 |
+
for i, song_id in tqdm(enumerate(self.files), total=len(self.files)):
|
| 1181 |
+
song_duration = durations[song_id]
|
| 1182 |
+
mix_stems = self.song_to_stem[song_id]
|
| 1183 |
+
|
| 1184 |
+
n_mix_chunks = math.floor(
|
| 1185 |
+
(song_duration - self.chunk_size_seconds) / self.chunk_hop_size_seconds
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
for stem in mix_stems:
|
| 1189 |
+
possible_queries = self.stem_to_song[stem]
|
| 1190 |
+
if song_id in possible_queries:
|
| 1191 |
+
possible_queries.remove(song_id)
|
| 1192 |
+
|
| 1193 |
+
if len(possible_queries) == 0:
|
| 1194 |
+
stems_without_queries[song_id].append(stem)
|
| 1195 |
+
continue
|
| 1196 |
+
|
| 1197 |
+
for k in tqdm(range(n_mix_chunks), desc=f"song{i + 1}/{stem}"):
|
| 1198 |
+
mix_chunk_start = int(k * self.chunk_hop_size_samples)
|
| 1199 |
+
|
| 1200 |
+
for j in range(self.n_queries_per_chunk):
|
| 1201 |
+
query = random.choice(possible_queries)
|
| 1202 |
+
|
| 1203 |
+
tuples.append(
|
| 1204 |
+
dict(
|
| 1205 |
+
mix=song_id,
|
| 1206 |
+
query=query,
|
| 1207 |
+
stem=stem,
|
| 1208 |
+
mix_chunk_start=mix_chunk_start,
|
| 1209 |
+
)
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
if len(stems_without_queries) > 0:
|
| 1213 |
+
print("Stems without queries:")
|
| 1214 |
+
for song_id, stems in stems_without_queries.items():
|
| 1215 |
+
print(f"{song_id}: {stems}")
|
| 1216 |
+
|
| 1217 |
+
tuples = pd.DataFrame(tuples)
|
| 1218 |
+
|
| 1219 |
+
print(
|
| 1220 |
+
f"Generating query tuples for {self.split} set with {len(tuples)} tuples."
|
| 1221 |
+
)
|
| 1222 |
+
print(f"Saving query tuples to {query_path}")
|
| 1223 |
+
|
| 1224 |
+
tuples.to_csv(query_path, index=False)
|
| 1225 |
+
|
| 1226 |
+
return tuples
|
| 1227 |
+
|
| 1228 |
+
def index_to_identifiers(self, index: int) -> Tuple[str, str, str, int]:
|
| 1229 |
+
|
| 1230 |
+
row = self.query_tuples.iloc[index]
|
| 1231 |
+
mix_id = row["mix"]
|
| 1232 |
+
|
| 1233 |
+
if self.use_own_query:
|
| 1234 |
+
query_id = mix_id
|
| 1235 |
+
else:
|
| 1236 |
+
query_id = row["query"]
|
| 1237 |
+
|
| 1238 |
+
stem = row["stem"]
|
| 1239 |
+
mix_chunk_start = row["mix_chunk_start"]
|
| 1240 |
+
|
| 1241 |
+
return mix_id, query_id, stem, mix_chunk_start
|
| 1242 |
+
|
| 1243 |
+
def __getitem__(self, index: int):
|
| 1244 |
+
|
| 1245 |
+
mix_id, query_id, stem, mix_chunk_start = self.index_to_identifiers(index)
|
| 1246 |
+
|
| 1247 |
+
mix_identifier = dict(song_id=mix_id, chunk_start=mix_chunk_start)
|
| 1248 |
+
query_identifier = dict(song_id=query_id)
|
| 1249 |
+
|
| 1250 |
+
audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier)
|
| 1251 |
+
query = self.get_query_stem(stem=stem, identifier=query_identifier)
|
| 1252 |
+
|
| 1253 |
+
mixture = audio[self.mixture_stem].copy()
|
| 1254 |
+
sources = {"target": audio[stem].copy()}
|
| 1255 |
+
query = query.copy()
|
| 1256 |
+
|
| 1257 |
+
assert mixture.shape[-1] == self.chunk_size_samples
|
| 1258 |
+
# print(query.shape[-1], self.query_size_samples)
|
| 1259 |
+
assert query.shape[-1] == self.query_size_samples
|
| 1260 |
+
assert sources["target"].shape[-1] == self.chunk_size_samples
|
| 1261 |
+
|
| 1262 |
+
return input_dict(
|
| 1263 |
+
mixture=mixture,
|
| 1264 |
+
sources=sources,
|
| 1265 |
+
query=query,
|
| 1266 |
+
metadata={
|
| 1267 |
+
"mix": mix_identifier,
|
| 1268 |
+
"query": query_identifier,
|
| 1269 |
+
"stem": stem,
|
| 1270 |
+
},
|
| 1271 |
+
modality="audio",
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
|
| 1275 |
+
class MoisesDBFullTrackTestQueryDataset(MoisesDBFullTrackDataset):
|
| 1276 |
+
def __init__(
|
| 1277 |
+
self,
|
| 1278 |
+
data_root: str,
|
| 1279 |
+
split: str = "test",
|
| 1280 |
+
top_k_instrument: int = 10,
|
| 1281 |
+
mixture_stem: str = "mixture",
|
| 1282 |
+
use_own_query: bool = True,
|
| 1283 |
+
npy_memmap=True,
|
| 1284 |
+
allowed_stems: List[str] = None,
|
| 1285 |
+
query_file="query-10s",
|
| 1286 |
+
) -> None:
|
| 1287 |
+
super().__init__(
|
| 1288 |
+
data_root=data_root,
|
| 1289 |
+
split=split,
|
| 1290 |
+
npy_memmap=npy_memmap,
|
| 1291 |
+
recompute_mixture=False,
|
| 1292 |
+
query_file=query_file,
|
| 1293 |
+
)
|
| 1294 |
+
|
| 1295 |
+
self.use_own_query = use_own_query
|
| 1296 |
+
|
| 1297 |
+
self.allowed_stems = allowed_stems
|
| 1298 |
+
|
| 1299 |
+
test_indices = pd.read_csv(os.path.join(data_root, "test_indices.csv"))
|
| 1300 |
+
|
| 1301 |
+
test_indices = test_indices[test_indices.stem.isin(self.allowed_stems)]
|
| 1302 |
+
|
| 1303 |
+
self.test_indices = test_indices
|
| 1304 |
+
|
| 1305 |
+
self.length = len(self.test_indices)
|
| 1306 |
+
|
| 1307 |
+
def __len__(self) -> int:
|
| 1308 |
+
return self.length
|
| 1309 |
+
|
| 1310 |
+
def index_to_identifiers(self, index: int) -> Tuple[str, str, str]:
|
| 1311 |
+
|
| 1312 |
+
row = self.test_indices.iloc[index]
|
| 1313 |
+
mix_id = row["song_id"]
|
| 1314 |
+
if self.use_own_query:
|
| 1315 |
+
query_id = mix_id
|
| 1316 |
+
else:
|
| 1317 |
+
query_id = row["query_id"]
|
| 1318 |
+
stem = row["stem"]
|
| 1319 |
+
|
| 1320 |
+
return mix_id, query_id, stem
|
| 1321 |
+
|
| 1322 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 1323 |
+
audio = {}
|
| 1324 |
+
|
| 1325 |
+
for stem in stems:
|
| 1326 |
+
audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)
|
| 1327 |
+
|
| 1328 |
+
return audio
|
| 1329 |
+
|
| 1330 |
+
def __getitem__(self, index: int):
|
| 1331 |
+
|
| 1332 |
+
mix_id, query_id, stem = self.index_to_identifiers(index)
|
| 1333 |
+
|
| 1334 |
+
mix_identifier = dict(song_id=mix_id)
|
| 1335 |
+
|
| 1336 |
+
query_identifier = dict(song_id=query_id)
|
| 1337 |
+
|
| 1338 |
+
audio = self._get_audio([stem, "mixture"], identifier=mix_identifier)
|
| 1339 |
+
query = self.get_query_stem(stem=stem, identifier=query_identifier)
|
| 1340 |
+
|
| 1341 |
+
mixture = audio["mixture"].copy()
|
| 1342 |
+
sources = {stem: audio[stem].copy()}
|
| 1343 |
+
query = query.copy()
|
| 1344 |
+
|
| 1345 |
+
return input_dict(
|
| 1346 |
+
mixture=mixture,
|
| 1347 |
+
sources=sources,
|
| 1348 |
+
query=query,
|
| 1349 |
+
metadata={
|
| 1350 |
+
"mix": mix_identifier["song_id"],
|
| 1351 |
+
"query": query_identifier["song_id"],
|
| 1352 |
+
"stem": stem,
|
| 1353 |
+
},
|
| 1354 |
+
modality="audio",
|
| 1355 |
+
)
|
| 1356 |
+
|
| 1357 |
+
|
| 1358 |
+
if __name__ == "__main__":
|
| 1359 |
+
|
| 1360 |
+
print("Beginning")
|
| 1361 |
+
|
| 1362 |
+
config = "/storage/home/hcoda1/1/kwatchar3/coda/config/data/moisesdb-everything-query-d-aug.yml"
|
| 1363 |
+
|
| 1364 |
+
config = OmegaConf.load(config)
|
| 1365 |
+
|
| 1366 |
+
print("Loaded config")
|
| 1367 |
+
|
| 1368 |
+
dataset = MoisesDBRandomChunkRandomQueryDataset(
|
| 1369 |
+
data_root=config.data_root, split="train", **config.train_kwargs
|
| 1370 |
+
)
|
| 1371 |
+
|
| 1372 |
+
print("Loaded dataset")
|
| 1373 |
+
|
| 1374 |
+
for item in tqdm(dataset, total=len(dataset)):
|
| 1375 |
+
target_audio = item["sources"]["target"]["audio"]
|
| 1376 |
+
mixture = item["mixture"]["audio"]
|
| 1377 |
+
|
| 1378 |
+
if target_audio is None:
|
| 1379 |
+
raise ValueError
|
| 1380 |
+
else:
|
| 1381 |
+
tdb = 10.0 * torch.log10(torch.mean(torch.square(target_audio)) + 1e-6)
|
| 1382 |
+
mdb = 10.0 * torch.log10(torch.mean(torch.square(mixture)) + 1e-6)
|
| 1383 |
+
print(f"Target db: {tdb}, Mixture db: {mdb}")
|
core/data/moisesdb/eda.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
core/data/moisesdb/npyify.py
ADDED
|
@@ -0,0 +1,923 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
from itertools import chain
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
|
| 15 |
+
from tqdm.contrib.concurrent import process_map
|
| 16 |
+
|
| 17 |
+
from tqdm import tqdm as tdqm, tqdm
|
| 18 |
+
import torchaudio as ta
|
| 19 |
+
|
| 20 |
+
import librosa
|
| 21 |
+
|
| 22 |
+
taxonomy = {
|
| 23 |
+
"vocals": [
|
| 24 |
+
"lead male singer",
|
| 25 |
+
"lead female singer",
|
| 26 |
+
"human choir",
|
| 27 |
+
"background vocals",
|
| 28 |
+
"other (vocoder, beatboxing etc)",
|
| 29 |
+
],
|
| 30 |
+
"bass": [
|
| 31 |
+
"bass guitar",
|
| 32 |
+
"bass synthesizer (moog etc)",
|
| 33 |
+
"contrabass/double bass (bass of instrings)",
|
| 34 |
+
"tuba (bass of brass)",
|
| 35 |
+
"bassoon (bass of woodwind)",
|
| 36 |
+
],
|
| 37 |
+
"drums": [
|
| 38 |
+
"snare drum",
|
| 39 |
+
"toms",
|
| 40 |
+
"kick drum",
|
| 41 |
+
"cymbals",
|
| 42 |
+
"overheads",
|
| 43 |
+
"full acoustic drumkit",
|
| 44 |
+
"drum machine",
|
| 45 |
+
"hi-hat"
|
| 46 |
+
],
|
| 47 |
+
"other": [
|
| 48 |
+
"fx/processed sound, scratches, gun shots, explosions etc",
|
| 49 |
+
"click track",
|
| 50 |
+
],
|
| 51 |
+
"guitar": [
|
| 52 |
+
"clean electric guitar",
|
| 53 |
+
"distorted electric guitar",
|
| 54 |
+
"lap steel guitar or slide guitar",
|
| 55 |
+
"acoustic guitar",
|
| 56 |
+
],
|
| 57 |
+
"other plucked": ["banjo, mandolin, ukulele, harp etc"],
|
| 58 |
+
"percussion": [
|
| 59 |
+
"a-tonal percussion (claps, shakers, congas, cowbell etc)",
|
| 60 |
+
"pitched percussion (mallets, glockenspiel, ...)",
|
| 61 |
+
],
|
| 62 |
+
"piano": [
|
| 63 |
+
"grand piano",
|
| 64 |
+
"electric piano (rhodes, wurlitzer, piano sound alike)",
|
| 65 |
+
],
|
| 66 |
+
"other keys": [
|
| 67 |
+
"organ, electric organ",
|
| 68 |
+
"synth pad",
|
| 69 |
+
"synth lead",
|
| 70 |
+
"other sounds (hapischord, melotron etc)",
|
| 71 |
+
],
|
| 72 |
+
"bowed strings": [
|
| 73 |
+
"violin (solo)",
|
| 74 |
+
"viola (solo)",
|
| 75 |
+
"cello (solo)",
|
| 76 |
+
"violin section",
|
| 77 |
+
"viola section",
|
| 78 |
+
"cello section",
|
| 79 |
+
"string section",
|
| 80 |
+
"other strings",
|
| 81 |
+
],
|
| 82 |
+
"wind": [
|
| 83 |
+
"brass (trumpet, trombone, french horn, brass etc)",
|
| 84 |
+
"flutes (piccolo, bamboo flute, panpipes, flutes etc)",
|
| 85 |
+
"reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
|
| 86 |
+
"other wind",
|
| 87 |
+
],
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def clean_npy_other_vox(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq"):
|
| 91 |
+
npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
npys = [npy for npy in npys if "other" in npy]
|
| 95 |
+
npys = [npy for npy in npys if "vdbo_" not in npy]
|
| 96 |
+
npys = [npy for npy in npys if "other_" not in npy]
|
| 97 |
+
|
| 98 |
+
stems = set([
|
| 99 |
+
os.path.basename(npy).split(".")[0] for npy in npys
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
assert len(stems) == 1
|
| 103 |
+
|
| 104 |
+
for npy in tqdm(npys):
|
| 105 |
+
shutil.move(npy, npy.replace("other", "other_vocals"))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def clean_track_inst(inst):
|
| 111 |
+
|
| 112 |
+
if "vocoder" in inst:
|
| 113 |
+
inst = "other_vocals"
|
| 114 |
+
|
| 115 |
+
if "fx" in inst:
|
| 116 |
+
inst = "fx"
|
| 117 |
+
|
| 118 |
+
if "contrabass_double_bass" in inst:
|
| 119 |
+
inst = "double_bass"
|
| 120 |
+
|
| 121 |
+
if "banjo" in inst:
|
| 122 |
+
return "other_plucked"
|
| 123 |
+
|
| 124 |
+
if "(" in inst:
|
| 125 |
+
inst = inst.split("(")[0]
|
| 126 |
+
|
| 127 |
+
for s in [",", "-"]:
|
| 128 |
+
if s in inst:
|
| 129 |
+
inst = inst.replace(s, "")
|
| 130 |
+
|
| 131 |
+
for s in ["/"]:
|
| 132 |
+
if s in inst:
|
| 133 |
+
inst = inst.replace(s, "_")
|
| 134 |
+
|
| 135 |
+
if inst[-1] == "_":
|
| 136 |
+
inst = inst[:-1]
|
| 137 |
+
|
| 138 |
+
return inst
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
taxonomy = {
|
| 142 |
+
k.replace(" ", "_"): [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
fine_to_coarse = {}
|
| 146 |
+
|
| 147 |
+
for k, v in taxonomy.items():
|
| 148 |
+
for vv in v:
|
| 149 |
+
fine_to_coarse[vv] = k
|
| 150 |
+
|
| 151 |
+
# pprint(fine_to_coarse)
|
| 152 |
+
|
| 153 |
+
def save_taxonomy():
|
| 154 |
+
with open("taxonomy.json", "w") as f:
|
| 155 |
+
json.dump(taxonomy, f, indent=4)
|
| 156 |
+
|
| 157 |
+
taxonomy_coarse = list(taxonomy.keys())
|
| 158 |
+
|
| 159 |
+
with open("taxonomy_coarse.json", "w") as f:
|
| 160 |
+
json.dump(taxonomy_coarse, f, indent=4)
|
| 161 |
+
|
| 162 |
+
taxonomy_fine = list(chain(*taxonomy.values()))
|
| 163 |
+
|
| 164 |
+
count_ = defaultdict(int)
|
| 165 |
+
for t in taxonomy_fine:
|
| 166 |
+
count_[t] += 1
|
| 167 |
+
|
| 168 |
+
with open("taxonomy_fine.json", "w") as f:
|
| 169 |
+
json.dump(taxonomy_fine, f, indent=4)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
possible_coarse = list(taxonomy.keys())
|
| 174 |
+
possible_fine = list(set(chain(*taxonomy.values())))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def trim_and_mix(audios, length_=None):
|
| 178 |
+
length = min([a.shape[-1] for a in audios])
|
| 179 |
+
|
| 180 |
+
if length_ is not None:
|
| 181 |
+
length = min(length, length_)
|
| 182 |
+
|
| 183 |
+
audios = [a[..., :length] for a in audios]
|
| 184 |
+
return np.sum(np.stack(audios, axis=0), axis=0), length
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def retrim_npys(saved_npy, new_length):
|
| 188 |
+
print("retrimming")
|
| 189 |
+
for npy in saved_npy:
|
| 190 |
+
audio = np.load(npy)
|
| 191 |
+
audio = audio[..., :new_length]
|
| 192 |
+
np.save(npy, audio)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def convert_one(inout):
|
| 196 |
+
input_path = inout.input_path
|
| 197 |
+
output_root = inout.output_root
|
| 198 |
+
|
| 199 |
+
song_id = os.path.basename(input_path)
|
| 200 |
+
output_root = os.path.join(output_root, song_id)
|
| 201 |
+
os.makedirs(output_root, exist_ok=True)
|
| 202 |
+
|
| 203 |
+
metadata = OmegaConf.load(os.path.join(input_path, "data.json"))
|
| 204 |
+
stems = metadata.stems
|
| 205 |
+
|
| 206 |
+
min_length = None
|
| 207 |
+
saved_npy = []
|
| 208 |
+
|
| 209 |
+
all_tracks = []
|
| 210 |
+
other_tracks = []
|
| 211 |
+
|
| 212 |
+
outfile = None
|
| 213 |
+
|
| 214 |
+
added_tracks = set()
|
| 215 |
+
duplicated_tracks = set()
|
| 216 |
+
track_to_stem = defaultdict(list)
|
| 217 |
+
added_stems = set()
|
| 218 |
+
duplicated_stems = set()
|
| 219 |
+
|
| 220 |
+
stem_name_to_stems = defaultdict(list)
|
| 221 |
+
|
| 222 |
+
for stem in stems:
|
| 223 |
+
stem_name = stem.stemName
|
| 224 |
+
stem_name_to_stems[stem_name].append(stem)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
for stem_name in tqdm(stem_name_to_stems):
|
| 228 |
+
stem_tracks = []
|
| 229 |
+
for stem in stem_name_to_stems[stem_name]:
|
| 230 |
+
stem_name = stem.stemName
|
| 231 |
+
|
| 232 |
+
if stem_name in added_stems:
|
| 233 |
+
print(f"Duplicate stem {stem_name} in {song_id}")
|
| 234 |
+
duplicated_stems.add(stem_name)
|
| 235 |
+
|
| 236 |
+
added_stems.add(stem_name)
|
| 237 |
+
|
| 238 |
+
for track in stem.tracks:
|
| 239 |
+
track_inst = track.trackType
|
| 240 |
+
track_inst = clean_track_inst(track_inst)
|
| 241 |
+
|
| 242 |
+
if track_inst in added_tracks:
|
| 243 |
+
if stem_name in track_to_stem[track_inst]:
|
| 244 |
+
continue
|
| 245 |
+
print(f"Duplicate track {track_inst} in {song_id}")
|
| 246 |
+
print(f"Stems: {track_to_stem[track_inst]}")
|
| 247 |
+
duplicated_tracks.add(track_inst)
|
| 248 |
+
raise ValueError
|
| 249 |
+
else:
|
| 250 |
+
added_tracks.add(track_inst)
|
| 251 |
+
|
| 252 |
+
track_to_stem[track_inst].append(stem_name)
|
| 253 |
+
track_id = track.id
|
| 254 |
+
|
| 255 |
+
audio, fs = ta.load(os.path.join(input_path, stem_name, f"{track_id}.wav"))
|
| 256 |
+
|
| 257 |
+
if fs != 44100:
|
| 258 |
+
print(f"fs is {fs} for {track_id}")
|
| 259 |
+
with open(os.path.join(output_root, "fs.txt"), "w") as f:
|
| 260 |
+
f.write(f"{song_id}\t{track_id}\t{fs}\n")
|
| 261 |
+
|
| 262 |
+
if min_length is None:
|
| 263 |
+
min_length = audio.shape[-1]
|
| 264 |
+
else:
|
| 265 |
+
if audio.shape[-1] < min_length:
|
| 266 |
+
min_length = audio.shape[-1]
|
| 267 |
+
|
| 268 |
+
if len(saved_npy) > 0:
|
| 269 |
+
retrim_npys(saved_npy, min_length)
|
| 270 |
+
|
| 271 |
+
audio = audio[..., :min_length]
|
| 272 |
+
audio = audio.numpy()
|
| 273 |
+
audio = audio.astype(np.float32)
|
| 274 |
+
|
| 275 |
+
if audio.shape[0] == 1:
|
| 276 |
+
print("mono")
|
| 277 |
+
if audio.shape[0] > 2:
|
| 278 |
+
print("multi channel")
|
| 279 |
+
|
| 280 |
+
assert outfile is None
|
| 281 |
+
outfile = os.path.join(output_root, f"{track_inst}.npy")
|
| 282 |
+
np.save(outfile, audio)
|
| 283 |
+
saved_npy.append(outfile)
|
| 284 |
+
outfile = None
|
| 285 |
+
stem_tracks.append(audio)
|
| 286 |
+
audio = None
|
| 287 |
+
|
| 288 |
+
stem_track, min_length = trim_and_mix(stem_tracks)
|
| 289 |
+
|
| 290 |
+
assert outfile is None
|
| 291 |
+
outfile = os.path.join(output_root, f"{stem_name}.npy")
|
| 292 |
+
np.save(outfile, stem_track)
|
| 293 |
+
saved_npy.append(outfile)
|
| 294 |
+
outfile = None
|
| 295 |
+
|
| 296 |
+
all_tracks.append(stem_track)
|
| 297 |
+
|
| 298 |
+
if stem_name not in ["vocals", "drums", "bass"]:
|
| 299 |
+
# print(f"Putting {stem_name} in other")
|
| 300 |
+
other_tracks.append(stem_track)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
assert outfile is None
|
| 304 |
+
all_track, min_length_ = trim_and_mix(all_tracks, min_length)
|
| 305 |
+
outfile = os.path.join(output_root, f"mixture.npy")
|
| 306 |
+
np.save(outfile, all_track)
|
| 307 |
+
|
| 308 |
+
if min_length_ != min_length:
|
| 309 |
+
retrim_npys(saved_npy, min_length_)
|
| 310 |
+
min_length = min_length_
|
| 311 |
+
|
| 312 |
+
saved_npy.append(outfile)
|
| 313 |
+
outfile = None
|
| 314 |
+
|
| 315 |
+
other_track, min_length_ = trim_and_mix(other_tracks, min_length)
|
| 316 |
+
np.save(os.path.join(output_root, f"vdbo_others.npy"), other_track)
|
| 317 |
+
|
| 318 |
+
if min_length_ != min_length:
|
| 319 |
+
retrim_npys(saved_npy, min_length_)
|
| 320 |
+
min_length = min_length_
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def convert_to_npy(
|
| 324 |
+
data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/canonical",
|
| 325 |
+
output_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2",
|
| 326 |
+
):
|
| 327 |
+
if output_root is None:
|
| 328 |
+
output_root = os.path.join(os.path.dirname(data_root), "npy")
|
| 329 |
+
|
| 330 |
+
files = os.listdir(data_root)
|
| 331 |
+
files = [
|
| 332 |
+
os.path.join(data_root, f)
|
| 333 |
+
for f in files
|
| 334 |
+
if os.path.isdir(os.path.join(data_root, f))
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
inout = [SimpleNamespace(input_path=f, output_root=output_root) for f in files]
|
| 338 |
+
|
| 339 |
+
process_map(convert_one, inout)
|
| 340 |
+
|
| 341 |
+
# for io in tdqm(inout):
|
| 342 |
+
# convert_one(io)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def make_others_one(input_path, dry_run=False):
|
| 346 |
+
|
| 347 |
+
other_stems = [k for k in taxonomy.keys() if k not in ["vocals", "bass", "drums"]]
|
| 348 |
+
npys = glob.glob(os.path.join(input_path, "**/*.npy"), recursive=True)
|
| 349 |
+
|
| 350 |
+
npys = [npy for npy in npys if ".dbfs" not in npy]
|
| 351 |
+
npys = [npy for npy in npys if ".query" not in npy]
|
| 352 |
+
npys = [npy for npy in npys if "mixture" not in npy]
|
| 353 |
+
npys = [npy for npy in npys if os.path.basename(npy).split(".")[0] in other_stems]
|
| 354 |
+
|
| 355 |
+
print(f"Using stems: {[os.path.basename(npy).split('.')[0] for npy in npys]}")
|
| 356 |
+
|
| 357 |
+
if len(npys) == 0:
|
| 358 |
+
audio = np.zeros_like(np.load(os.path.join(input_path, "mixture.npy")))
|
| 359 |
+
else:
|
| 360 |
+
audio = [np.load(npy) for npy in npys]
|
| 361 |
+
|
| 362 |
+
audio = np.sum(np.stack(audio, axis=0), axis=0)
|
| 363 |
+
assert audio.shape[0] == 2
|
| 364 |
+
|
| 365 |
+
output = os.path.join(input_path, "vdbo_others.npy")
|
| 366 |
+
|
| 367 |
+
if dry_run:
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
np.save(output, audio)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def check_vdbo_one(f):
|
| 374 |
+
s = np.sum(
|
| 375 |
+
np.stack(
|
| 376 |
+
[
|
| 377 |
+
np.load(os.path.join(f, s + ".npy"))
|
| 378 |
+
for s in ["vocals", "drums", "bass", "vdbo_others"]
|
| 379 |
+
if os.path.exists(os.path.join(f, s + ".npy"))
|
| 380 |
+
],
|
| 381 |
+
axis=0,
|
| 382 |
+
),
|
| 383 |
+
axis=0,
|
| 384 |
+
)
|
| 385 |
+
m = np.load(os.path.join(f, "mixture.npy"))
|
| 386 |
+
snr = 10 * np.log10(np.mean(np.square(m)) / np.mean(np.square(s - m)))
|
| 387 |
+
print(snr)
|
| 388 |
+
|
| 389 |
+
return snr
|
| 390 |
+
|
| 391 |
+
def check_vdbo(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):
|
| 392 |
+
files = os.listdir(data_root)
|
| 393 |
+
|
| 394 |
+
files = [
|
| 395 |
+
os.path.join(data_root, f)
|
| 396 |
+
for f in files
|
| 397 |
+
if os.path.isdir(os.path.join(data_root, f))
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
snrs = process_map(check_vdbo_one, files)
|
| 401 |
+
|
| 402 |
+
np.save("/storage/home/hcoda1/1/kwatchar3/data/vdbo.npy", np.array(snrs))
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def make_others(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):
|
| 406 |
+
|
| 407 |
+
files = os.listdir(data_root)
|
| 408 |
+
|
| 409 |
+
files = [
|
| 410 |
+
os.path.join(data_root, f)
|
| 411 |
+
for f in files
|
| 412 |
+
if os.path.isdir(os.path.join(data_root, f))
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
+
process_map(make_others_one, files)
|
| 416 |
+
|
| 417 |
+
# for f in tqdm(files):
|
| 418 |
+
# make_others_one(f, dry_run=False)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def extract_metadata_one(input_path):
|
| 422 |
+
song_id = os.path.basename(input_path)
|
| 423 |
+
metadata = OmegaConf.load(os.path.join(input_path, "data.json"))
|
| 424 |
+
|
| 425 |
+
song = metadata.song
|
| 426 |
+
artist = metadata.artist
|
| 427 |
+
genre = metadata.genre
|
| 428 |
+
|
| 429 |
+
stems = metadata.stems
|
| 430 |
+
data_out = []
|
| 431 |
+
|
| 432 |
+
for stem in stems:
|
| 433 |
+
stem_name = stem.stemName
|
| 434 |
+
stem_id = stem.id
|
| 435 |
+
for track in stem.tracks:
|
| 436 |
+
track_inst = track.trackType
|
| 437 |
+
track_id = track.id
|
| 438 |
+
|
| 439 |
+
data_out.append(
|
| 440 |
+
{
|
| 441 |
+
"song_id": song_id,
|
| 442 |
+
"song": song,
|
| 443 |
+
"artist": artist,
|
| 444 |
+
"genre": genre,
|
| 445 |
+
"stem_name": stem_name,
|
| 446 |
+
"stem_id": stem_id,
|
| 447 |
+
"track_inst": track_inst,
|
| 448 |
+
"track_id": track_id,
|
| 449 |
+
"has_bleed": track.has_bleed,
|
| 450 |
+
}
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
return data_out
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def consolidate_metadata(
|
| 457 |
+
data_root="/home/kwatchar3/Documents/data/moisesdb/canonical",
|
| 458 |
+
):
|
| 459 |
+
|
| 460 |
+
files = os.listdir(data_root)
|
| 461 |
+
files = [
|
| 462 |
+
os.path.join(data_root, f)
|
| 463 |
+
for f in files
|
| 464 |
+
if os.path.isdir(os.path.join(data_root, f))
|
| 465 |
+
]
|
| 466 |
+
|
| 467 |
+
data = process_map(extract_metadata_one, files)
|
| 468 |
+
|
| 469 |
+
df = pd.DataFrame.from_records(list(chain(*data)))
|
| 470 |
+
|
| 471 |
+
df.to_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"), index=False)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def clean_canonical(data_root="/home/kwatchar3/Documents/data/moisesdb/canonical"):
|
| 475 |
+
|
| 476 |
+
npy = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 477 |
+
|
| 478 |
+
for n in tqdm(npy):
|
| 479 |
+
os.remove(n)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def remove_dbfs(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy"):
|
| 483 |
+
npy = glob.glob(os.path.join(data_root, "**/*.dbfs.npy"), recursive=True)
|
| 484 |
+
|
| 485 |
+
for n in tqdm(npy):
|
| 486 |
+
os.remove(n)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def make_split(
|
| 490 |
+
metadata_path="/home/kwatchar3/Documents/data/moisesdb/metadata.csv",
|
| 491 |
+
n_splits=5,
|
| 492 |
+
seed=42,
|
| 493 |
+
):
|
| 494 |
+
|
| 495 |
+
df = pd.read_csv(metadata_path)
|
| 496 |
+
# print(df.columns)
|
| 497 |
+
df = df[["song_id", "genre"]].drop_duplicates()
|
| 498 |
+
|
| 499 |
+
genres = df["genre"].value_counts()
|
| 500 |
+
genres_map = {g: g if c > n_splits else "other" for g, c in genres.items()}
|
| 501 |
+
|
| 502 |
+
df["genre"] = df["genre"].map(genres_map)
|
| 503 |
+
|
| 504 |
+
n_samples = len(df)
|
| 505 |
+
n_per_split = n_samples // n_splits
|
| 506 |
+
|
| 507 |
+
np.random.seed(seed)
|
| 508 |
+
|
| 509 |
+
from sklearn.model_selection import train_test_split
|
| 510 |
+
|
| 511 |
+
splits = []
|
| 512 |
+
|
| 513 |
+
df_ = df.copy()
|
| 514 |
+
|
| 515 |
+
for i in range(n_splits - 1):
|
| 516 |
+
df_, test = train_test_split(
|
| 517 |
+
df_,
|
| 518 |
+
test_size=n_per_split,
|
| 519 |
+
random_state=seed,
|
| 520 |
+
stratify=df_["genre"],
|
| 521 |
+
shuffle=True,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
dfs = test[["song_id"]].copy().sort_values(by="song_id")
|
| 525 |
+
dfs["split"] = i + 1
|
| 526 |
+
splits.append(dfs)
|
| 527 |
+
|
| 528 |
+
test = df_
|
| 529 |
+
dfs = test[["song_id"]].copy().sort_values(by="song_id")
|
| 530 |
+
dfs["split"] = n_splits
|
| 531 |
+
splits.append(dfs)
|
| 532 |
+
|
| 533 |
+
splits = pd.concat(splits)
|
| 534 |
+
|
| 535 |
+
splits.to_csv(
|
| 536 |
+
os.path.join(os.path.dirname(metadata_path), "splits.csv"), index=False
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def consolidate_stems(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
|
| 541 |
+
|
| 542 |
+
metadata = pd.read_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"))
|
| 543 |
+
|
| 544 |
+
dfg = metadata.groupby("song_id")[["stem_name", "track_inst"]]
|
| 545 |
+
|
| 546 |
+
pprint(dfg)
|
| 547 |
+
|
| 548 |
+
df = []
|
| 549 |
+
|
| 550 |
+
def make_stem_dict(song_id, track_inst, stem_names):
|
| 551 |
+
|
| 552 |
+
d = {"song_id": song_id}
|
| 553 |
+
|
| 554 |
+
for inst in possible_fine:
|
| 555 |
+
d[inst] = int(inst in track_inst)
|
| 556 |
+
|
| 557 |
+
for inst in possible_coarse:
|
| 558 |
+
d[inst] = int(inst in stem_names)
|
| 559 |
+
|
| 560 |
+
return d
|
| 561 |
+
|
| 562 |
+
for song_id, dfgg in dfg:
|
| 563 |
+
|
| 564 |
+
track_inst = dfgg["track_inst"].tolist()
|
| 565 |
+
track_inst = list(set(track_inst))
|
| 566 |
+
track_inst = [clean_track_inst(inst) for inst in track_inst]
|
| 567 |
+
|
| 568 |
+
stem_names = dfgg["stem_name"].tolist()
|
| 569 |
+
stem_names = list(set([clean_track_inst(inst) for inst in stem_names]))
|
| 570 |
+
|
| 571 |
+
d = make_stem_dict(song_id, track_inst, stem_names)
|
| 572 |
+
df.append(d)
|
| 573 |
+
|
| 574 |
+
print(df)
|
| 575 |
+
|
| 576 |
+
df = pd.DataFrame.from_records(df)
|
| 577 |
+
|
| 578 |
+
df.to_csv(os.path.join(os.path.dirname(data_root), "stems.csv"), index=False)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def get_dbfs(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
|
| 582 |
+
npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 583 |
+
|
| 584 |
+
dbfs = []
|
| 585 |
+
|
| 586 |
+
for npy in tqdm(npys):
|
| 587 |
+
audio = np.load(npy)
|
| 588 |
+
song_id = os.path.basename(os.path.dirname(npy))
|
| 589 |
+
track_id = os.path.basename(npy).split(".")[0]
|
| 590 |
+
|
| 591 |
+
dbfs.append(
|
| 592 |
+
{
|
| 593 |
+
"song_id": song_id,
|
| 594 |
+
"track_id": track_id,
|
| 595 |
+
"dbfs": 10 * np.log10(np.mean(np.square(audio))),
|
| 596 |
+
}
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
dbfs = pd.DataFrame.from_records(dbfs)
|
| 600 |
+
|
| 601 |
+
dbfs.to_csv(os.path.join(os.path.dirname(data_root), "dbfs.csv"), index=False)
|
| 602 |
+
|
| 603 |
+
return dbfs
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def get_dbfs_by_chunk_one(inout):
|
| 607 |
+
|
| 608 |
+
audio = np.load(inout.audio_path, mmap_mode="r")
|
| 609 |
+
chunk_size = inout.chunk_size
|
| 610 |
+
fs = inout.fs
|
| 611 |
+
hop_size = inout.hop_size
|
| 612 |
+
|
| 613 |
+
n_chan, n_samples = audio.shape
|
| 614 |
+
chunk_size_samples = int(chunk_size * fs)
|
| 615 |
+
hop_size_samples = int(hop_size * fs)
|
| 616 |
+
|
| 617 |
+
x2win = np.lib.stride_tricks.sliding_window_view(
|
| 618 |
+
np.square(audio), chunk_size_samples, axis=1
|
| 619 |
+
)[:, ::hop_size_samples, :]
|
| 620 |
+
|
| 621 |
+
x2win_mean = np.mean(x2win, axis=(0, 2))
|
| 622 |
+
x2win_mean[x2win_mean == 0] = 1e-8
|
| 623 |
+
dbfs = 10 * np.log10(x2win_mean)
|
| 624 |
+
|
| 625 |
+
# song_id = os.path.basename(os.path.dirname(inout.audio_path))
|
| 626 |
+
track_id = os.path.basename(inout.audio_path).split(".")[0]
|
| 627 |
+
|
| 628 |
+
np.save(
|
| 629 |
+
os.path.join(os.path.dirname(inout.audio_path), f"{track_id}.dbfs.npy"), dbfs
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def clean_data_root(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
|
| 634 |
+
npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 635 |
+
|
| 636 |
+
for npy in tqdm(npys):
|
| 637 |
+
if ".dbfs" in npy or ".query" in npy:
|
| 638 |
+
# print("removing", npy)
|
| 639 |
+
os.remove(npy)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
#
|
| 643 |
+
def get_dbfs_by_chunk(
|
| 644 |
+
data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
|
| 645 |
+
query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
|
| 646 |
+
):
|
| 647 |
+
npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 648 |
+
|
| 649 |
+
inout = [
|
| 650 |
+
SimpleNamespace(
|
| 651 |
+
audio_path=npy,
|
| 652 |
+
chunk_size=1,
|
| 653 |
+
hop_size=0.125,
|
| 654 |
+
fs=44100,
|
| 655 |
+
output_path=npy.replace(data_root, query_root).replace(
|
| 656 |
+
".npy", ".query.npy"
|
| 657 |
+
),
|
| 658 |
+
)
|
| 659 |
+
for npy in npys
|
| 660 |
+
]
|
| 661 |
+
|
| 662 |
+
process_map(get_dbfs_by_chunk_one, inout, chunksize=2)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def round_samples(seconds, fs, hop_size, downsample):
|
| 666 |
+
n_frames = math.ceil(seconds * fs / hop_size) + 1
|
| 667 |
+
n_frames_down = math.ceil(n_frames / downsample)
|
| 668 |
+
n_frames = n_frames_down * downsample
|
| 669 |
+
n_samples = (n_frames - 1) * hop_size
|
| 670 |
+
|
| 671 |
+
return int(n_samples)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def get_query_one(inout):
|
| 675 |
+
|
| 676 |
+
audio = np.load(inout.audio_path, mmap_mode="r")
|
| 677 |
+
chunk_size = inout.chunk_size
|
| 678 |
+
fs = inout.fs
|
| 679 |
+
output_path = inout.output_path
|
| 680 |
+
round = inout.round
|
| 681 |
+
hop_size = inout.hop_size
|
| 682 |
+
|
| 683 |
+
if round:
|
| 684 |
+
chunk_size_samples = round_samples(chunk_size, fs, 512, 2**6)
|
| 685 |
+
else:
|
| 686 |
+
chunk_size_samples = int(chunk_size * fs)
|
| 687 |
+
|
| 688 |
+
audio_mono = np.mean(audio, axis=0)
|
| 689 |
+
|
| 690 |
+
onset = librosa.onset.onset_detect(
|
| 691 |
+
y=audio_mono, sr=fs, units="frames", hop_length=hop_size
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
onset_strength = librosa.onset.onset_strength(
|
| 695 |
+
y=audio_mono, sr=fs, hop_length=hop_size
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
n_frames_per_chunk = chunk_size_samples // hop_size
|
| 699 |
+
|
| 700 |
+
onset_strength_slide = np.lib.stride_tricks.sliding_window_view(
|
| 701 |
+
onset_strength, n_frames_per_chunk, axis=0
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
onset_strength = np.mean(onset_strength_slide, axis=1)
|
| 705 |
+
|
| 706 |
+
max_onset_frame = np.argmax(onset_strength)
|
| 707 |
+
|
| 708 |
+
max_onset_samples = librosa.frames_to_samples(max_onset_frame)
|
| 709 |
+
|
| 710 |
+
track_id = os.path.basename(inout.audio_path).split(".")[0]
|
| 711 |
+
|
| 712 |
+
segment = audio[:, max_onset_samples : max_onset_samples + chunk_size_samples]
|
| 713 |
+
|
| 714 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 715 |
+
|
| 716 |
+
np.save(output_path, segment)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def get_query_from_onset(
|
| 720 |
+
data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2", # "/home/kwatchar3/Documents/data/moisesdb/npy",
|
| 721 |
+
query_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq", # "/home/kwatchar3/Documents/data/moisesdb/npyq",
|
| 722 |
+
query_file="query-10s",
|
| 723 |
+
pmap=True,
|
| 724 |
+
):
|
| 725 |
+
npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
|
| 726 |
+
|
| 727 |
+
npys = [npy for npy in npys if "dbfs" not in npy]
|
| 728 |
+
|
| 729 |
+
inout = [
|
| 730 |
+
SimpleNamespace(
|
| 731 |
+
audio_path=npy,
|
| 732 |
+
chunk_size=10,
|
| 733 |
+
hop_size=512,
|
| 734 |
+
round=False,
|
| 735 |
+
fs=44100,
|
| 736 |
+
output_path=npy.replace(data_root, query_root).replace(
|
| 737 |
+
".npy", f".{query_file}.npy"
|
| 738 |
+
),
|
| 739 |
+
)
|
| 740 |
+
for npy in npys
|
| 741 |
+
]
|
| 742 |
+
|
| 743 |
+
if pmap:
|
| 744 |
+
process_map(get_query_one, inout, chunksize=2, max_workers=24)
|
| 745 |
+
else:
|
| 746 |
+
for io in tqdm(inout):
|
| 747 |
+
get_query_one(io)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def get_durations(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
|
| 751 |
+
npys = glob.glob(os.path.join(data_root, "**/mixture.npy"), recursive=True)
|
| 752 |
+
|
| 753 |
+
durations = []
|
| 754 |
+
|
| 755 |
+
for npy in tqdm(npys):
|
| 756 |
+
audio = np.load(npy, mmap_mode="r")
|
| 757 |
+
song_id = os.path.basename(os.path.dirname(npy))
|
| 758 |
+
track_id = os.path.basename(npy).split(".")[0]
|
| 759 |
+
|
| 760 |
+
durations.append(
|
| 761 |
+
{
|
| 762 |
+
"song_id": song_id,
|
| 763 |
+
"track_id": track_id,
|
| 764 |
+
"duration": audio.shape[-1] / 44100,
|
| 765 |
+
}
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
durations = pd.DataFrame.from_records(durations)
|
| 769 |
+
|
| 770 |
+
durations.to_csv(
|
| 771 |
+
os.path.join(os.path.dirname(data_root), "durations.csv"), index=False
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
return durations
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def clean_query_root(
|
| 778 |
+
data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
|
| 779 |
+
query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
|
| 780 |
+
):
|
| 781 |
+
npys = glob.glob(os.path.join(data_root, "**/*.query.npy"), recursive=True)
|
| 782 |
+
|
| 783 |
+
for npy in tqdm(npys):
|
| 784 |
+
dst = npy.replace(data_root, query_root)
|
| 785 |
+
dstdir = os.path.dirname(dst)
|
| 786 |
+
os.makedirs(dstdir, exist_ok=True)
|
| 787 |
+
shutil.move(npy, dst)
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def make_test_indices(
|
| 791 |
+
metadata_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/metadata.csv",
|
| 792 |
+
stem_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/stems.csv",
|
| 793 |
+
splits_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/splits.csv",
|
| 794 |
+
test_split=5,
|
| 795 |
+
):
|
| 796 |
+
|
| 797 |
+
coarse_stems = set(taxonomy.keys())
|
| 798 |
+
fine_stems = set(chain(*taxonomy.values()))
|
| 799 |
+
|
| 800 |
+
metadata = pd.read_csv(metadata_path)
|
| 801 |
+
splits = pd.read_csv(splits_path)
|
| 802 |
+
stems = pd.read_csv(stem_path)
|
| 803 |
+
|
| 804 |
+
file_in_test = splits[splits["split"] == test_split]["song_id"].tolist()
|
| 805 |
+
|
| 806 |
+
stems_test = stems[stems["song_id"].isin(file_in_test)]
|
| 807 |
+
metadata_test = metadata[metadata["song_id"].isin(file_in_test)]
|
| 808 |
+
splits_test = splits[splits["split"] == test_split]
|
| 809 |
+
|
| 810 |
+
stems_test = stems_test.set_index("song_id")
|
| 811 |
+
metadata_test = metadata_test.drop_duplicates("song_id").set_index("song_id")
|
| 812 |
+
splits_test = splits_test.set_index("song_id")
|
| 813 |
+
|
| 814 |
+
stem_to_song_id = defaultdict(list)
|
| 815 |
+
song_id_to_stem = defaultdict(list)
|
| 816 |
+
|
| 817 |
+
for song_id in file_in_test:
|
| 818 |
+
|
| 819 |
+
stems_ = stems_test.loc[song_id]
|
| 820 |
+
stem_names = stems_.T
|
| 821 |
+
stem_names = stem_names[stem_names == 1].index.tolist()
|
| 822 |
+
|
| 823 |
+
for stem in stem_names:
|
| 824 |
+
stem_to_song_id[stem].append(song_id)
|
| 825 |
+
|
| 826 |
+
song_id_to_stem[song_id] = stem_names
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
indices = []
|
| 830 |
+
no_query = []
|
| 831 |
+
|
| 832 |
+
for song_id in file_in_test:
|
| 833 |
+
|
| 834 |
+
genre = metadata_test.loc[song_id, "genre"]
|
| 835 |
+
# print(genre)
|
| 836 |
+
artist = metadata_test.loc[song_id, "artist"]
|
| 837 |
+
# print(artist)
|
| 838 |
+
|
| 839 |
+
stems_ = song_id_to_stem[song_id]
|
| 840 |
+
|
| 841 |
+
for stem in stems_:
|
| 842 |
+
possible_query = stem_to_song_id[stem]
|
| 843 |
+
possible_query = [p for p in possible_query if p != song_id]
|
| 844 |
+
|
| 845 |
+
if len(possible_query) == 0:
|
| 846 |
+
print(f"No possible query for {song_id} with {stem}")
|
| 847 |
+
|
| 848 |
+
no_query.append(
|
| 849 |
+
{
|
| 850 |
+
"song_id": song_id,
|
| 851 |
+
"stem": stem
|
| 852 |
+
}
|
| 853 |
+
)
|
| 854 |
+
continue
|
| 855 |
+
|
| 856 |
+
query_df = metadata_test.loc[possible_query, ["genre", "artist"]]
|
| 857 |
+
|
| 858 |
+
assert len(query_df) > 0
|
| 859 |
+
|
| 860 |
+
query_df_ = query_df.copy()
|
| 861 |
+
|
| 862 |
+
same_genre = True
|
| 863 |
+
different_artist = True
|
| 864 |
+
query_df = query_df[(query_df["genre"] == genre) & (query_df["artist"] != artist)]
|
| 865 |
+
|
| 866 |
+
if len(query_df) == 0:
|
| 867 |
+
|
| 868 |
+
same_genre = False
|
| 869 |
+
different_artist = True
|
| 870 |
+
|
| 871 |
+
query_df = query_df_.copy()
|
| 872 |
+
query_df = query_df[(query_df["artist"] != artist)]
|
| 873 |
+
|
| 874 |
+
if len(query_df) == 0:
|
| 875 |
+
|
| 876 |
+
same_genre = True
|
| 877 |
+
different_artist = False
|
| 878 |
+
|
| 879 |
+
query_df = query_df_.copy()
|
| 880 |
+
query_df = query_df[(query_df["genre"] == genre)]
|
| 881 |
+
|
| 882 |
+
if len(query_df) == 0:
|
| 883 |
+
|
| 884 |
+
same_genre = False
|
| 885 |
+
different_artist = False
|
| 886 |
+
|
| 887 |
+
query_df = query_df_.copy()
|
| 888 |
+
|
| 889 |
+
query_id = query_df.sample(1).index[0]
|
| 890 |
+
|
| 891 |
+
indices.append(
|
| 892 |
+
{
|
| 893 |
+
"song_id": song_id,
|
| 894 |
+
"query_id": query_id,
|
| 895 |
+
"stem": stem,
|
| 896 |
+
"same_genre": same_genre,
|
| 897 |
+
"different_artist": different_artist
|
| 898 |
+
}
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
indices = pd.DataFrame.from_records(indices)
|
| 902 |
+
no_query = pd.DataFrame.from_records(no_query)
|
| 903 |
+
|
| 904 |
+
indices.to_csv(
|
| 905 |
+
os.path.join(os.path.dirname(metadata_path), "test_indices.csv"), index=False
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
no_query.to_csv(
|
| 909 |
+
os.path.join(os.path.dirname(metadata_path), "no_query.csv"), index=False
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
print("Total number of queries:", len(indices))
|
| 913 |
+
print("Total number of no queries:", len(no_query))
|
| 914 |
+
|
| 915 |
+
query_type = indices.groupby(["same_genre", "different_artist"]).size()
|
| 916 |
+
|
| 917 |
+
print(query_type)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
if __name__ == "__main__":
|
| 921 |
+
import fire
|
| 922 |
+
|
| 923 |
+
fire.Fire()
|
core/data/moisesdb/passt.ipynb
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import glob\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"data_root = \"/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/passt\"\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"files = glob.glob(data_root + \"/*.passt.npy\")"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": null,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": []
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"metadata": {
|
| 26 |
+
"language_info": {
|
| 27 |
+
"name": "python"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"nbformat": 4,
|
| 31 |
+
"nbformat_minor": 2
|
| 32 |
+
}
|
core/losses/__init__.py
ADDED
|
File without changes
|
core/losses/base.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Union
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn.modules.loss import _Loss
|
| 5 |
+
|
| 6 |
+
from core.types import BatchedInputOutput
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
class BaseLossHandler(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, loss: nn.Module, modality: Union[str, List[str]], name: Optional[str] = None
|
| 12 |
+
) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
self.loss = loss
|
| 16 |
+
|
| 17 |
+
if isinstance(modality, str):
|
| 18 |
+
modality = [modality]
|
| 19 |
+
|
| 20 |
+
self.modality = modality
|
| 21 |
+
|
| 22 |
+
if name is None:
|
| 23 |
+
name = "loss"
|
| 24 |
+
|
| 25 |
+
if name == "__auto__":
|
| 26 |
+
name = self.loss.__class__.__name__
|
| 27 |
+
|
| 28 |
+
self.name = name
|
| 29 |
+
|
| 30 |
+
def _audio_preprocess(self, y_pred, y_true):
|
| 31 |
+
|
| 32 |
+
n_sample_true = y_true.shape[-1]
|
| 33 |
+
n_sample_pred = y_pred.shape[-1]
|
| 34 |
+
|
| 35 |
+
if n_sample_pred > n_sample_true:
|
| 36 |
+
y_pred = y_pred[..., :n_sample_true]
|
| 37 |
+
elif n_sample_pred < n_sample_true:
|
| 38 |
+
y_true = y_true[..., :n_sample_pred]
|
| 39 |
+
|
| 40 |
+
return y_pred, y_true
|
| 41 |
+
|
| 42 |
+
def forward(self, batch: BatchedInputOutput):
|
| 43 |
+
y_true = batch.sources
|
| 44 |
+
y_pred = batch.estimates
|
| 45 |
+
|
| 46 |
+
loss_contribs = {}
|
| 47 |
+
|
| 48 |
+
stem_contribs = {
|
| 49 |
+
stem: 0.0 for stem in y_pred.keys()
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
for stem in y_pred.keys():
|
| 53 |
+
for modality in self.modality:
|
| 54 |
+
|
| 55 |
+
if modality not in y_pred[stem].keys():
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if y_pred[stem][modality].shape[-1] == 0:
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
y_true_ = y_true[stem][modality]
|
| 62 |
+
y_pred_ = y_pred[stem][modality]
|
| 63 |
+
|
| 64 |
+
if modality == "audio":
|
| 65 |
+
y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
|
| 66 |
+
elif modality == "spectrogram":
|
| 67 |
+
y_pred_ = torch.view_as_real(y_pred_)
|
| 68 |
+
y_true_ = torch.view_as_real(y_true_)
|
| 69 |
+
|
| 70 |
+
loss_contribs[f"{self.name}/{stem}/{modality}"] = self.loss(
|
| 71 |
+
y_pred_, y_true_
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
stem_contribs[stem] += loss_contribs[f"{self.name}/{stem}/{modality}"]
|
| 75 |
+
|
| 76 |
+
total_loss = sum(stem_contribs.values())
|
| 77 |
+
loss_contribs[self.name] = total_loss
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
for stem in stem_contribs.keys():
|
| 81 |
+
loss_contribs[f"{self.name}/{stem}"] = stem_contribs[stem]
|
| 82 |
+
|
| 83 |
+
return loss_contribs
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AdversarialLossHandler(BaseLossHandler):
|
| 87 |
+
def __init__(self, loss: nn.Module, modality: str, name: Optional[str] = "adv_loss"):
|
| 88 |
+
|
| 89 |
+
super().__init__(loss, modality, name)
|
| 90 |
+
|
| 91 |
+
def discriminator_forward(self, batch: BatchedInputOutput):
|
| 92 |
+
|
| 93 |
+
y_true = batch.sources
|
| 94 |
+
y_pred = batch.estimates
|
| 95 |
+
|
| 96 |
+
# g_loss_contribs = {}
|
| 97 |
+
d_loss_contribs = {}
|
| 98 |
+
|
| 99 |
+
for stem in y_pred.keys():
|
| 100 |
+
|
| 101 |
+
if self.modality not in y_pred[stem].keys():
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
if y_pred[stem][self.modality].shape[-1] == 0:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
y_true_ = y_true[stem][self.modality]
|
| 108 |
+
y_pred_ = y_pred[stem][self.modality]
|
| 109 |
+
|
| 110 |
+
if self.modality == "audio":
|
| 111 |
+
y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
|
| 112 |
+
|
| 113 |
+
# g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
|
| 114 |
+
# y_pred_, y_true_
|
| 115 |
+
# )
|
| 116 |
+
|
| 117 |
+
d_loss_contribs[f"{self.name}:d/{stem}"] = self.loss.discriminator_loss(
|
| 118 |
+
y_pred_, y_true_
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# g_total_loss = sum(g_loss_contribs.values())
|
| 122 |
+
d_total_loss = sum(d_loss_contribs.values())
|
| 123 |
+
|
| 124 |
+
# g_loss_contribs["loss"] = g_total_loss
|
| 125 |
+
d_loss_contribs["disc_loss"] = d_total_loss
|
| 126 |
+
|
| 127 |
+
return d_loss_contribs
|
| 128 |
+
|
| 129 |
+
def generator_forward(self, batch: BatchedInputOutput):
|
| 130 |
+
|
| 131 |
+
y_true = batch.sources
|
| 132 |
+
y_pred = batch.estimates
|
| 133 |
+
|
| 134 |
+
g_loss_contribs = {}
|
| 135 |
+
# d_loss_contribs = {}
|
| 136 |
+
|
| 137 |
+
for stem in y_pred.keys():
|
| 138 |
+
|
| 139 |
+
if self.modality not in y_pred[stem].keys():
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
if y_pred[stem][self.modality].shape[-1] == 0:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
y_true_ = y_true[stem][self.modality]
|
| 146 |
+
y_pred_ = y_pred[stem][self.modality]
|
| 147 |
+
|
| 148 |
+
if self.modality == "audio":
|
| 149 |
+
y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
|
| 150 |
+
|
| 151 |
+
g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
|
| 152 |
+
y_pred_, y_true_
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# d_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.discriminator_loss(
|
| 156 |
+
# y_pred_, y_true_
|
| 157 |
+
# )
|
| 158 |
+
|
| 159 |
+
g_total_loss = sum(g_loss_contribs.values())
|
| 160 |
+
# d_total_loss = sum(d_loss_contribs.values())
|
| 161 |
+
|
| 162 |
+
g_loss_contribs["gen_loss"] = g_total_loss
|
| 163 |
+
# d_loss_contribs["loss"] = d_total_loss
|
| 164 |
+
|
| 165 |
+
return g_loss_contribs
|
| 166 |
+
|
| 167 |
+
def forward(self, batch: BatchedInputOutput):
|
| 168 |
+
return {
|
| 169 |
+
"generator": self.generator_forward(batch),
|
| 170 |
+
"discriminator": self.discriminator_forward(batch)
|
| 171 |
+
}
|
core/losses/l1snr.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.modules.loss import _Loss
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class WeightedL1Loss(_Loss):
|
| 6 |
+
def __init__(self, weights=None):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, y_pred, y_true):
|
| 10 |
+
ndim = y_pred.ndim
|
| 11 |
+
dims = list(range(1, ndim))
|
| 12 |
+
loss = F.l1_loss(y_pred, y_true, reduction='none')
|
| 13 |
+
loss = torch.mean(loss, dim=dims)
|
| 14 |
+
weights = torch.mean(torch.abs(y_true), dim=dims)
|
| 15 |
+
|
| 16 |
+
loss = torch.sum(loss * weights) / torch.sum(weights)
|
| 17 |
+
|
| 18 |
+
return loss
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class L1MatchLoss(_Loss):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
def forward(self, y_pred, y_true):
|
| 26 |
+
batch_size = y_pred.shape[0]
|
| 27 |
+
|
| 28 |
+
y_pred = y_pred.reshape(batch_size, -1)
|
| 29 |
+
y_true = y_true.reshape(batch_size, -1)
|
| 30 |
+
|
| 31 |
+
l1_true = torch.mean(torch.abs(y_true), dim=-1)
|
| 32 |
+
l1_pred = torch.mean(torch.abs(y_pred), dim=-1)
|
| 33 |
+
loss = torch.mean(torch.abs(l1_pred - l1_true))
|
| 34 |
+
|
| 35 |
+
return loss
|
| 36 |
+
|
| 37 |
+
class DecibelMatchLoss(_Loss):
|
| 38 |
+
def __init__(self, eps=1e-3):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.eps = eps
|
| 42 |
+
|
| 43 |
+
def forward(self, y_pred, y_true):
|
| 44 |
+
batch_size = y_pred.shape[0]
|
| 45 |
+
|
| 46 |
+
y_pred = y_pred.reshape(batch_size, -1)
|
| 47 |
+
y_true = y_true.reshape(batch_size, -1)
|
| 48 |
+
|
| 49 |
+
db_true = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_true)), dim=-1))
|
| 50 |
+
db_pred = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_pred)), dim=-1))
|
| 51 |
+
loss = torch.mean(torch.abs(db_pred - db_true))
|
| 52 |
+
|
| 53 |
+
return loss
|
| 54 |
+
|
| 55 |
+
class L1SNRLoss(_Loss):
|
| 56 |
+
def __init__(self, eps=1e-3):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.eps = torch.tensor(eps)
|
| 59 |
+
|
| 60 |
+
def forward(self, y_pred, y_true):
|
| 61 |
+
batch_size = y_pred.shape[0]
|
| 62 |
+
|
| 63 |
+
y_pred = y_pred.reshape(batch_size, -1)
|
| 64 |
+
y_true = y_true.reshape(batch_size, -1)
|
| 65 |
+
|
| 66 |
+
l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
|
| 67 |
+
l1_true = torch.mean(torch.abs(y_true), dim=-1)
|
| 68 |
+
|
| 69 |
+
snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))
|
| 70 |
+
|
| 71 |
+
return -torch.mean(snr)
|
| 72 |
+
|
| 73 |
+
class L1SNRLossIgnoreSilence(_Loss):
|
| 74 |
+
def __init__(self, eps=1e-3, dbthresh=-20, dbthresh_step=20):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.eps = torch.tensor(eps)
|
| 77 |
+
self.dbthresh = dbthresh
|
| 78 |
+
self.dbthresh_step = dbthresh_step
|
| 79 |
+
|
| 80 |
+
def forward(self, y_pred, y_true):
|
| 81 |
+
batch_size = y_pred.shape[0]
|
| 82 |
+
|
| 83 |
+
y_pred = y_pred.reshape(batch_size, -1)
|
| 84 |
+
y_true = y_true.reshape(batch_size, -1)
|
| 85 |
+
|
| 86 |
+
l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
|
| 87 |
+
l1_true = torch.mean(torch.abs(y_true), dim=-1)
|
| 88 |
+
|
| 89 |
+
snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))
|
| 90 |
+
|
| 91 |
+
db = 10.0 * torch.log10(torch.mean(torch.square(y_true), dim=-1) + 1e-6)
|
| 92 |
+
|
| 93 |
+
if torch.sum(db > self.dbthresh) == 0:
|
| 94 |
+
if torch.sum(db > self.dbthresh - self.dbthresh_step) == 0:
|
| 95 |
+
return -torch.mean(snr)
|
| 96 |
+
else:
|
| 97 |
+
return -torch.mean(snr[db > self.dbthresh - self.dbthresh_step])
|
| 98 |
+
|
| 99 |
+
return -torch.mean(snr[db > self.dbthresh])
|
| 100 |
+
|
| 101 |
+
class L1SNRDecibelMatchLoss(_Loss):
|
| 102 |
+
def __init__(self, db_weight=0.1, l1snr_eps=1e-3, dbeps=1e-3):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.l1snr = L1SNRLoss(l1snr_eps)
|
| 105 |
+
self.decibel_match = DecibelMatchLoss(dbeps)
|
| 106 |
+
self.db_weight = db_weight
|
| 107 |
+
|
| 108 |
+
def forward(self, y_pred, y_true):
|
| 109 |
+
|
| 110 |
+
return self.l1snr(y_pred, y_true) + self.decibel_match(y_pred, y_true)
|