refactor: removes nltk from score based chunker
This commit is contained in:
parent
bf16491c73
commit
0fe272d82c
3 changed files with 93 additions and 129 deletions
|
|
@ -24,7 +24,7 @@ RUN curl -fsSL https://ollama.com/install.sh | sh
|
|||
|
||||
# Install dependencies for FastAPI
|
||||
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
|
||||
pathvalidate pdfplumber nltk chromadb sqlmodel \
|
||||
pathvalidate pdfplumber chromadb sqlmodel \
|
||||
anthropic google-genai openai fastmcp
|
||||
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ RUN curl -fsSL http://ollama.com/install.sh | sh
|
|||
|
||||
# Install dependencies for FastAPI
|
||||
RUN pip install aiohttp aiomysql aiosqlite asyncpg fastapi[standard] \
|
||||
pathvalidate pdfplumber nltk chromadb sqlmodel \
|
||||
pathvalidate pdfplumber chromadb sqlmodel \
|
||||
anthropic google-genai openai fastmcp
|
||||
RUN pip install docling --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
|
|
|
|||
|
|
@ -1,133 +1,79 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
import nltk
|
||||
|
||||
from models.document_chunk import DocumentChunk
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt", paths=["./nltk"])
|
||||
except LookupError:
|
||||
nltk.download("punkt", download_dir="./nltk")
|
||||
|
||||
|
||||
class ScoreBasedChunker:
|
||||
|
||||
def extract_sentences(self, text: str, min_sentences: int) -> List[str]:
|
||||
sentences = self.extract_sentences_markdown(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_nltk(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_by_stop_words(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_by_new_line(text)
|
||||
if len(sentences) < min_sentences:
|
||||
raise ValueError(
|
||||
f"Only {len(sentences)} sentences found, requested {min_sentences}"
|
||||
)
|
||||
return sentences
|
||||
|
||||
def extract_sentences_markdown(self, text: str) -> List[str]:
|
||||
def extract_headings(self, text: str) -> List[str]:
|
||||
lines = text.split("\n")
|
||||
sentences = []
|
||||
|
||||
headings = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
if line.startswith("#"):
|
||||
sentences.append(line)
|
||||
else:
|
||||
if line.endswith((".", "!", "?")):
|
||||
sentences.append(line)
|
||||
else:
|
||||
sentences.append(line)
|
||||
|
||||
return sentences
|
||||
|
||||
def extract_sentences_nltk(self, text: str) -> List[str]:
|
||||
sentences = nltk.sent_tokenize(text)
|
||||
return sentences
|
||||
|
||||
def extract_sentences_by_stop_words(self, text: str) -> List[str]:
|
||||
sentences = []
|
||||
current_sentence = ""
|
||||
|
||||
for char in text:
|
||||
current_sentence += char
|
||||
if char in ".!?":
|
||||
sentences.append(current_sentence.strip())
|
||||
current_sentence = ""
|
||||
|
||||
if current_sentence.strip():
|
||||
sentences.append(current_sentence.strip())
|
||||
|
||||
return [s for s in sentences if s]
|
||||
|
||||
def extract_sentences_by_new_line(self, text: str) -> List[str]:
|
||||
sentences = text.split("\n")
|
||||
result = []
|
||||
for i, sentence in enumerate(sentences):
|
||||
if i < len(sentences) - 1:
|
||||
result.append(sentence + "\n")
|
||||
else:
|
||||
result.append(sentence)
|
||||
return result
|
||||
|
||||
def score_sentences_for_heading(self, sentences: List[str]) -> List[float]:
|
||||
sentences_scores = []
|
||||
if line.startswith("#"):
|
||||
headings.append(line)
|
||||
|
||||
return headings
|
||||
|
||||
def score_headings(self, headings: List[str]) -> List[float]:
|
||||
heading_scores = []
|
||||
last_heading_index = -1
|
||||
first_heading_found = False
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
for i, heading in enumerate(headings):
|
||||
score = 0.0
|
||||
|
||||
heading_level = len(heading) - len(heading.lstrip("#"))
|
||||
|
||||
if heading_level <= 3:
|
||||
score += 10.0 - (heading_level - 1) * 2.0
|
||||
else:
|
||||
score += 4.0 - (heading_level - 4) * 0.5
|
||||
|
||||
if sentence.strip().startswith("#"):
|
||||
heading_level = len(sentence) - len(sentence.lstrip("#"))
|
||||
if not first_heading_found:
|
||||
score += 5.0
|
||||
first_heading_found = True
|
||||
|
||||
if heading_level <= 3:
|
||||
score += 10.0 - (heading_level - 1) * 2.0
|
||||
else:
|
||||
score += 4.0 - (heading_level - 4) * 0.5
|
||||
if last_heading_index != -1:
|
||||
distance = i - last_heading_index
|
||||
distance_bonus = min(5.0, distance * 0.5)
|
||||
score += distance_bonus
|
||||
|
||||
if not first_heading_found:
|
||||
score += 5.0
|
||||
first_heading_found = True
|
||||
last_heading_index = i
|
||||
heading_scores.append(score)
|
||||
|
||||
if last_heading_index != -1:
|
||||
distance = i - last_heading_index
|
||||
distance_bonus = min(5.0, distance * 0.5)
|
||||
score += distance_bonus
|
||||
return heading_scores
|
||||
|
||||
last_heading_index = i
|
||||
|
||||
sentences_scores.append(score)
|
||||
|
||||
return sentences_scores
|
||||
|
||||
def get_chunks(
|
||||
self, sentences: List[str], sentences_scores: List[float], top_k: int = 10
|
||||
def get_chunks_from_headings(
|
||||
self,
|
||||
text: str,
|
||||
headings: List[str],
|
||||
heading_scores: List[float],
|
||||
top_k: int = 10,
|
||||
) -> List[DocumentChunk]:
|
||||
if not sentences_scores:
|
||||
sentences_scores = self.score_sentences_for_heading(sentences)
|
||||
if not heading_scores:
|
||||
heading_scores = self.score_headings(headings)
|
||||
|
||||
chunks = []
|
||||
heading_scores = []
|
||||
heading_indices = []
|
||||
|
||||
for i, score in enumerate(sentences_scores):
|
||||
for i, score in enumerate(heading_scores):
|
||||
if score > 0:
|
||||
heading_scores.append((i, score))
|
||||
heading_indices.append((i, score))
|
||||
|
||||
if len(heading_scores) == 0:
|
||||
if len(heading_indices) == 0:
|
||||
return chunks
|
||||
|
||||
heading_scores.sort(key=lambda x: (-x[1], x[0]))
|
||||
heading_indices.sort(key=lambda x: (-x[1], x[0]))
|
||||
|
||||
if len(heading_scores) <= top_k:
|
||||
selected_headings = [idx for idx, _ in heading_scores]
|
||||
selected_headings.sort()
|
||||
if len(heading_indices) <= top_k:
|
||||
selected_indices = [idx for idx, _ in heading_indices]
|
||||
selected_indices.sort()
|
||||
else:
|
||||
score_groups = {}
|
||||
for idx, score in heading_scores:
|
||||
for idx, score in heading_indices:
|
||||
rounded_score = round(score)
|
||||
if rounded_score not in score_groups:
|
||||
score_groups[rounded_score] = []
|
||||
|
|
@ -137,62 +83,80 @@ class ScoreBasedChunker:
|
|||
score_groups.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
|
||||
selected_headings = []
|
||||
selected_indices = []
|
||||
|
||||
for score, headings in sorted_groups:
|
||||
headings.sort()
|
||||
remaining_needed = top_k - len(selected_headings)
|
||||
for score, indices in sorted_groups:
|
||||
indices.sort()
|
||||
remaining_needed = top_k - len(selected_indices)
|
||||
|
||||
if remaining_needed <= 0:
|
||||
break
|
||||
|
||||
if len(headings) <= remaining_needed:
|
||||
selected_headings.extend(headings)
|
||||
if len(indices) <= remaining_needed:
|
||||
selected_indices.extend(indices)
|
||||
else:
|
||||
if remaining_needed == 1:
|
||||
mid_idx = len(headings) // 2
|
||||
selected_headings.append(headings[mid_idx])
|
||||
mid_idx = len(indices) // 2
|
||||
selected_indices.append(indices[mid_idx])
|
||||
elif remaining_needed == 2:
|
||||
selected_headings.append(headings[0])
|
||||
selected_headings.append(headings[-1])
|
||||
selected_indices.append(indices[0])
|
||||
selected_indices.append(indices[-1])
|
||||
else:
|
||||
step = (len(headings) - 1) / (remaining_needed - 1)
|
||||
step = (len(indices) - 1) / (remaining_needed - 1)
|
||||
|
||||
for i in range(remaining_needed):
|
||||
index = int(round(i * step))
|
||||
if index < len(headings):
|
||||
selected_headings.append(headings[index])
|
||||
if index < len(indices):
|
||||
selected_indices.append(indices[index])
|
||||
|
||||
selected_headings.sort()
|
||||
selected_indices.sort()
|
||||
|
||||
for i, heading_idx in enumerate(selected_headings):
|
||||
heading = sentences[heading_idx]
|
||||
|
||||
if i + 1 < len(selected_headings):
|
||||
next_heading_idx = selected_headings[i + 1]
|
||||
content_end = next_heading_idx
|
||||
lines = text.split("\n")
|
||||
heading_positions = {}
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line_stripped = line.strip()
|
||||
if line_stripped.startswith("#"):
|
||||
for heading_idx, heading in enumerate(headings):
|
||||
if heading == line_stripped and heading_idx not in heading_positions:
|
||||
heading_positions[heading_idx] = i
|
||||
break
|
||||
|
||||
for i, heading_idx in enumerate(selected_indices):
|
||||
if heading_idx not in heading_positions:
|
||||
continue
|
||||
|
||||
heading = headings[heading_idx]
|
||||
heading_line_idx = heading_positions[heading_idx]
|
||||
|
||||
if i + 1 < len(selected_indices):
|
||||
next_heading_idx = selected_indices[i + 1]
|
||||
if next_heading_idx in heading_positions:
|
||||
next_heading_line_idx = heading_positions[next_heading_idx]
|
||||
content_end = next_heading_line_idx
|
||||
else:
|
||||
content_end = len(lines)
|
||||
else:
|
||||
content_end = len(sentences)
|
||||
content_end = len(lines)
|
||||
|
||||
content_sentences = sentences[heading_idx + 1 : content_end]
|
||||
content = " ".join(content_sentences).strip()
|
||||
content_lines = lines[heading_line_idx + 1 : content_end]
|
||||
content = "\n".join(content_lines).strip()
|
||||
|
||||
chunk = DocumentChunk(
|
||||
heading=heading,
|
||||
content=content,
|
||||
heading_index=heading_idx,
|
||||
score=sentences_scores[heading_idx],
|
||||
score=heading_scores[heading_idx],
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def get_n_chunks(self, text: str, n: int) -> List[DocumentChunk]:
|
||||
sentences = await asyncio.to_thread(self.extract_sentences, text, n)
|
||||
sentences_scores = await asyncio.to_thread(
|
||||
self.score_sentences_for_heading, sentences
|
||||
)
|
||||
headings = await asyncio.to_thread(self.extract_headings, text)
|
||||
heading_scores = await asyncio.to_thread(self.score_headings, headings)
|
||||
chunks = await asyncio.to_thread(
|
||||
self.get_chunks, sentences, sentences_scores, n
|
||||
self.get_chunks_from_headings, text, headings, heading_scores, n
|
||||
)
|
||||
if len(chunks) < n:
|
||||
raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue