Source code for dlordinal.datasets.hci

from hashlib import md5
from pathlib import Path
from shutil import move, rmtree
from typing import Callable, Optional, Union

import pandas as pd
from PIL import Image, UnidentifiedImageError
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
    extract_archive,
)


[docs] class HCI(ImageFolder): """ Historical Color Images (HCI) Decade Database dataset :footcite:t:`palermo2012dating`. This dataset contains colour photographs from five decades (1930s-1970s), organised for decade classification. Upon first use, the dataset is automatically downloaded, verified, preprocessed, and split into training and test subsets. The preprocessing pipeline includes: - verifying and downloading the dataset archive if necessary; - extracting and normalising directory names according to class labels; - resizing all images to 224x224 pixels; - creating a stratified 70/30 train/test split; - generating an MD5 checksum file for future integrity checks. Parameters ---------- root : str or Path Root directory where the dataset will be stored and processed. transform : callable, optional A function/transform applied to each loaded PIL image. target_transform : callable, optional A function/transform applied to the target label. is_valid_file : callable, optional A function that takes a file path and returns ``True`` if the file should be included. train : bool, default=True If ``True``, loads the training split; otherwise, loads the test split. Attributes ---------- URL : str Download URL for the dataset archive. MD5 : str MD5 checksum used to verify the downloaded archive. CATEGORIES : dict Mapping from decade names to numeric class labels (as strings). base_root: Path Base directory for dataset storage and processing. train: bool Indicates whether the dataset instance is for training or testing. Example ----- >>> from dlordinal.datasets.hci import HCI >>> dataset = HCI(root="data", train=True) >>> img, label = dataset[0] Notes ----- The train/test split is stratified by decade, with 70% of the images in the training set and 30% in the test set. Preprocessing is only performed the first time the dataset is initialised. """ URL = "http://graphics.cs.cmu.edu/projects/historicalColor/HistoricalColor-ECCV2012-DecadeDatabase.tar" MD5 = "afb4c47b7da105c4afd1f27e06bea171" CATEGORIES = {"1930s": "0", "1940s": "1", "1950s": "2", "1960s": "3", "1970s": "4"} def __init__( self, root: Union[str, Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, train: bool = True, ): self.base_root = Path(root) self.train = train self._prepare_dataset() super().__init__( root=str(self.base_root / "HCI" / ("train" if self.train else "test")), loader=default_loader, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, ) def _prepare_dataset(self) -> bool: target_folder = self.base_root / "HCI" if not target_folder.exists() or not self._verify_md5sums(): if target_folder.exists(): rmtree(target_folder, ignore_errors=True) tar_path = self.base_root / "HistoricalColor-ECCV2012-DecadeDatabase.tar" if not tar_path.exists() or not check_integrity(str(tar_path), self.MD5): # Download and extract download_and_extract_archive( self.URL, str(self.base_root), md5=self.MD5 ) else: # Extract from existing tar extract_archive( str(tar_path), str(self.base_root), False, ) extracted_folder = ( self.base_root / "HistoricalColor-ECCV2012" / "data" / "imgs" / "decade_database" ) extracted_folder.rename(target_folder) rmtree(self.base_root / "HistoricalColor-ECCV2012", ignore_errors=True) # Rename categories for old_name, new_name in self.CATEGORIES.items(): (target_folder / old_name).rename(target_folder / new_name) # Rescale images to 224x224 for cat in self.CATEGORIES.values(): cat_folder = target_folder / cat for img_path in cat_folder.glob("*"): if img_path.suffix.lower() in IMG_EXTENSIONS: try: with Image.open(img_path) as img: img = img.resize((224, 224), Image.Resampling.LANCZOS) img.save(img_path) except UnidentifiedImageError: # print(f"Removing corrupted image: {img_path} ({e})") img_path.unlink() # Train/test split self._split_train_test(target_folder) # Create md5sums file self._create_md5sums_file() return True return False def _split_train_test(self, folder: Path, train_frac: float = 0.7, seed: int = 0): # Gather all images and labels records = [ {"path": str(p.resolve()), "label": int(cat)} for cat in self.CATEGORIES.values() for p in (folder / cat).glob("*") if p.suffix.lower() in IMG_EXTENSIONS ] df = pd.DataFrame(records) train_df = df.groupby("label", group_keys=False).sample( frac=train_frac, random_state=seed ) test_df = df.drop(train_df.index) # Move files for df_subset, subset_name in [(train_df, "train"), (test_df, "test")]: subset_folder = folder / subset_name for row in df_subset.itertuples(index=False): dest_folder = subset_folder / str(row.label) dest_folder.mkdir(parents=True, exist_ok=True) move(row.path, dest_folder / Path(row.path).name) # Remove empty original category folders for cat in self.CATEGORIES.values(): cat_folder = folder / cat if cat_folder.exists() and not any(cat_folder.iterdir()): cat_folder.rmdir() def _create_md5sums_file(self): md5sum_path = self.base_root / "HCI" / "md5sums.txt" with open(md5sum_path, "w") as f: for img_path in (self.base_root / "HCI").rglob("*"): if img_path.suffix.lower() in IMG_EXTENSIONS: with open(img_path, "rb") as img_file: file_hash = md5(img_file.read()).hexdigest() relative_path = img_path.relative_to(self.base_root / "HCI") f.write(f"{file_hash} {relative_path}\n") def _verify_md5sums(self) -> bool: md5sum_path = self.base_root / "HCI" / "md5sums.txt" if not md5sum_path.exists(): return False with open(md5sum_path, "r") as f: for line in f: expected_hash, relative_path = line.strip().split(" ", 1) img_path = self.base_root / "HCI" / relative_path if not img_path.exists(): return False with open(img_path, "rb") as img_file: actual_hash = md5(img_file.read()).hexdigest() if actual_hash != expected_hash: return False return True