"""Json schema validation functions."""
from __future__ import annotations
import json
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING
from jsonschema import Draft7Validator, FormatChecker
from referencing import Registry, Resource
from referencing.jsonschema import DRAFT7
if TYPE_CHECKING: # pragma: no cover
from jsonschema.exceptions import ValidationError
[docs]class ValidationErrorGroup:
"""
A group of validation errors.
schema_id: The ID of the schema that the errors belong to.
message: A message describing the errors.
errors: A list of validation errors.
"""
schema_id: str
message: str
errors: list[ValidationError]
def __init__(self, schema_id: str, message: str, errors: list[ValidationError]):
"""Initialize the ValidationErrorGroup class."""
self.schema_id = schema_id
self.message = message
self.errors = errors
[docs]class JsonSchemaRegistry:
"""Class for validating JSON against JSON schemas."""
_schema_aliases: dict[str, str]
_registry: Registry
def __init__(self) -> None:
"""Initialize the JsonSchemaValidator class."""
self._registry = Registry()
self._schema_aliases = {}
[docs] def register_directory(self, directory: str, pattern: str = "**/*.json"):
"""Register all schemas in a directory."""
for schema_filepath in Path(directory).resolve().glob(pattern):
self.register_schema(str(schema_filepath))
[docs] def register_schema(self, schema: str | dict) -> str:
"""
Add a schema to the registry.
When registering a schema from a dict,
you can only reference the schema using the $id not the filepath.
When registering a schema from a filepath, and the schema does not have a $id,
the schema ID is generated from the filepath.
Args:
schema (str | dict): JSON schema or filepath to schema.
Returns:
str: The schema ID.
Raises:
ValueError: If the schema does not have a $id. For filepaths,
KeyError: If a conflicting schema is already registered.
"""
schema_filepath = ""
if isinstance(schema, str):
schema_filepath = schema
with Path(schema).resolve().open(encoding="UTF-8") as file:
schema = json.load(file)
if not schema.get("$id"):
if schema_filepath:
schema["$id"] = schema_filepath
else:
raise ValueError("Schema must have an $id.")
normalized_id = self.normalize_schema_id(schema["$id"])
if schema_filepath:
normalized_filepath = self.normalize_schema_id(schema_filepath)
if normalized_filepath != normalized_id:
self._schema_aliases[normalized_filepath] = normalized_id
schema_id = normalized_id
schema["$id"] = schema_id
schema = self._normalize_schema_refs(schema)
if self.schema_id_in_registry(schema_id):
if self.contents(schema_id) != schema:
raise KeyError(f"Schema '{schema_id}' already registered with different contents.")
else:
self._registry = self._registry.with_resource(schema_id, Resource(schema, DRAFT7))
return schema_id
[docs] def validate_json_against_schemas(
self,
data: dict | list,
schema_ids: list[str],
):
"""
Validate JSON data against JSON schemas.
Args:
data (dict | list): The JSON data to validate.
schema_ids (list[str]): A list of schema IDs to validate against.
Raises:
jsonschema.exceptions.ValidationError: The first validation error.
"""
schema_ids = [self._resolve_alias(self.normalize_schema_id(sid)) for sid in schema_ids]
for schema_id in schema_ids:
if not self.schema_id_in_registry(schema_id):
raise KeyError(f"Schema '{schema_id}' not found in registry.")
validator = self._get_validator(schema_id)
validator.validate(data)
[docs] def validate_json_against_schemas_complete(
self,
data: dict | list,
schema_ids: list[str],
) -> list[ValidationErrorGroup]:
"""
Validate JSON data against JSON schemas and return all errors.
Args:
data (dict | list): The JSON data to validate.
schema_ids (list[str]): A list of schema IDs to validate against.
Returns:
list[ValidationErrorGroup]: A list of validation error groups.
"""
schema_ids = [self._resolve_alias(self.normalize_schema_id(sid)) for sid in schema_ids]
validation_errors = []
for schema_id in schema_ids:
if not self.schema_id_in_registry(schema_id):
raise KeyError(f"Schema '{schema_id}' not found in registry.")
validator = self._get_validator(schema_id)
errors = list(validator.iter_errors(data))
unique_errors = {}
for error in errors:
path = ".".join(p if isinstance(p, str) else "item" for p in error.path)
message = f"{path}: {error.message}"
unique_errors.setdefault(message, [])
unique_errors[message].append(error)
for message, errors in unique_errors.items():
validation_errors.append(ValidationErrorGroup(schema_id, message, errors))
return validation_errors
[docs] def schema_id_in_registry(self, schema_id: str) -> bool:
"""Check if a schema ID is in the registry."""
return self._resolve_alias(self.normalize_schema_id(schema_id)) in self._registry
[docs] def get_schema_ids(self) -> list[str]:
"""Get all schema IDs in the registry."""
return list(self._registry)
[docs] def contents(self, schema_id: str) -> dict:
"""Get the contents of a schema."""
return self._registry.contents(self._resolve_alias(self.normalize_schema_id(schema_id)))
[docs] def get_referenced_schema_ids(self, schema_id: str) -> list[str]:
"""
Get all referenced schema IDs in a schema.
If referenced schemas are registered in the registry, this method
recursively traverses them to find all child references.
Args:
schema_id (str): The schema ID to analyse.
Returns:
list[str]: A list of schema IDs.
"""
schema_id = self._resolve_alias(self.normalize_schema_id(schema_id))
schema = self._registry.contents(schema_id)
references = self._find_references_recursively(schema)
referenced_ids = set()
for ref in references:
if ref.startswith("#"):
continue
referenced_ids.add(ref.split("#")[0])
return list(referenced_ids)
[docs] def get_references(self, schema_id: str) -> list[str]:
"""
Get all $ref URI references in a schema.
If referenced schemas are registered in the registry, this method
recursively traverses them to find all child references.
Args:
schema_id (str): The schema ID to analyse.
Returns:
list[str]: A list of all $ref URIs found directly or transitively.
"""
schema_id = self._resolve_alias(self.normalize_schema_id(schema_id))
schema = self._registry.contents(schema_id)
return list(self._find_references_recursively(schema))
[docs] def check_unresolvable_refs(self):
"""
Check all schemas in the registry for unresolvable references.
Raises:
referencing.exceptions.Unresolvable: If a referenced schema is not found.
"""
resolver = self._registry.resolver()
for schema_id in self._registry:
for ref in self.get_references(schema_id):
resolver.lookup(f"{schema_id}{ref}" if ref.startswith("#") else ref)
def _normalize_schema_refs(self, schema: dict) -> dict:
"""Update schema refs to schema IDs."""
result = {}
for key, value in schema.items():
if isinstance(value, dict):
result[key] = self._normalize_schema_refs(value)
elif isinstance(value, list):
result[key] = [
self._normalize_schema_refs(item) if isinstance(item, dict) else item
for item in value
]
elif key == "$ref" and not value.startswith("#"):
parts = value.split("#", 1)
ref_identifier = self._resolve_alias(self.normalize_schema_id(parts[0]))
result[key] = f"{ref_identifier}#{parts[1]}" if len(parts) > 1 else ref_identifier
else:
result[key] = value
return result
def _find_references_recursively(self, schema: dict) -> set[str]:
"""Find all references in a schema recursively."""
references = set()
for key, value in schema.items():
if key == "$ref" and isinstance(value, str):
references.add(value)
if not value.startswith("#"):
parts = value.split("#", 1)
if self.schema_id_in_registry(parts[0]):
ref_schema = self._registry.contents(parts[0])
references.update(self._find_references_recursively(ref_schema))
else:
logging.warning("Referenced schema '%s' not found in registry.", parts[0])
elif isinstance(value, dict):
references.update(self._find_references_recursively(value))
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
references.update(self._find_references_recursively(item))
return references
[docs] def normalize_schema_id(self, schema_id: str) -> str:
"""
Clean a schema ID or convert a schema filepath to a schema ID.
The schema ID is lowercased, slashes and ".json" are replaced with underscores,
and schema IDs are stripped of character not in [a-z0-9_].
Args:
schema_id (str): The schema ID or filepath.
Returns:
str: The normalized schema ID.
"""
schema_id = re.sub(
r"[^a-z0-9_]",
"",
schema_id.replace("/", "_").replace(".json", "").lower(),
)
return re.sub(r"_+", "_", schema_id).strip("_")
def _resolve_alias(self, normalized_id: str) -> str:
"""
Resolve a normalized schema ID through the alias chain.
Args:
normalized_id (str): The normalized schema ID.
Returns:
str: The resolved schema ID after following aliases.
"""
return self._schema_aliases.get(normalized_id, normalized_id)
def _get_validator(self, schema_id: str) -> Draft7Validator:
"""Get a validator for a schema, creating it if it doesn't exist in the cache."""
return Draft7Validator(
self._registry.contents(schema_id),
format_checker=FormatChecker(),
registry=self._registry,
)