File size: 4,197 Bytes
b5ae7e6 36f3d38 1849dad d5c8a0f 36f3d38 fe70438 d5c8a0f fe70438 d5c8a0f b5ae7e6 1849dad d5c8a0f 1849dad b113398 fe70438 8f5a1d4 fe70438 b113398 31cee3d d5c8a0f b113398 43b496d d5c8a0f b113398 31cee3d 8f5a1d4 fe70438 8f5a1d4 1849dad b113398 8f5a1d4 fe70438 d5c8a0f 8f5a1d4 31cee3d b5ae7e6 1849dad b5ae7e6 43b496d d5c8a0f 43b496d 1849dad 43b496d b5ae7e6 1849dad b5ae7e6 1849dad b5ae7e6 1849dad b5ae7e6 1849dad b5ae7e6 1849dad 36f3d38 b5ae7e6 1849dad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import importlib
import inspect
import os
from pathlib import Path
from .artifact import Artifact, Catalogs
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
from .error_utils import Documentation, UnitxtError, UnitxtWarning
from .settings_utils import get_constants, get_settings
from .utils import Singleton
constants = get_constants()
settings = get_settings()
def _register_catalog(catalog: LocalCatalog):
Catalogs().register(catalog)
def _unregister_catalog(catalog: LocalCatalog):
Catalogs().unregister(catalog)
def is_local_catalog_registered(catalog_path: str):
if os.path.isdir(catalog_path):
for catalog in _catalogs_list():
if isinstance(catalog, LocalCatalog):
if os.path.isdir(catalog.location):
if Path(catalog.location).resolve() == Path(catalog_path).resolve():
return True
return False
def register_local_catalog(catalog_path: str):
assert os.path.exists(catalog_path), f"Catalog path {catalog_path} does not exist."
assert os.path.isdir(
catalog_path
), f"Catalog path {catalog_path} is not a directory."
if not is_local_catalog_registered(catalog_path=catalog_path):
_register_catalog(LocalCatalog(location=catalog_path))
def unregister_local_catalog(catalog_path: str):
if is_local_catalog_registered(catalog_path=catalog_path):
for catalog in _catalogs_list():
if isinstance(catalog, LocalCatalog):
if os.path.isdir(catalog.location):
if Path(catalog.location).resolve() == Path(catalog_path).resolve():
_unregister_catalog(catalog)
def _catalogs_list():
return list(Catalogs())
def _register_all_catalogs():
_register_catalog(GithubCatalog())
_register_catalog(LocalCatalog())
_reset_env_local_catalogs()
def _reset_env_local_catalogs():
for catalog in _catalogs_list():
if isinstance(catalog, EnvironmentLocalCatalog):
_unregister_catalog(catalog)
if settings.catalogs and settings.artifactories:
raise UnitxtError(
f"Both UNITXT_CATALOGS and UNITXT_ARTIFACTORIES are set. Use only UNITXT_CATALOG. UNITXT_ARTIFACTORIES is deprecated.\n"
f"UNITXT_CATALOG: {settings.catalogs}\n"
f"UNITXT_ARTIFACTORIES: {settings.artifactories}\n",
Documentation.CATALOG,
)
if settings.artifactories:
UnitxtWarning(
"UNITXT_ARTIFACTORIES is set but is deprecated, use UNITXT_CATALOGS instead.",
Documentation.CATALOG,
)
if settings.catalogs:
for path in settings.catalogs.split(constants.env_local_catalogs_paths_sep):
_register_catalog(EnvironmentLocalCatalog(location=path))
if settings.artifactories:
for path in settings.artifactories.split(
constants.env_local_catalogs_paths_sep
):
_register_catalog(EnvironmentLocalCatalog(location=path))
def _register_all_artifacts():
dir = os.path.dirname(__file__)
file_name = os.path.basename(__file__)
for file in os.listdir(dir):
if (
file.endswith(".py")
and file not in constants.non_registered_files
and file != file_name
):
module_name = file.replace(".py", "")
module = importlib.import_module("." + module_name, __package__)
for _name, obj in inspect.getmembers(module):
# Make sure the object is a class
if inspect.isclass(obj):
# Make sure the class is a subclass of Artifact (but not Artifact itself)
if issubclass(obj, Artifact) and obj is not Artifact:
Artifact.register_class(obj)
class ProjectArtifactRegisterer(metaclass=Singleton):
def __init__(self):
if not hasattr(self, "_registered"):
self._registered = False
if not self._registered:
_register_all_catalogs()
_register_all_artifacts()
self._registered = True
def register_all_artifacts():
ProjectArtifactRegisterer()
|