fix(fastapi): use fastembed-vectorestore package
This commit is contained in:
parent
99583238bf
commit
c3add3850e
7 changed files with 583670 additions and 591236 deletions
|
|
@ -7,4 +7,5 @@ build
|
|||
.git
|
||||
.gitignore
|
||||
tmp
|
||||
debug
|
||||
debug
|
||||
.fastembed_cache
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -8,4 +8,5 @@ node_modules
|
|||
out
|
||||
user_data
|
||||
tmp
|
||||
debug
|
||||
debug
|
||||
.fastembed_cache
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -5,17 +5,17 @@ from ppt_generator.models.query_and_prompt_models import (
|
|||
IconCategoryEnum,
|
||||
IconQueryCollectionWithData,
|
||||
)
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from fastembed_vectorstore import FastembedVectorstore
|
||||
|
||||
|
||||
async def get_icon(
|
||||
vector_store: InMemoryVectorStore,
|
||||
vector_store: FastembedVectorstore,
|
||||
input: IconQueryCollectionWithData,
|
||||
) -> str:
|
||||
try:
|
||||
query = input.icon_query
|
||||
results = vector_store.similarity_search(query=query, k=1)
|
||||
icon_name = results[0].page_content
|
||||
results = vector_store.search(query, 1)
|
||||
icon_name = results[0][0].split("||")[0]
|
||||
return get_resource(f"assets/icons/bold/{icon_name}.png")
|
||||
except Exception as e:
|
||||
print("Error finding icon: ", e)
|
||||
|
|
@ -23,7 +23,7 @@ async def get_icon(
|
|||
|
||||
|
||||
async def get_icons(
|
||||
vector_store: InMemoryVectorStore,
|
||||
vector_store: FastembedVectorstore,
|
||||
query: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
|
|
@ -31,7 +31,7 @@ async def get_icons(
|
|||
temp_dir: str,
|
||||
) -> List[str]:
|
||||
|
||||
results = await vector_store.asimilarity_search(query=query, k=limit)
|
||||
icon_names = [result.page_content for result in results]
|
||||
results = vector_store.search(query, limit)
|
||||
icon_names = [result[0].split("||")[0] for result in results]
|
||||
|
||||
return [get_resource(f"assets/icons/bold/{each}.png") for each in icon_names]
|
||||
|
|
|
|||
|
|
@ -1,34 +1,26 @@
|
|||
import json
|
||||
import os
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
from api.utils.utils import get_resource
|
||||
from fastembed_vectorstore import FastembedVectorstore, FastembedEmbeddingModel
|
||||
|
||||
|
||||
def get_icons_vectorstore():
|
||||
vector_store_path = get_resource("assets/icons_vectorstore.json")
|
||||
embedding_model = FastembedEmbeddingModel.BGESmallENV15
|
||||
|
||||
embedding_model = TextEmbedding()
|
||||
|
||||
# if os.path.exists(vector_store_path):
|
||||
# vector_store = InMemoryVectorStore.load(vector_store_path, embeddings)
|
||||
# return vector_store
|
||||
if os.path.exists(vector_store_path):
|
||||
return FastembedVectorstore.load(embedding_model, vector_store_path)
|
||||
|
||||
vector_store = FastembedVectorstore(embedding_model)
|
||||
with open(get_resource("assets/icons.json"), "r") as f:
|
||||
icons = json.load(f)
|
||||
documents = []
|
||||
for each in icons["icons"]:
|
||||
if each["name"].split("-")[-1] == "bold":
|
||||
documents.append(f"{each['name']}||{each['tags']}")
|
||||
|
||||
icon_names = [icon["name"] for icon in icons["icons"]]
|
||||
bold_icon_names = []
|
||||
for each in icon_names:
|
||||
if each.split("-")[-1] == "bold":
|
||||
bold_icon_names.append(each)
|
||||
vector_store.embed_documents(documents)
|
||||
vector_store.save(vector_store_path)
|
||||
|
||||
documents_and_embeddings = {
|
||||
"documents": bold_icon_names,
|
||||
"embeddings": embedding_model.embed(bold_icon_names),
|
||||
}
|
||||
|
||||
with open(vector_store_path, "w") as f:
|
||||
json.dump(documents_and_embeddings, f)
|
||||
|
||||
return documents_and_embeddings
|
||||
return vector_store
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import base64
|
|||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from openai import OpenAI
|
||||
|
||||
from ppt_generator.models.query_and_prompt_models import (
|
||||
|
|
@ -67,22 +66,23 @@ async def generate_image_openai(prompt: str, output_directory: str) -> str:
|
|||
|
||||
|
||||
async def generate_image_google(prompt: str, output_directory: str) -> str:
|
||||
response = await ChatGoogleGenerativeAI(
|
||||
model="gemini-2.0-flash-preview-image-generation"
|
||||
).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
|
||||
# response = await ChatGoogleGenerativeAI(
|
||||
# model="gemini-2.0-flash-preview-image-generation"
|
||||
# ).ainvoke([prompt], generation_config={"response_modalities": ["TEXT", "IMAGE"]})
|
||||
|
||||
image_block = next(
|
||||
block
|
||||
for block in response.content
|
||||
if isinstance(block, dict) and block.get("image_url")
|
||||
)
|
||||
# image_block = next(
|
||||
# block
|
||||
# for block in response.content
|
||||
# if isinstance(block, dict) and block.get("image_url")
|
||||
# )
|
||||
|
||||
base64_image = image_block["image_url"].get("url").split(",")[-1]
|
||||
image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(base64.b64decode(base64_image))
|
||||
# base64_image = image_block["image_url"].get("url").split(",")[-1]
|
||||
# image_path = os.path.join(output_directory, f"{str(uuid.uuid4())}.jpg")
|
||||
# with open(image_path, "wb") as f:
|
||||
# f.write(base64.b64decode(base64_image))
|
||||
|
||||
return image_path
|
||||
# return image_path
|
||||
return ""
|
||||
|
||||
|
||||
async def get_image_from_pexels(prompt: str, output_directory: str) -> str:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ dnspython==2.7.0
|
|||
email_validator==2.2.0
|
||||
fastapi==0.115.12
|
||||
fastapi-cli==0.0.7
|
||||
fastembed==0.7.0
|
||||
fastembed_vectorstore==0.1.5
|
||||
filelock==3.18.0
|
||||
filetype==1.2.0
|
||||
flatbuffers==25.2.10
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue