163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
import asyncio
|
|
from typing import List
|
|
|
|
from models.document_chunk import DocumentChunk
|
|
|
|
|
|
class ScoreBasedChunker:
|
|
|
|
def extract_headings(self, text: str) -> List[str]:
|
|
lines = text.split("\n")
|
|
headings = []
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line.startswith("#"):
|
|
headings.append(line)
|
|
|
|
return headings
|
|
|
|
def score_headings(self, headings: List[str]) -> List[float]:
|
|
heading_scores = []
|
|
last_heading_index = -1
|
|
first_heading_found = False
|
|
|
|
for i, heading in enumerate(headings):
|
|
score = 0.0
|
|
|
|
heading_level = len(heading) - len(heading.lstrip("#"))
|
|
|
|
if heading_level <= 3:
|
|
score += 10.0 - (heading_level - 1) * 2.0
|
|
else:
|
|
score += 4.0 - (heading_level - 4) * 0.5
|
|
|
|
if not first_heading_found:
|
|
score += 5.0
|
|
first_heading_found = True
|
|
|
|
if last_heading_index != -1:
|
|
distance = i - last_heading_index
|
|
distance_bonus = min(5.0, distance * 0.5)
|
|
score += distance_bonus
|
|
|
|
last_heading_index = i
|
|
heading_scores.append(score)
|
|
|
|
return heading_scores
|
|
|
|
def get_chunks_from_headings(
|
|
self,
|
|
text: str,
|
|
headings: List[str],
|
|
heading_scores: List[float],
|
|
top_k: int = 10,
|
|
) -> List[DocumentChunk]:
|
|
if not heading_scores:
|
|
heading_scores = self.score_headings(headings)
|
|
|
|
chunks = []
|
|
heading_indices = []
|
|
|
|
for i, score in enumerate(heading_scores):
|
|
if score > 0:
|
|
heading_indices.append((i, score))
|
|
|
|
if len(heading_indices) == 0:
|
|
return chunks
|
|
|
|
heading_indices.sort(key=lambda x: (-x[1], x[0]))
|
|
|
|
if len(heading_indices) <= top_k:
|
|
selected_indices = [idx for idx, _ in heading_indices]
|
|
selected_indices.sort()
|
|
else:
|
|
score_groups = {}
|
|
for idx, score in heading_indices:
|
|
rounded_score = round(score)
|
|
if rounded_score not in score_groups:
|
|
score_groups[rounded_score] = []
|
|
score_groups[rounded_score].append(idx)
|
|
|
|
sorted_groups = sorted(
|
|
score_groups.items(), key=lambda x: x[0], reverse=True
|
|
)
|
|
|
|
selected_indices = []
|
|
|
|
for score, indices in sorted_groups:
|
|
indices.sort()
|
|
remaining_needed = top_k - len(selected_indices)
|
|
|
|
if remaining_needed <= 0:
|
|
break
|
|
|
|
if len(indices) <= remaining_needed:
|
|
selected_indices.extend(indices)
|
|
else:
|
|
if remaining_needed == 1:
|
|
mid_idx = len(indices) // 2
|
|
selected_indices.append(indices[mid_idx])
|
|
elif remaining_needed == 2:
|
|
selected_indices.append(indices[0])
|
|
selected_indices.append(indices[-1])
|
|
else:
|
|
step = (len(indices) - 1) / (remaining_needed - 1)
|
|
|
|
for i in range(remaining_needed):
|
|
index = int(round(i * step))
|
|
if index < len(indices):
|
|
selected_indices.append(indices[index])
|
|
|
|
selected_indices.sort()
|
|
|
|
lines = text.split("\n")
|
|
heading_positions = {}
|
|
|
|
for i, line in enumerate(lines):
|
|
line_stripped = line.strip()
|
|
if line_stripped.startswith("#"):
|
|
for heading_idx, heading in enumerate(headings):
|
|
if heading == line_stripped and heading_idx not in heading_positions:
|
|
heading_positions[heading_idx] = i
|
|
break
|
|
|
|
for i, heading_idx in enumerate(selected_indices):
|
|
if heading_idx not in heading_positions:
|
|
continue
|
|
|
|
heading = headings[heading_idx]
|
|
heading_line_idx = heading_positions[heading_idx]
|
|
|
|
if i + 1 < len(selected_indices):
|
|
next_heading_idx = selected_indices[i + 1]
|
|
if next_heading_idx in heading_positions:
|
|
next_heading_line_idx = heading_positions[next_heading_idx]
|
|
content_end = next_heading_line_idx
|
|
else:
|
|
content_end = len(lines)
|
|
else:
|
|
content_end = len(lines)
|
|
|
|
content_lines = lines[heading_line_idx + 1 : content_end]
|
|
content = "\n".join(content_lines).strip()
|
|
|
|
chunk = DocumentChunk(
|
|
heading=heading,
|
|
content=content,
|
|
heading_index=heading_idx,
|
|
score=heading_scores[heading_idx],
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
async def get_n_chunks(self, text: str, n: int) -> List[DocumentChunk]:
|
|
headings = await asyncio.to_thread(self.extract_headings, text)
|
|
heading_scores = await asyncio.to_thread(self.score_headings, headings)
|
|
chunks = await asyncio.to_thread(
|
|
self.get_chunks_from_headings, text, headings, heading_scores, n
|
|
)
|
|
if len(chunks) < n:
|
|
raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")
|
|
return chunks
|