feat(fastapi): adds a score based chunker to chunk documents for outlines
This commit is contained in:
parent
f299cad078
commit
e3779502bf
7 changed files with 3175 additions and 57 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -12,4 +12,5 @@ tmp
|
|||
debug
|
||||
.fastembed_cache
|
||||
my-doc.txt
|
||||
generated_models
|
||||
generated_models
|
||||
nltk
|
||||
1
servers/fastapi/.python-version
Normal file
1
servers/fastapi/.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.11
|
||||
6
servers/fastapi/main.py
Normal file
6
servers/fastapi/main.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
def main():
|
||||
print("Hello from presenton-backend!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
servers/fastapi/pyproject.toml
Normal file
27
servers/fastapi/pyproject.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[project]
|
||||
name = "presenton-backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11,<3.12"
|
||||
dependencies = [
|
||||
"aiohttp>=3.12.15",
|
||||
"aiomysql>=0.2.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"anthropic>=0.60.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"chromadb>=1.0.15",
|
||||
"docling>=2.43.0",
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"google-genai>=1.28.0",
|
||||
"nltk>=3.9.1",
|
||||
"openai>=1.98.0",
|
||||
"pathvalidate>=3.3.1",
|
||||
"pdfplumber>=0.11.7",
|
||||
"python-pptx>=1.0.2",
|
||||
"redis>=6.2.0",
|
||||
"sqlmodel>=0.0.24",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
|
@ -1,143 +1,170 @@
|
|||
accelerate==1.9.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.12.14
|
||||
aiohttp==3.12.15
|
||||
aiomysql==0.2.0
|
||||
aiosignal==1.4.0
|
||||
aiosqlite==0.21.0
|
||||
annotated-types==0.7.0
|
||||
anthropic==0.60.0
|
||||
anyio==4.9.0
|
||||
argcomplete==3.6.2
|
||||
async-timeout==5.0.1
|
||||
asyncpg==0.30.0
|
||||
attrs==25.3.0
|
||||
backoff==2.2.1
|
||||
bcrypt==4.3.0
|
||||
black==25.1.0
|
||||
build==1.2.2.post1
|
||||
beautifulsoup4==4.13.4
|
||||
build==1.3.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.7.14
|
||||
certifi==2025.8.3
|
||||
cffi==1.17.1
|
||||
charset-normalizer==3.4.2
|
||||
chromadb==1.0.15
|
||||
click==8.2.1
|
||||
coloredlogs==15.0.1
|
||||
cryptography==45.0.5
|
||||
dill==0.4.0
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
docling==2.43.0
|
||||
docling-core==2.44.1
|
||||
docling-ibm-models==3.9.0
|
||||
docling-parse==4.1.0
|
||||
durationpy==0.10
|
||||
email_validator==2.2.0
|
||||
easyocr==1.7.2
|
||||
email-validator==2.2.0
|
||||
et-xmlfile==2.0.0
|
||||
fastapi==0.116.1
|
||||
fastapi-cli==0.0.8
|
||||
fastapi-cloud-cli==0.1.4
|
||||
fastembed==0.7.1
|
||||
fastapi-cloud-cli==0.1.5
|
||||
filelock==3.18.0
|
||||
filetype==1.2.0
|
||||
flatbuffers==25.2.10
|
||||
frozenlist==1.7.0
|
||||
fsspec==2025.7.0
|
||||
genson==1.3.0
|
||||
google-auth==2.40.3
|
||||
google-genai==1.25.0
|
||||
google-genai==1.28.0
|
||||
googleapis-common-protos==1.70.0
|
||||
greenlet==3.2.3
|
||||
grpcio==1.74.0
|
||||
h11==0.16.0
|
||||
h2==4.2.0
|
||||
hf-xet==1.1.5
|
||||
hpack==4.1.0
|
||||
httpcore==1.0.9
|
||||
httptools==0.6.4
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.34.1
|
||||
huggingface-hub==0.34.3
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.1.0
|
||||
idna==3.10
|
||||
importlib_metadata==8.7.0
|
||||
importlib_resources==6.5.2
|
||||
inflect==7.5.0
|
||||
iniconfig==2.1.0
|
||||
isort==6.0.1
|
||||
Jinja2==3.1.6
|
||||
imageio==2.37.0
|
||||
importlib-metadata==8.7.0
|
||||
importlib-resources==6.5.2
|
||||
jinja2==3.1.6
|
||||
jiter==0.10.0
|
||||
joblib==1.5.1
|
||||
jsonlines==3.1.0
|
||||
jsonref==1.1.0
|
||||
jsonschema==4.25.0
|
||||
jsonschema-specifications==2025.4.1
|
||||
kubernetes==33.1.0
|
||||
loguru==0.7.3
|
||||
lxml==6.0.0
|
||||
latex2mathml==3.78.0
|
||||
lazy-loader==0.4
|
||||
lxml==5.4.0
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
marko==2.1.4
|
||||
markupsafe==3.0.1
|
||||
mdurl==0.1.2
|
||||
mmh3==5.1.0
|
||||
more-itertools==10.7.0
|
||||
mmh3==5.2.0
|
||||
mpire==2.10.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.6.3
|
||||
mypy_extensions==1.1.0
|
||||
multiprocess==0.70.18
|
||||
networkx==3.5
|
||||
ninja==1.11.1.4
|
||||
nltk==3.9.1
|
||||
numpy==2.3.2
|
||||
oauthlib==3.3.1
|
||||
onnxruntime==1.22.1
|
||||
openai==1.95.1
|
||||
opentelemetry-api==1.35.0
|
||||
opentelemetry-exporter-otlp-proto-common==1.35.0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.35.0
|
||||
opentelemetry-proto==1.35.0
|
||||
opentelemetry-sdk==1.35.0
|
||||
opentelemetry-semantic-conventions==0.56b0
|
||||
openai==1.98.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
openpyxl==3.1.5
|
||||
opentelemetry-api==1.36.0
|
||||
opentelemetry-exporter-otlp-proto-common==1.36.0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.36.0
|
||||
opentelemetry-proto==1.36.0
|
||||
opentelemetry-sdk==1.36.0
|
||||
opentelemetry-semantic-conventions==0.57b0
|
||||
orjson==3.11.1
|
||||
overrides==7.7.0
|
||||
packaging==25.0
|
||||
pathspec==0.12.1
|
||||
pandas==2.3.1
|
||||
pathvalidate==3.3.1
|
||||
pdfminer.six==20250506
|
||||
pdfminer-six==20250506
|
||||
pdfplumber==0.11.7
|
||||
pillow==11.3.0
|
||||
platformdirs==4.3.8
|
||||
pluggy==1.6.0
|
||||
portalocker==3.2.0
|
||||
posthog==5.4.0
|
||||
propcache==0.3.2
|
||||
protobuf==6.31.1
|
||||
py_rust_stemmers==0.1.5
|
||||
psutil==7.0.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.2
|
||||
pyasn1-modules==0.4.2
|
||||
pybase64==1.4.2
|
||||
pyclipper==1.3.0.post6
|
||||
pycparser==2.22
|
||||
pydantic==2.11.7
|
||||
pydantic_core==2.33.2
|
||||
Pygments==2.19.2
|
||||
pypdfium2==4.30.1
|
||||
PyPika==0.48.9
|
||||
pyproject_hooks==1.2.0
|
||||
pytest==8.4.1
|
||||
pydantic-core==2.33.2
|
||||
pydantic-settings==2.10.1
|
||||
pygments==2.19.2
|
||||
pylatexenc==2.10
|
||||
pymysql==1.1.1
|
||||
pypdfium2==4.30.0
|
||||
pypika==0.48.9
|
||||
pyproject-hooks==1.2.0
|
||||
python-bidi==0.6.6
|
||||
python-dateutil==2.9.0.post0
|
||||
python-docx==1.2.0
|
||||
python-dotenv==1.1.1
|
||||
python-multipart==0.0.20
|
||||
python-pptx==1.0.2
|
||||
PyYAML==6.0.2
|
||||
pytz==2025.2
|
||||
pyyaml==6.0.2
|
||||
redis==6.2.0
|
||||
referencing==0.36.2
|
||||
regex==2025.7.34
|
||||
requests==2.32.4
|
||||
requests-oauthlib==2.0.0
|
||||
rich==14.0.0
|
||||
rich-toolkit==0.14.8
|
||||
rignore==0.6.2
|
||||
rich==14.1.0
|
||||
rich-toolkit==0.14.9
|
||||
rignore==0.6.4
|
||||
rpds-py==0.26.0
|
||||
rsa==4.9.1
|
||||
sentry-sdk==2.32.0
|
||||
rtree==1.4.0
|
||||
safetensors==0.5.3
|
||||
scikit-image==0.25.2
|
||||
scipy==1.16.1
|
||||
semchunk==2.2.2
|
||||
sentry-sdk==2.34.1
|
||||
shapely==2.1.1
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
SQLAlchemy==2.0.41
|
||||
soupsieve==2.7
|
||||
sqlalchemy==2.0.42
|
||||
sqlmodel==0.0.24
|
||||
starlette==0.47.1
|
||||
starlette==0.47.2
|
||||
sympy==1.14.0
|
||||
tabulate==0.9.0
|
||||
tenacity==8.5.0
|
||||
tokenizers==0.21.2
|
||||
tomli==2.2.1
|
||||
tifffile==2025.6.11
|
||||
tokenizers==0.21.4
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.7.1+cpu
|
||||
torchvision==0.22.1+cpu
|
||||
tqdm==4.67.1
|
||||
typeguard==4.4.4
|
||||
transformers==4.54.1
|
||||
typer==0.16.0
|
||||
typing-extensions==4.14.1
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.14.1
|
||||
tzdata==2025.2
|
||||
urllib3==2.5.0
|
||||
uvicorn==0.35.0
|
||||
uvloop==0.21.0
|
||||
|
|
|
|||
197
servers/fastapi/services/score_based_chunker.py
Normal file
197
servers/fastapi/services/score_based_chunker.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
import nltk
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download("punkt", download_dir="./nltk")
|
||||
|
||||
|
||||
class ScoreBasedChunker:
|
||||
|
||||
def extract_sentences(self, text: str, min_sentences: int) -> List[str]:
|
||||
sentences = self.extract_sentences_markdown(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_nltk(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_by_stop_words(text)
|
||||
if len(sentences) < min_sentences:
|
||||
sentences = self.extract_sentences_by_new_line(text)
|
||||
if len(sentences) < min_sentences:
|
||||
raise ValueError(
|
||||
f"Only {len(sentences)} sentences found, requested {min_sentences}"
|
||||
)
|
||||
return sentences
|
||||
|
||||
def extract_sentences_markdown(self, text: str) -> List[str]:
|
||||
lines = text.split("\n")
|
||||
sentences = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
if line.startswith("#"):
|
||||
sentences.append(line)
|
||||
else:
|
||||
if line.endswith((".", "!", "?")):
|
||||
sentences.append(line)
|
||||
else:
|
||||
sentences.append(line)
|
||||
|
||||
return sentences
|
||||
|
||||
def extract_sentences_nltk(self, text: str) -> List[str]:
|
||||
sentences = nltk.sent_tokenize(text)
|
||||
return sentences
|
||||
|
||||
def extract_sentences_by_stop_words(self, text: str) -> List[str]:
|
||||
sentences = []
|
||||
current_sentence = ""
|
||||
|
||||
for char in text:
|
||||
current_sentence += char
|
||||
if char in ".!?":
|
||||
sentences.append(current_sentence.strip())
|
||||
current_sentence = ""
|
||||
|
||||
if current_sentence.strip():
|
||||
sentences.append(current_sentence.strip())
|
||||
|
||||
return [s for s in sentences if s]
|
||||
|
||||
def extract_sentences_by_new_line(self, text: str) -> List[str]:
|
||||
sentences = text.split("\n")
|
||||
result = []
|
||||
for i, sentence in enumerate(sentences):
|
||||
if i < len(sentences) - 1:
|
||||
result.append(sentence + "\n")
|
||||
else:
|
||||
result.append(sentence)
|
||||
return result
|
||||
|
||||
def score_sentences_for_heading(self, sentences: List[str]) -> List[float]:
|
||||
sentences_scores = []
|
||||
|
||||
last_heading_index = -1
|
||||
first_heading_found = False
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
score = 0.0
|
||||
|
||||
if sentence.strip().startswith("#"):
|
||||
heading_level = len(sentence) - len(sentence.lstrip("#"))
|
||||
|
||||
if heading_level <= 3:
|
||||
score += 10.0 - (heading_level - 1) * 2.0
|
||||
else:
|
||||
score += 4.0 - (heading_level - 4) * 0.5
|
||||
|
||||
if not first_heading_found:
|
||||
score += 5.0
|
||||
first_heading_found = True
|
||||
|
||||
if last_heading_index != -1:
|
||||
distance = i - last_heading_index
|
||||
distance_bonus = min(5.0, distance * 0.5)
|
||||
score += distance_bonus
|
||||
|
||||
last_heading_index = i
|
||||
|
||||
sentences_scores.append(score)
|
||||
|
||||
return sentences_scores
|
||||
|
||||
def get_chunks(
|
||||
self, sentences: List[str], sentences_scores: List[float], top_k: int = 10
|
||||
) -> List[dict]:
|
||||
if not sentences_scores:
|
||||
sentences_scores = self.score_sentences_for_heading(sentences)
|
||||
|
||||
chunks = []
|
||||
heading_scores = []
|
||||
|
||||
for i, score in enumerate(sentences_scores):
|
||||
if score > 0:
|
||||
heading_scores.append((i, score))
|
||||
|
||||
if len(heading_scores) == 0:
|
||||
return chunks
|
||||
|
||||
heading_scores.sort(key=lambda x: (-x[1], x[0]))
|
||||
|
||||
if len(heading_scores) <= top_k:
|
||||
selected_headings = [idx for idx, _ in heading_scores]
|
||||
selected_headings.sort()
|
||||
else:
|
||||
score_groups = {}
|
||||
for idx, score in heading_scores:
|
||||
rounded_score = round(score)
|
||||
if rounded_score not in score_groups:
|
||||
score_groups[rounded_score] = []
|
||||
score_groups[rounded_score].append(idx)
|
||||
|
||||
sorted_groups = sorted(
|
||||
score_groups.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
|
||||
selected_headings = []
|
||||
|
||||
for score, headings in sorted_groups:
|
||||
headings.sort()
|
||||
remaining_needed = top_k - len(selected_headings)
|
||||
|
||||
if remaining_needed <= 0:
|
||||
break
|
||||
|
||||
if len(headings) <= remaining_needed:
|
||||
selected_headings.extend(headings)
|
||||
else:
|
||||
if remaining_needed == 1:
|
||||
mid_idx = len(headings) // 2
|
||||
selected_headings.append(headings[mid_idx])
|
||||
elif remaining_needed == 2:
|
||||
selected_headings.append(headings[0])
|
||||
selected_headings.append(headings[-1])
|
||||
else:
|
||||
step = (len(headings) - 1) / (remaining_needed - 1)
|
||||
|
||||
for i in range(remaining_needed):
|
||||
index = int(round(i * step))
|
||||
if index < len(headings):
|
||||
selected_headings.append(headings[index])
|
||||
|
||||
selected_headings.sort()
|
||||
|
||||
for i, heading_idx in enumerate(selected_headings):
|
||||
heading = sentences[heading_idx]
|
||||
|
||||
if i + 1 < len(selected_headings):
|
||||
next_heading_idx = selected_headings[i + 1]
|
||||
content_end = next_heading_idx
|
||||
else:
|
||||
content_end = len(sentences)
|
||||
|
||||
content_sentences = sentences[heading_idx + 1 : content_end]
|
||||
content = " ".join(content_sentences).strip()
|
||||
|
||||
chunk = {
|
||||
"heading": heading,
|
||||
"content": content,
|
||||
"heading_index": heading_idx,
|
||||
"score": sentences_scores[heading_idx],
|
||||
}
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
async def get_n_chunks(self, text: str, n: int) -> List[dict]:
|
||||
sentences = await asyncio.to_thread(self.extract_sentences, text, n)
|
||||
sentences_scores = await asyncio.to_thread(
|
||||
self.score_sentences_for_heading, sentences
|
||||
)
|
||||
chunks = await asyncio.to_thread(
|
||||
self.get_chunks, sentences, sentences_scores, n
|
||||
)
|
||||
if len(chunks) < n:
|
||||
raise ValueError(f"Only {len(chunks)} chunks found, requested {n}")
|
||||
return chunks
|
||||
2859
servers/fastapi/uv.lock
generated
Normal file
2859
servers/fastapi/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue