presenton/servers/fastapi/services/score_based_chunker.py

197 lines
6.6 KiB
Python

import asyncio
from typing import List
import nltk
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", download_dir="./nltk")
class ScoreBasedChunker:
def extract_sentences(self, text: str, min_sentences: int) -> List[str]:
sentences = self.extract_sentences_markdown(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_nltk(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_by_stop_words(text)
if len(sentences) < min_sentences:
sentences = self.extract_sentences_by_new_line(text)
if len(sentences) < min_sentences:
raise ValueError(
f"Only {len(sentences)} sentences found, requested {min_sentences}"
)
return sentences
def extract_sentences_markdown(self, text: str) -> List[str]:
lines = text.split("\n")
sentences = []
for line in lines:
line = line.strip()
if line:
if line.startswith("#"):
sentences.append(line)
else:
if line.endswith((".", "!", "?")):
sentences.append(line)
else:
sentences.append(line)
return sentences
def extract_sentences_nltk(self, text: str) -> List[str]:
sentences = nltk.sent_tokenize(text)
return sentences
def extract_sentences_by_stop_words(self, text: str) -> List[str]:
sentences = []
current_sentence = ""
for char in text:
current_sentence += char
if char in ".!?":
sentences.append(current_sentence.strip())
current_sentence = ""
if current_sentence.strip():
sentences.append(current_sentence.strip())
return [s for s in sentences if s]
def extract_sentences_by_new_line(self, text: str) -> List[str]:
sentences = text.split("\n")
result = []
for i, sentence in enumerate(sentences):
if i < len(sentences) - 1:
result.append(sentence + "\n")
else:
result.append(sentence)
return result
def score_sentences_for_heading(self, sentences: List[str]) -> List[float]:
sentences_scores = []
last_heading_index = -1
first_heading_found = False
for i, sentence in enumerate(sentences):
score = 0.0
if sentence.strip().startswith("#"):
heading_level = len(sentence) - len(sentence.lstrip("#"))
if heading_level <= 3:
score += 10.0 - (heading_level - 1) * 2.0
else:
score += 4.0 - (heading_level - 4) * 0.5
if not first_heading_found:
score += 5.0
first_heading_found = True
if last_heading_index != -1:
distance = i - last_heading_index
distance_bonus = min(5.0, distance * 0.5)
score += distance_bonus
last_heading_index = i
sentences_scores.append(score)
return sentences_scores
def get_chunks(
self, sentences: List[str], sentences_scores: List[float], top_k: int = 10
) -> List[dict]:
if not sentences_scores:
sentences_scores = self.score_sentences_for_heading(sentences)
chunks = []
heading_scores = []
for i, score in enumerate(sentences_scores):
if score > 0:
heading_scores.append((i, score))
if len(heading_scores) == 0:
return chunks
heading_scores.sort(key=lambda x: (-x[1], x[0]))
if len(heading_scores) <= top_k:
selected_headings = [idx for idx, _ in heading_scores]
selected_headings.sort()
else:
score_groups = {}
for idx, score in heading_scores:
rounded_score = round(score)
if rounded_score not in score_groups:
score_groups[rounded_score] = []
score_groups[rounded_score].append(idx)
sorted_groups = sorted(
score_groups.items(), key=lambda x: x[0], reverse=True
)
selected_headings = []
for score, headings in sorted_groups:
headings.sort()
remaining_needed = top_k - len(selected_headings)
if remaining_needed <= 0:
break
if len(headings) <= remaining_needed:
selected_headings.extend(headings)
else:
if remaining_needed == 1:
mid_idx = len(headings) // 2
selected_headings.append(headings[mid_idx])
elif remaining_needed == 2:
selected_headings.append(headings[0])
selected_headings.append(headings[-1])
else:
step = (len(headings) - 1) / (remaining_needed - 1)
for i in range(remaining_needed):
index = int(round(i * step))
if index < len(headings):
selected_headings.append(headings[index])
selected_headings.sort()
for i, heading_idx in enumerate(selected_headings):
heading = sentences[heading_idx]
if i + 1 < len(selected_headings):
next_heading_idx = selected_headings[i + 1]
content_end = next_heading_idx
else:
content_end = len(sentences)
content_sentences = sentences[heading_idx + 1 : content_end]
content = " ".join(content_sentences).strip()
chunk = {
"heading": heading,
"content": content,
"heading_index": heading_idx,
"score": sentences_scores[heading_idx],
}
chunks.append(chunk)
return chunks
async def get_n_chunks(self, text: str, n: int) -> List[dict]:
sentences = await asyncio.to_thread(self.extract_sentences, text, n)
sentences_scores = await asyncio.to_thread(
self.score_sentences_for_heading, sentences
)
chunks = await asyncio.to_thread(
self.get_chunks, sentences, sentences_scores, n
)
if len(chunks) < n:
raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")
return chunks