File size: 4,197 Bytes
b5ae7e6
36f3d38
1849dad
d5c8a0f
36f3d38
fe70438
d5c8a0f
fe70438
d5c8a0f
b5ae7e6
1849dad
d5c8a0f
 
1849dad
 
b113398
fe70438
8f5a1d4
 
 
fe70438
b113398
31cee3d
d5c8a0f
 
 
 
 
 
 
 
 
 
b113398
 
43b496d
 
 
d5c8a0f
 
 
 
 
 
 
 
 
 
 
b113398
31cee3d
8f5a1d4
fe70438
8f5a1d4
 
1849dad
b113398
 
8f5a1d4
 
 
 
 
 
 
fe70438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5c8a0f
 
 
 
8f5a1d4
31cee3d
b5ae7e6
 
 
 
1849dad
b5ae7e6
43b496d
 
d5c8a0f
43b496d
 
1849dad
 
 
 
43b496d
b5ae7e6
 
 
 
 
 
1849dad
 
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
 
1849dad
36f3d38
b5ae7e6
1849dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import importlib
import inspect
import os
from pathlib import Path

from .artifact import Artifact, Catalogs
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
from .error_utils import Documentation, UnitxtError, UnitxtWarning
from .settings_utils import get_constants, get_settings
from .utils import Singleton

constants = get_constants()
settings = get_settings()


def _register_catalog(catalog: LocalCatalog):
    Catalogs().register(catalog)


def _unregister_catalog(catalog: LocalCatalog):
    Catalogs().unregister(catalog)


def is_local_catalog_registered(catalog_path: str):
    if os.path.isdir(catalog_path):
        for catalog in _catalogs_list():
            if isinstance(catalog, LocalCatalog):
                if os.path.isdir(catalog.location):
                    if Path(catalog.location).resolve() == Path(catalog_path).resolve():
                        return True
    return False


def register_local_catalog(catalog_path: str):
    assert os.path.exists(catalog_path), f"Catalog path {catalog_path} does not exist."
    assert os.path.isdir(
        catalog_path
    ), f"Catalog path {catalog_path} is not a directory."
    if not is_local_catalog_registered(catalog_path=catalog_path):
        _register_catalog(LocalCatalog(location=catalog_path))


def unregister_local_catalog(catalog_path: str):
    if is_local_catalog_registered(catalog_path=catalog_path):
        for catalog in _catalogs_list():
            if isinstance(catalog, LocalCatalog):
                if os.path.isdir(catalog.location):
                    if Path(catalog.location).resolve() == Path(catalog_path).resolve():
                        _unregister_catalog(catalog)


def _catalogs_list():
    return list(Catalogs())


def _register_all_catalogs():
    _register_catalog(GithubCatalog())
    _register_catalog(LocalCatalog())
    _reset_env_local_catalogs()


def _reset_env_local_catalogs():
    for catalog in _catalogs_list():
        if isinstance(catalog, EnvironmentLocalCatalog):
            _unregister_catalog(catalog)

    if settings.catalogs and settings.artifactories:
        raise UnitxtError(
            f"Both UNITXT_CATALOGS and UNITXT_ARTIFACTORIES are set.  Use only UNITXT_CATALOG.  UNITXT_ARTIFACTORIES is deprecated.\n"
            f"UNITXT_CATALOG: {settings.catalogs}\n"
            f"UNITXT_ARTIFACTORIES: {settings.artifactories}\n",
            Documentation.CATALOG,
        )

    if settings.artifactories:
        UnitxtWarning(
            "UNITXT_ARTIFACTORIES is set but is deprecated, use UNITXT_CATALOGS instead.",
            Documentation.CATALOG,
        )

    if settings.catalogs:
        for path in settings.catalogs.split(constants.env_local_catalogs_paths_sep):
            _register_catalog(EnvironmentLocalCatalog(location=path))

    if settings.artifactories:
        for path in settings.artifactories.split(
            constants.env_local_catalogs_paths_sep
        ):
            _register_catalog(EnvironmentLocalCatalog(location=path))


def _register_all_artifacts():
    dir = os.path.dirname(__file__)
    file_name = os.path.basename(__file__)

    for file in os.listdir(dir):
        if (
            file.endswith(".py")
            and file not in constants.non_registered_files
            and file != file_name
        ):
            module_name = file.replace(".py", "")

            module = importlib.import_module("." + module_name, __package__)

            for _name, obj in inspect.getmembers(module):
                # Make sure the object is a class
                if inspect.isclass(obj):
                    # Make sure the class is a subclass of Artifact (but not Artifact itself)
                    if issubclass(obj, Artifact) and obj is not Artifact:
                        Artifact.register_class(obj)


class ProjectArtifactRegisterer(metaclass=Singleton):
    def __init__(self):
        if not hasattr(self, "_registered"):
            self._registered = False

        if not self._registered:
            _register_all_catalogs()
            _register_all_artifacts()
            self._registered = True


def register_all_artifacts():
    ProjectArtifactRegisterer()