from abc import ABC, abstractmethod
from pathlib import Path
import re
from typing import Any, Callable, TypeAlias
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
import yaml
from pathlib import PurePosixPath
from boltzkit.utils.key_value_store import FileKV
from pathlib import PurePosixPath
[docs]
def strip_repo_prefix(full_path: str, repo_root: str) -> str:
"""
Returns the relative path to the given repo root.
"""
full = PurePosixPath(full_path)
root = PurePosixPath(repo_root)
return str(full.relative_to(root))
[docs]
class CachedRepo(ABC):
"""
Abstract base class representing a cached repository.
A CachedRepo provides a unified interface for interacting with
remote repositories (e.g., Huggingface datasets or local directories)
while caching files locally for efficient repeated access.
Attributes:
remote_uri (str): The URI or path of the remote repository.
local_path (Path): The local directory where files are cached.
"""
[docs]
def __init__(self, remote_uri: str, local_repo_path: Path, lazy_load: bool):
"""
Initialize a CachedRepo instance.
Args:
remote_uri (str): The remote repository URI or path.
local_repo_path (Path): Local path where cached files will be stored.
lazy_load (bool): If True, files are loaded on demand; if False, all files are loaded immediately.
"""
super().__init__()
self.__remote_uri = remote_uri
self.__local_path = local_repo_path
self.__lazy_load = lazy_load
[docs]
def post_init(self):
self._key_value_store = FileKV(self.local_path / "cached_config.yaml")
info_path = self.load_file("info.yaml")
with open(info_path) as f:
self._config: dict[str, Any] = yaml.safe_load(f)
if not self.__lazy_load:
self.load_all_files()
[docs]
@abstractmethod
def load_file(self, relative_fpath: str) -> Path:
raise NotImplementedError
[docs]
def try_load_file(self, relative_fpath: str | None) -> Path | None:
if relative_fpath is None:
return None
try:
return self.load_file(relative_fpath)
except Exception:
return None
[docs]
@abstractmethod
def load_all_files(self) -> None:
raise NotImplementedError
[docs]
@abstractmethod
def list_remote_files(self) -> list[str]:
raise NotImplementedError
[docs]
def find_file(self, regex: str) -> list[str]:
"""
Return all remote files matching the given regex pattern.
Args:
regex (str): Regular expression to match against file paths.
Returns:
List[str]: List of matching file paths (repo-relative).
"""
pattern = re.compile(regex)
return [path for path in self.list_remote_files() if pattern.search(path)]
@property
def config(self) -> dict[str, Any]:
if not hasattr(self, "_config"):
raise AttributeError(
"The attribute _config could not be found, perhaps 'post_init' was not called."
)
return self._config.copy()
[docs]
@classmethod
@abstractmethod
def match_uri(cls, uri: str) -> bool:
raise NotImplementedError
[docs]
@classmethod
@abstractmethod
def get_name_from_uri(cls, uri: str) -> str:
raise NotADirectoryError
@property
def remote_uri(self) -> str:
return self.__remote_uri
@property
def local_path(self) -> Path:
return self.__local_path
[docs]
def get_cached_key_value_store(self):
return self._key_value_store
[docs]
class HuggingfaceRepo(CachedRepo):
[docs]
def __init__(self, remote_uri, local_repo_path, lazy_load):
super().__init__(remote_uri, local_repo_path, lazy_load)
self._fs = HfFileSystem()
self._ignore_patterns = [".gitattributes"]
self.post_init()
[docs]
def load_file(self, relative_fpath):
local_path = hf_hub_download(
repo_id=self.remote_uri.replace("datasets/", ""),
repo_type="dataset",
local_dir=self.local_path,
filename=relative_fpath,
)
return Path(local_path)
[docs]
def load_all_files(self):
snapshot_download(
repo_id=self.remote_uri.replace("datasets/", ""),
repo_type="dataset",
local_dir=self.local_path,
ignore_patterns=self._ignore_patterns,
)
[docs]
def list_remote_files(self):
l = [
strip_repo_prefix(p, self.remote_uri)
for p in self._fs.find(self.remote_uri)
]
# filter out unwanted files like .gitattributes
return list(filter(lambda x: x not in self._ignore_patterns, l))
[docs]
@classmethod
def match_uri(cls, uri):
pattern = r"^datasets\/[a-zA-Z0-9._-]+(\/.*)?$"
return bool(re.match(pattern, uri))
[docs]
@classmethod
def get_name_from_uri(cls, uri):
return uri.split("/")[-1] # "huggingface_" +
[docs]
class LocalRepo(CachedRepo):
[docs]
def __init__(self, remote_uri, local_repo_path, lazy_load):
super().__init__(remote_uri, local_repo_path, lazy_load)
self.post_init()
[docs]
def load_file(self, relative_fpath):
remote_file = Path(self.remote_uri) / relative_fpath
local_file = self.local_path / relative_fpath
# create parent directories if needed
local_file.parent.mkdir(parents=True, exist_ok=True)
# create symlink if it doesn't exist
if not local_file.exists():
local_file.symlink_to(remote_file.resolve())
return local_file
[docs]
def load_all_files(self):
for remote_relative_fpath in self.list_remote_files():
self.load_file(remote_relative_fpath)
[docs]
def list_remote_files(self):
files = [
str(p.relative_to(self.remote_uri))
for p in Path(self.remote_uri).rglob("*")
if p.is_file()
]
return files
[docs]
@classmethod
def match_uri(cls, uri):
return Path(uri).expanduser().exists()
[docs]
@classmethod
def get_name_from_uri(cls, uri):
return "local_" + Path(uri).name
Content: TypeAlias = str | bytes | Callable[[Path], None]
"""
str: Write text
bytes: write binary
Callable: creates file at path
"""
[docs]
def normalize_path(path: str | PurePosixPath) -> str:
return str(PurePosixPath(path))
[docs]
class VirtualRepo(CachedRepo):
"""
Creates cache directory from in-memory content,
i.e., cache dir is not backed by some form of directory or repository
"""
[docs]
def __init__(
self,
remote_uri,
local_repo_path,
lazy_load,
file_tree: dict[str, Content],
):
"""
remote_uri must have format 'virtual://<name>', e.g., 'virtual://foo',
which will create a cache dir with name 'virtual_foo'.
"""
super().__init__(remote_uri, local_repo_path, lazy_load)
self._file_content_tree = {normalize_path(k): v for k, v in file_tree.items()}
self.post_init()
[docs]
def load_file(self, relative_fpath):
relative_fpath = normalize_path(relative_fpath)
target_path = self.local_path / relative_fpath
target_path.parent.mkdir(parents=True, exist_ok=True)
content = self._file_content_tree[relative_fpath]
if isinstance(content, str):
target_path.write_text(content)
return target_path
if isinstance(content, bytes):
target_path.write_bytes(content)
return target_path
# Callable case: responsible for writing the file itself
content(target_path)
# Optional safety check
if not target_path.exists():
raise RuntimeError(f"Callable did not create file: {relative_fpath}")
return target_path
[docs]
def load_all_files(self):
for path in self._file_content_tree.keys():
self.load_file(path)
[docs]
def list_remote_files(self):
return list(self._file_content_tree.keys())
[docs]
@classmethod
def match_uri(cls, uri):
return uri.startswith("virtual://")
[docs]
@classmethod
def get_name_from_uri(cls, uri):
return "virtual_" + uri.removeprefix("virtual://")
[docs]
def create_cached_repo(
uri: str,
local_repos_dir: Path = Path("target_cache"),
lazy_load: bool = True,
**kwargs,
):
"""
Creates CachedRepo object from the given URI (Unified Resource Identifier).
The type of the CachedRepo is automatically determined by the given URI.
"""
classes: list[type[CachedRepo]] = [HuggingfaceRepo, LocalRepo, VirtualRepo]
if isinstance(local_repos_dir, str):
local_repos_dir = Path(local_repos_dir)
for cls in classes:
if cls.match_uri(uri):
name = cls.get_name_from_uri(uri)
local_repo_path = local_repos_dir / name
local_repo_path.mkdir(parents=True, exist_ok=True)
cached_repo = cls(
remote_uri=uri,
local_repo_path=local_repo_path,
lazy_load=lazy_load,
**kwargs,
)
print(
f"Created cached repo of type '{cls.__name__}' for remote uri '{uri}' and local path '{local_repo_path.as_posix()}'"
)
return cached_repo
if __name__ == "__main__":
repo_path = "datasets/chrklitz99/alanine_dipeptide"
# repo_path = "target_cache/alanine_dipeptide"
sys_info = create_cached_repo(repo_path)
print("Remote files:")
print(sys_info.list_remote_files())
print(sys_info.config)
print(sys_info.find_file(".*\.pdb"))
virtual_repo = create_cached_repo(
"virtual://test", file_tree={"data/test.yaml": "Hello world"}
)
virtual_repo.load_file("data/test.yaml")
print(virtual_repo.list_remote_files())