refactor: removes nltk from score based chunker

This commit is contained in:
sauravniraula 2025-08-05 23:54:18 +05:45
parent bf16491c73
commit 0fe272d82c
No known key found for this signature in database
GPG key ID: 60FCC1B5A5E83326
3 changed files with 93 additions and 129 deletions

View file

@ -24,7 +24,7 @@ RUN curl -fsSL https://ollama.com/install.sh | sh
# Install dependencies for FastAPI
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
pathvalidate pdfplumber nltk chromadb sqlmodel \
pathvalidate pdfplumber chromadb sqlmodel \
anthropic google-genai openai fastmcp
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu

View file

@ -26,7 +26,7 @@ RUN curl -fsSL http://ollama.com/install.sh | sh
# Install dependencies for FastAPI
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
pathvalidate pdfplumber nltk chromadb sqlmodel \
pathvalidate pdfplumber chromadb sqlmodel \
anthropic google-genai openai fastmcp
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu

View file

@ -1,133 +1,79 @@
import asyncio
from typing import List
import nltk
from models.document_chunk import DocumentChunk
try:
nltk.data.find("tokenizers/punkt", paths=["./nltk"])
except LookupError:
nltk.download("punkt", download_dir="./nltk")
class ScoreBasedChunker:
def extract_sentences(self, text: str, min_sentences: int) -> List[str]:
sentences = self.extract_sentences_markdown(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_nltk(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_by_stop_words(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_by_new_line(text)
if len(sentences) < min_sentences:
raise ValueError(
f"Only {len(sentences)} sentences found, requested {min_sentences}"
)
return sentences
def extract_sentences_markdown(self, text: str) -> List[str]:
def extract_headings(self, text: str) -> List[str]:
lines = text.split("\n")
sentences = []
headings = []
for line in lines:
line = line.strip()
if line:
if line.startswith("#"):
sentences.append(line)
else:
if line.endswith((".", "!", "?")):
sentences.append(line)
else:
sentences.append(line)
return sentences
def extract_sentences_nltk(self, text: str) -> List[str]:
sentences = nltk.sent_tokenize(text)
return sentences
def extract_sentences_by_stop_words(self, text: str) -> List[str]:
sentences = []
current_sentence = ""
for char in text:
current_sentence += char
if char in ".!?":
sentences.append(current_sentence.strip())
current_sentence = ""
if current_sentence.strip():
sentences.append(current_sentence.strip())
return [s for s in sentences if s]
def extract_sentences_by_new_line(self, text: str) -> List[str]:
sentences = text.split("\n")
result = []
for i, sentence in enumerate(sentences):
if i < len(sentences) - 1:
result.append(sentence + "\n")
else:
result.append(sentence)
return result
def score_sentences_for_heading(self, sentences: List[str]) -> List[float]:
sentences_scores = []
if line.startswith("#"):
headings.append(line)
return headings
def score_headings(self, headings: List[str]) -> List[float]:
heading_scores = []
last_heading_index = -1
first_heading_found = False
for i, sentence in enumerate(sentences):
for i, heading in enumerate(headings):
score = 0.0
heading_level = len(heading) - len(heading.lstrip("#"))
if heading_level <= 3:
score += 10.0 - (heading_level - 1) * 2.0
else:
score += 4.0 - (heading_level - 4) * 0.5
if sentence.strip().startswith("#"):
heading_level = len(sentence) - len(sentence.lstrip("#"))
if not first_heading_found:
score += 5.0
first_heading_found = True
if heading_level <= 3:
score += 10.0 - (heading_level - 1) * 2.0
else:
score += 4.0 - (heading_level - 4) * 0.5
if last_heading_index != -1:
distance = i - last_heading_index
distance_bonus = min(5.0, distance * 0.5)
score += distance_bonus
if not first_heading_found:
score += 5.0
first_heading_found = True
last_heading_index = i
heading_scores.append(score)
if last_heading_index != -1:
distance = i - last_heading_index
distance_bonus = min(5.0, distance * 0.5)
score += distance_bonus
return heading_scores
last_heading_index = i
sentences_scores.append(score)
return sentences_scores
def get_chunks(
self, sentences: List[str], sentences_scores: List[float], top_k: int = 10
def get_chunks_from_headings(
self,
text: str,
headings: List[str],
heading_scores: List[float],
top_k: int = 10,
) -> List[DocumentChunk]:
if not sentences_scores:
sentences_scores = self.score_sentences_for_heading(sentences)
if not heading_scores:
heading_scores = self.score_headings(headings)
chunks = []
heading_scores = []
heading_indices = []
for i, score in enumerate(sentences_scores):
for i, score in enumerate(heading_scores):
if score > 0:
heading_scores.append((i, score))
heading_indices.append((i, score))
if len(heading_scores) == 0:
if len(heading_indices) == 0:
return chunks
heading_scores.sort(key=lambda x: (-x[1], x[0]))
heading_indices.sort(key=lambda x: (-x[1], x[0]))
if len(heading_scores) <= top_k:
selected_headings = [idx for idx, _ in heading_scores]
selected_headings.sort()
if len(heading_indices) <= top_k:
selected_indices = [idx for idx, _ in heading_indices]
selected_indices.sort()
else:
score_groups = {}
for idx, score in heading_scores:
for idx, score in heading_indices:
rounded_score = round(score)
if rounded_score not in score_groups:
score_groups[rounded_score] = []
@ -137,62 +83,80 @@ class ScoreBasedChunker:
score_groups.items(), key=lambda x: x[0], reverse=True
)
selected_headings = []
selected_indices = []
for score, headings in sorted_groups:
headings.sort()
remaining_needed = top_k - len(selected_headings)
for score, indices in sorted_groups:
indices.sort()
remaining_needed = top_k - len(selected_indices)
if remaining_needed <= 0:
break
if len(headings) <= remaining_needed:
selected_headings.extend(headings)
if len(indices) <= remaining_needed:
selected_indices.extend(indices)
else:
if remaining_needed == 1:
mid_idx = len(headings) // 2
selected_headings.append(headings[mid_idx])
mid_idx = len(indices) // 2
selected_indices.append(indices[mid_idx])
elif remaining_needed == 2:
selected_headings.append(headings[0])
selected_headings.append(headings[-1])
selected_indices.append(indices[0])
selected_indices.append(indices[-1])
else:
step = (len(headings) - 1) / (remaining_needed - 1)
step = (len(indices) - 1) / (remaining_needed - 1)
for i in range(remaining_needed):
index = int(round(i * step))
if index < len(headings):
selected_headings.append(headings[index])
if index < len(indices):
selected_indices.append(indices[index])
selected_headings.sort()
selected_indices.sort()
for i, heading_idx in enumerate(selected_headings):
heading = sentences[heading_idx]
if i + 1 < len(selected_headings):
next_heading_idx = selected_headings[i + 1]
content_end = next_heading_idx
lines = text.split("\n")
heading_positions = {}
for i, line in enumerate(lines):
line_stripped = line.strip()
if line_stripped.startswith("#"):
for heading_idx, heading in enumerate(headings):
if heading == line_stripped and heading_idx not in heading_positions:
heading_positions[heading_idx] = i
break
for i, heading_idx in enumerate(selected_indices):
if heading_idx not in heading_positions:
continue
heading = headings[heading_idx]
heading_line_idx = heading_positions[heading_idx]
if i + 1 < len(selected_indices):
next_heading_idx = selected_indices[i + 1]
if next_heading_idx in heading_positions:
next_heading_line_idx = heading_positions[next_heading_idx]
content_end = next_heading_line_idx
else:
content_end = len(lines)
else:
content_end = len(sentences)
content_end = len(lines)
content_sentences = sentences[heading_idx + 1 : content_end]
content = " ".join(content_sentences).strip()
content_lines = lines[heading_line_idx + 1 : content_end]
content = "\n".join(content_lines).strip()
chunk = DocumentChunk(
heading=heading,
content=content,
heading_index=heading_idx,
score=sentences_scores[heading_idx],
score=heading_scores[heading_idx],
)
chunks.append(chunk)
return chunks
async def get_n_chunks(self, text: str, n: int) -> List[DocumentChunk]:
sentences = await asyncio.to_thread(self.extract_sentences, text, n)
sentences_scores = await asyncio.to_thread(
self.score_sentences_for_heading, sentences
)
headings = await asyncio.to_thread(self.extract_headings, text)
heading_scores = await asyncio.to_thread(self.score_headings, headings)
chunks = await asyncio.to_thread(
self.get_chunks, sentences, sentences_scores, n
self.get_chunks_from_headings, text, headings, heading_scores, n
)
if len(chunks) < n:
raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")