from pathlib import Path class ReferenceDocsService: """Service to load and provide reference documents for agents.""" def __init__(self, base_path: str | None = None): """ Initialize the reference docs service. Args: base_path: Path to the reference_docs directory. Defaults to ../reference_docs relative to backend/ """ if base_path is None: # Default to reference_docs at project root (sibling to backend/) base_path = Path(__file__).parent.parent.parent.parent / "reference_docs" self.base_path = Path(base_path) # Path to prompts directory at project root (sibling to backend/) self.prompts_path = Path(__file__).parent.parent.parent.parent / "prompts" # Cache loaded documents self._brand_context: str | None = None self._channel_context: str | None = None self._barclaycard_brand_spec: str | None = None self._barclays_brand_spec: str | None = None self._channel_best_practices_spec: str | None = None self._channel_tech_specs_spec: str | None = None self._legal_spec: str | None = None # DB-backed spec cache (takes priority over file-based) self._db_specs: dict[str, str] = {} async def load_specs_from_db(self, session) -> None: """Load active spec content from DB for all agent keys.""" from app.repositories.knowledge_base_repository import KnowledgeBaseRepository repo = KnowledgeBaseRepository(session) agent_keys = [ "legal", "brand_barclays", "brand_barclaycard", "channel_best_practices", "channel_tech_specs", ] for key in agent_keys: spec = await repo.get_active_spec_by_key(key) if spec and spec.content: self._db_specs[key] = spec.content print(f" Loaded DB spec for {key}: {len(spec.content)} chars (v{spec.version_number})") def invalidate_cache(self, agent_key: str | None = None, new_spec_content: str | None = None) -> None: """Clear cached specs and optionally replace with new content. Args: agent_key: The agent key to invalidate (or None for all). new_spec_content: If provided, immediately populate the DB cache with this content so the next analysis uses it without a restart. """ if agent_key is None: self._db_specs.clear() self._barclaycard_brand_spec = None self._barclays_brand_spec = None self._channel_best_practices_spec = None self._channel_tech_specs_spec = None self._legal_spec = None else: if new_spec_content is not None: self._db_specs[agent_key] = new_spec_content else: self._db_specs.pop(agent_key, None) # Also clear the file-based cache so it won't be stale cache_map = { "legal": "_legal_spec", "brand_barclays": "_barclays_brand_spec", "brand_barclaycard": "_barclaycard_brand_spec", "channel_best_practices": "_channel_best_practices_spec", "channel_tech_specs": "_channel_tech_specs_spec", } attr = cache_map.get(agent_key) if attr: setattr(self, attr, None) def get_brand_context(self) -> str: """Load and return all brand guideline documents as a single context string.""" if self._brand_context is None: brand_path = self.base_path / "brand" self._brand_context = self._load_all_markdown_files(brand_path) return self._brand_context def get_barclaycard_brand_spec(self) -> str: """Load and return the Barclaycard brand specification.""" # Check DB cache first if "brand_barclaycard" in self._db_specs: return self._db_specs["brand_barclaycard"] if self._barclaycard_brand_spec is None: spec_path = self.prompts_path / "brand_barclaycard.md" try: if spec_path.exists(): self._barclaycard_brand_spec = spec_path.read_text(encoding="utf-8") else: print(f"Warning: Barclaycard brand spec not found at {spec_path}") # Fall back to raw brand context self._barclaycard_brand_spec = self.get_brand_context() except Exception as e: print(f"Warning: Could not read Barclaycard brand spec: {e}") self._barclaycard_brand_spec = self.get_brand_context() return self._barclaycard_brand_spec def get_barclays_brand_spec(self) -> str: """Load and return the Barclays brand specification.""" # Check DB cache first if "brand_barclays" in self._db_specs: return self._db_specs["brand_barclays"] # Check cache first if not hasattr(self, '_barclays_brand_spec'): self._barclays_brand_spec = None if self._barclays_brand_spec is None: spec_path = self.prompts_path / "brand_barclays.md" try: if spec_path.exists(): self._barclays_brand_spec = spec_path.read_text(encoding="utf-8") else: print(f"Warning: Barclays brand spec not found at {spec_path}, using raw brand context") # Fall back to raw brand context from reference_docs/brand/ self._barclays_brand_spec = self.get_brand_context() except Exception as e: print(f"Warning: Could not read Barclays brand spec: {e}") self._barclays_brand_spec = self.get_brand_context() return self._barclays_brand_spec def get_channel_context(self) -> str: """Load and return all channel guideline documents as a single context string.""" if self._channel_context is None: channel_path = self.base_path / "channel" self._channel_context = self._load_all_markdown_files(channel_path) return self._channel_context def get_channel_best_practices_spec(self) -> str: """Load and return the Channel Best Practices specification.""" # Check DB cache first if "channel_best_practices" in self._db_specs: return self._db_specs["channel_best_practices"] if self._channel_best_practices_spec is None: spec_path = self.prompts_path / "channel_best_practices.md" try: if spec_path.exists(): self._channel_best_practices_spec = spec_path.read_text(encoding="utf-8") else: print(f"Warning: Channel Best Practices spec not found at {spec_path}") self._channel_best_practices_spec = self.get_channel_context() except Exception as e: print(f"Warning: Could not read Channel Best Practices spec: {e}") self._channel_best_practices_spec = self.get_channel_context() return self._channel_best_practices_spec def get_channel_tech_specs_spec(self) -> str: """Load and return the Channel Tech Specs specification.""" # Check DB cache first if "channel_tech_specs" in self._db_specs: return self._db_specs["channel_tech_specs"] if self._channel_tech_specs_spec is None: spec_path = self.prompts_path / "channel_tech_specs.md" try: if spec_path.exists(): self._channel_tech_specs_spec = spec_path.read_text(encoding="utf-8") else: print(f"Warning: Channel Tech Specs spec not found at {spec_path}") self._channel_tech_specs_spec = self.get_channel_context() except Exception as e: print(f"Warning: Could not read Channel Tech Specs spec: {e}") self._channel_tech_specs_spec = self.get_channel_context() return self._channel_tech_specs_spec def get_legal_spec(self) -> str: """Load and return the Legal specification.""" # Check DB cache first if "legal" in self._db_specs: return self._db_specs["legal"] if self._legal_spec is None: spec_path = self.prompts_path / "legal.md" try: if spec_path.exists(): self._legal_spec = spec_path.read_text(encoding="utf-8") else: print(f"Warning: Legal spec not found at {spec_path}") self._legal_spec = "No legal specification found. Apply general legal compliance checks." except Exception as e: print(f"Warning: Could not read Legal spec: {e}") self._legal_spec = "No legal specification found. Apply general legal compliance checks." return self._legal_spec def _load_all_markdown_files(self, directory: Path) -> str: """ Load all .md files from a directory and concatenate them. Args: directory: Path to the directory containing markdown files Returns: Concatenated content of all markdown files with section headers """ contents = [] if directory.exists(): # Sort files for consistent ordering for md_file in sorted(directory.glob("*.md")): try: content = md_file.read_text(encoding="utf-8") # Add file name as section header contents.append(f"## {md_file.stem}\n\n{content}") except Exception as e: print(f"Warning: Could not read {md_file}: {e}") if not contents: return "No reference documents found." return "\n\n---\n\n".join(contents) def get_context_summary(self) -> dict: """Return summary info about loaded documents.""" brand_path = self.base_path / "brand" channel_path = self.base_path / "channel" brand_files = list(brand_path.glob("*.md")) if brand_path.exists() else [] channel_files = list(channel_path.glob("*.md")) if channel_path.exists() else [] return { "brand_files": [f.name for f in brand_files], "channel_files": [f.name for f in channel_files], "brand_context_length": len(self.get_brand_context()), "channel_context_length": len(self.get_channel_context()), }