diff --git a/Dockerfile b/Dockerfile index 3815c237..220d6876 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,7 +24,7 @@ RUN curl -fsSL https://ollama.com/install.sh | sh # Install dependencies for FastAPI RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \ - pathvalidate pdfplumber nltk chromadb sqlmodel \ + pathvalidate pdfplumber chromadb sqlmodel \ anthropic google-genai openai fastmcp RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/Dockerfile.dev b/Dockerfile.dev index f4e860a1..4f3e80e5 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -26,7 +26,7 @@ RUN curl -fsSL http://ollama.com/install.sh | sh # Install dependencies for FastAPI RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \ - pathvalidate pdfplumber nltk chromadb sqlmodel \ + pathvalidate pdfplumber chromadb sqlmodel \ anthropic google-genai openai fastmcp RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/servers/fastapi/services/score_based_chunker.py b/servers/fastapi/services/score_based_chunker.py index 0af245a2..c67de796 100644 --- a/servers/fastapi/services/score_based_chunker.py +++ b/servers/fastapi/services/score_based_chunker.py @@ -1,133 +1,79 @@ import asyncio from typing import List -import nltk from models.document_chunk import DocumentChunk -try: - nltk.data.find("tokenizers/punkt", paths=["./nltk"]) -except LookupError: - nltk.download("punkt", download_dir="./nltk") - class ScoreBasedChunker: - def extract_sentences(self, text: str, min_sentences: int) -> List[str]: - sentences = self.extract_sentences_markdown(text) - if len(sentences) < min_sentences: - sentences = self.extract_sentences_nltk(text) - if len(sentences) < min_sentences: - sentences = self.extract_sentences_by_stop_words(text) - if len(sentences) < min_sentences: - sentences = self.extract_sentences_by_new_line(text) - if len(sentences) < min_sentences: - raise ValueError( - f"Only {len(sentences)} sentences found, requested {min_sentences}" - ) - return sentences - - def extract_sentences_markdown(self, text: str) -> List[str]: + def extract_headings(self, text: str) -> List[str]: lines = text.split("\n") - sentences = [] - + headings = [] + for line in lines: line = line.strip() - if line: - if line.startswith("#"): - sentences.append(line) - else: - if line.endswith((".", "!", "?")): - sentences.append(line) - else: - sentences.append(line) - - return sentences - - def extract_sentences_nltk(self, text: str) -> List[str]: - sentences = nltk.sent_tokenize(text) - return sentences - - def extract_sentences_by_stop_words(self, text: str) -> List[str]: - sentences = [] - current_sentence = "" - - for char in text: - current_sentence += char - if char in ".!?": - sentences.append(current_sentence.strip()) - current_sentence = "" - - if current_sentence.strip(): - sentences.append(current_sentence.strip()) - - return [s for s in sentences if s] - - def extract_sentences_by_new_line(self, text: str) -> List[str]: - sentences = text.split("\n") - result = [] - for i, sentence in enumerate(sentences): - if i < len(sentences) - 1: - result.append(sentence + "\n") - else: - result.append(sentence) - return result - - def score_sentences_for_heading(self, sentences: List[str]) -> List[float]: - sentences_scores = [] + if line.startswith("#"): + headings.append(line) + + return headings + def score_headings(self, headings: List[str]) -> List[float]: + heading_scores = [] last_heading_index = -1 first_heading_found = False - for i, sentence in enumerate(sentences): + for i, heading in enumerate(headings): score = 0.0 + + heading_level = len(heading) - len(heading.lstrip("#")) + + if heading_level <= 3: + score += 10.0 - (heading_level - 1) * 2.0 + else: + score += 4.0 - (heading_level - 4) * 0.5 - if sentence.strip().startswith("#"): - heading_level = len(sentence) - len(sentence.lstrip("#")) + if not first_heading_found: + score += 5.0 + first_heading_found = True - if heading_level <= 3: - score += 10.0 - (heading_level - 1) * 2.0 - else: - score += 4.0 - (heading_level - 4) * 0.5 + if last_heading_index != -1: + distance = i - last_heading_index + distance_bonus = min(5.0, distance * 0.5) + score += distance_bonus - if not first_heading_found: - score += 5.0 - first_heading_found = True + last_heading_index = i + heading_scores.append(score) - if last_heading_index != -1: - distance = i - last_heading_index - distance_bonus = min(5.0, distance * 0.5) - score += distance_bonus + return heading_scores - last_heading_index = i - - sentences_scores.append(score) - - return sentences_scores - - def get_chunks( - self, sentences: List[str], sentences_scores: List[float], top_k: int = 10 + def get_chunks_from_headings( + self, + text: str, + headings: List[str], + heading_scores: List[float], + top_k: int = 10, ) -> List[DocumentChunk]: - if not sentences_scores: - sentences_scores = self.score_sentences_for_heading(sentences) + if not heading_scores: + heading_scores = self.score_headings(headings) chunks = [] - heading_scores = [] + heading_indices = [] - for i, score in enumerate(sentences_scores): + for i, score in enumerate(heading_scores): if score > 0: - heading_scores.append((i, score)) + heading_indices.append((i, score)) - if len(heading_scores) == 0: + if len(heading_indices) == 0: return chunks - heading_scores.sort(key=lambda x: (-x[1], x[0])) + heading_indices.sort(key=lambda x: (-x[1], x[0])) - if len(heading_scores) <= top_k: - selected_headings = [idx for idx, _ in heading_scores] - selected_headings.sort() + if len(heading_indices) <= top_k: + selected_indices = [idx for idx, _ in heading_indices] + selected_indices.sort() else: score_groups = {} - for idx, score in heading_scores: + for idx, score in heading_indices: rounded_score = round(score) if rounded_score not in score_groups: score_groups[rounded_score] = [] @@ -137,62 +83,80 @@ class ScoreBasedChunker: score_groups.items(), key=lambda x: x[0], reverse=True ) - selected_headings = [] + selected_indices = [] - for score, headings in sorted_groups: - headings.sort() - remaining_needed = top_k - len(selected_headings) + for score, indices in sorted_groups: + indices.sort() + remaining_needed = top_k - len(selected_indices) if remaining_needed <= 0: break - if len(headings) <= remaining_needed: - selected_headings.extend(headings) + if len(indices) <= remaining_needed: + selected_indices.extend(indices) else: if remaining_needed == 1: - mid_idx = len(headings) // 2 - selected_headings.append(headings[mid_idx]) + mid_idx = len(indices) // 2 + selected_indices.append(indices[mid_idx]) elif remaining_needed == 2: - selected_headings.append(headings[0]) - selected_headings.append(headings[-1]) + selected_indices.append(indices[0]) + selected_indices.append(indices[-1]) else: - step = (len(headings) - 1) / (remaining_needed - 1) + step = (len(indices) - 1) / (remaining_needed - 1) for i in range(remaining_needed): index = int(round(i * step)) - if index < len(headings): - selected_headings.append(headings[index]) + if index < len(indices): + selected_indices.append(indices[index]) - selected_headings.sort() + selected_indices.sort() - for i, heading_idx in enumerate(selected_headings): - heading = sentences[heading_idx] - - if i + 1 < len(selected_headings): - next_heading_idx = selected_headings[i + 1] - content_end = next_heading_idx + lines = text.split("\n") + heading_positions = {} + + for i, line in enumerate(lines): + line_stripped = line.strip() + if line_stripped.startswith("#"): + for heading_idx, heading in enumerate(headings): + if heading == line_stripped and heading_idx not in heading_positions: + heading_positions[heading_idx] = i + break + + for i, heading_idx in enumerate(selected_indices): + if heading_idx not in heading_positions: + continue + + heading = headings[heading_idx] + heading_line_idx = heading_positions[heading_idx] + + if i + 1 < len(selected_indices): + next_heading_idx = selected_indices[i + 1] + if next_heading_idx in heading_positions: + next_heading_line_idx = heading_positions[next_heading_idx] + content_end = next_heading_line_idx + else: + content_end = len(lines) else: - content_end = len(sentences) + content_end = len(lines) - content_sentences = sentences[heading_idx + 1 : content_end] - content = " ".join(content_sentences).strip() + content_lines = lines[heading_line_idx + 1 : content_end] + content = "\n".join(content_lines).strip() chunk = DocumentChunk( heading=heading, content=content, heading_index=heading_idx, - score=sentences_scores[heading_idx], + score=heading_scores[heading_idx], ) chunks.append(chunk) + return chunks async def get_n_chunks(self, text: str, n: int) -> List[DocumentChunk]: - sentences = await asyncio.to_thread(self.extract_sentences, text, n) - sentences_scores = await asyncio.to_thread( - self.score_sentences_for_heading, sentences - ) + headings = await asyncio.to_thread(self.extract_headings, text) + heading_scores = await asyncio.to_thread(self.score_headings, headings) chunks = await asyncio.to_thread( - self.get_chunks, sentences, sentences_scores, n + self.get_chunks_from_headings, text, headings, heading_scores, n ) if len(chunks) < n: raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")