271 lines
9.6 KiB
Go
271 lines
9.6 KiB
Go
package milvus
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
|
|
"easy_rag/internal/models"
|
|
|
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
|
)
|
|
|
|
// InsertDocuments inserts documents into the "documents" collection.
|
|
func (m *Client) InsertDocuments(ctx context.Context, docs []models.Document) error {
|
|
idColumn := entity.NewColumnVarChar("ID", extractIDs(docs))
|
|
contentColumn := entity.NewColumnVarChar("Content", extractContents(docs))
|
|
linkColumn := entity.NewColumnVarChar("Link", extractLinks(docs))
|
|
filenameColumn := entity.NewColumnVarChar("Filename", extractFilenames(docs))
|
|
categoryColumn := entity.NewColumnVarChar("Category", extractCategories(docs))
|
|
embeddingModelColumn := entity.NewColumnVarChar("EmbeddingModel", extractEmbeddingModels(docs))
|
|
summaryColumn := entity.NewColumnVarChar("Summary", extractSummaries(docs))
|
|
metadataColumn := entity.NewColumnVarChar("Metadata", extractMetadata(docs))
|
|
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectorsDocs(docs))
|
|
// Insert the data
|
|
_, err := m.Instance.Insert(ctx, "documents", "_default", idColumn, contentColumn, linkColumn, filenameColumn,
|
|
categoryColumn, embeddingModelColumn, summaryColumn, metadataColumn, vectorColumn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert documents: %w", err)
|
|
}
|
|
|
|
// Flush the collection
|
|
err = m.Instance.Flush(ctx, "documents", false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to flush documents collection: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// InsertEmbeddings inserts embeddings into the "chunks" collection.
|
|
func (m *Client) InsertEmbeddings(ctx context.Context, embeddings []models.Embedding) error {
|
|
idColumn := entity.NewColumnVarChar("ID", extractEmbeddingIDs(embeddings))
|
|
documentIDColumn := entity.NewColumnVarChar("DocumentID", extractDocumentIDs(embeddings))
|
|
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectors(embeddings))
|
|
textChunkColumn := entity.NewColumnVarChar("TextChunk", extractTextChunks(embeddings))
|
|
dimensionColumn := entity.NewColumnInt32("Dimension", extractDimensions(embeddings))
|
|
orderColumn := entity.NewColumnInt32("Order", extractOrders(embeddings))
|
|
|
|
_, err := m.Instance.Insert(ctx, "chunks", "_default", idColumn, documentIDColumn, vectorColumn,
|
|
textChunkColumn, dimensionColumn, orderColumn)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to insert embeddings: %w", err)
|
|
}
|
|
|
|
err = m.Instance.Flush(ctx, "chunks", false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to flush chunks collection: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDocumentByID retrieves a document from the "documents" collection by ID.
|
|
func (m *Client) GetDocumentByID(ctx context.Context, id string) (map[string]interface{}, error) {
|
|
collectionName := "documents"
|
|
expr := fmt.Sprintf("ID == '%s'", id)
|
|
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
|
|
|
results, err := m.Instance.Query(ctx, collectionName, nil, expr, projections)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query document by ID: %w", err)
|
|
}
|
|
|
|
if len(results) == 0 {
|
|
return nil, fmt.Errorf("document with ID '%s' not found", id)
|
|
}
|
|
|
|
mp, err := transformResultSet(results, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal document: %w", err)
|
|
}
|
|
|
|
// convert metadata to map
|
|
mp[0]["Metadata"] = convertToMetadata(mp[0]["Metadata"].(string))
|
|
|
|
return mp[0], err
|
|
}
|
|
|
|
// GetAllDocuments retrieves all documents from the "documents" collection.
|
|
func (m *Client) GetAllDocuments(ctx context.Context) ([]models.Document, error) {
|
|
collectionName := "documents"
|
|
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
|
expr := ""
|
|
|
|
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query all documents: %w", err)
|
|
}
|
|
|
|
if len(rs) == 0 {
|
|
return nil, fmt.Errorf("no documents found in the collection")
|
|
}
|
|
|
|
results, err := transformResultSet(rs, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
|
}
|
|
|
|
var docs []models.Document = make([]models.Document, len(results))
|
|
for i, result := range results {
|
|
docs[i] = models.Document{
|
|
ID: result["ID"].(string),
|
|
Content: result["Content"].(string),
|
|
Link: result["Link"].(string),
|
|
Filename: result["Filename"].(string),
|
|
Category: result["Category"].(string),
|
|
EmbeddingModel: result["EmbeddingModel"].(string),
|
|
Summary: result["Summary"].(string),
|
|
Metadata: convertToMetadata(results[0]["Metadata"].(string)),
|
|
}
|
|
}
|
|
|
|
return docs, nil
|
|
}
|
|
|
|
// GetAllEmbeddingByDocID retrieves all embeddings linked to a specific DocumentID from the "chunks" collection.
|
|
func (m *Client) GetAllEmbeddingByDocID(ctx context.Context, documentID string) ([]models.Embedding, error) {
|
|
collectionName := "chunks"
|
|
projections := []string{"ID", "DocumentID", "TextChunk", "Order"} // Fetch all fields
|
|
expr := fmt.Sprintf("DocumentID == '%s'", documentID)
|
|
|
|
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query embeddings by DocumentID: %w", err)
|
|
}
|
|
|
|
if rs.Len() == 0 {
|
|
return nil, fmt.Errorf("no embeddings found for DocumentID '%s'", documentID)
|
|
}
|
|
|
|
results, err := transformResultSet(rs, "ID", "DocumentID", "TextChunk", "Order")
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
|
}
|
|
|
|
var embeddings []models.Embedding = make([]models.Embedding, rs.Len())
|
|
|
|
for i, result := range results {
|
|
embeddings[i] = models.Embedding{
|
|
ID: result["ID"].(string),
|
|
DocumentID: result["DocumentID"].(string),
|
|
TextChunk: result["TextChunk"].(string),
|
|
Order: result["Order"].(int64),
|
|
}
|
|
}
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
func (m *Client) Search(ctx context.Context, vectors [][]float32, topK int) ([]models.Embedding, error) {
|
|
const (
|
|
collectionName = "chunks"
|
|
vectorDim = 1024 // Replace with your actual vector dimension
|
|
)
|
|
projections := []string{"ID", "DocumentID", "TextChunk", "Order"}
|
|
metricType := entity.L2 // Default metric type
|
|
|
|
// Validate and convert input vectors
|
|
searchVectors, err := validateAndConvertVectors(vectors, vectorDim)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set search parameters
|
|
searchParams, err := entity.NewIndexIvfFlatSearchParam(16) // 16 is the number of clusters for IVF_FLAT index
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create search params: %w", err)
|
|
}
|
|
|
|
// Perform the search
|
|
searchResults, err := m.Instance.Search(ctx, collectionName, nil, "", projections, searchVectors, "Vector", metricType, topK, searchParams, client.WithLimit(10))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to search collection: %w", err)
|
|
}
|
|
|
|
// Process search results
|
|
embeddings, err := processSearchResults(searchResults)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to process search results: %w", err)
|
|
}
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
// validateAndConvertVectors validates vector dimensions and converts them to Milvus-compatible format.
|
|
func validateAndConvertVectors(vectors [][]float32, expectedDim int) ([]entity.Vector, error) {
|
|
searchVectors := make([]entity.Vector, len(vectors))
|
|
for i, vector := range vectors {
|
|
if len(vector) != expectedDim {
|
|
return nil, fmt.Errorf("vector dimension mismatch: expected %d, got %d", expectedDim, len(vector))
|
|
}
|
|
searchVectors[i] = entity.FloatVector(vector)
|
|
}
|
|
return searchVectors, nil
|
|
}
|
|
|
|
// processSearchResults transforms and aggregates the search results into embeddings and sorts by score.
|
|
func processSearchResults(results []client.SearchResult) ([]models.Embedding, error) {
|
|
var embeddings []models.Embedding
|
|
|
|
for _, result := range results {
|
|
for i := 0; i < result.ResultCount; i++ {
|
|
embeddingMap, err := transformSearchResultSet(result, "ID", "DocumentID", "TextChunk", "Order")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to transform search result set: %w", err)
|
|
}
|
|
|
|
for _, embedding := range embeddingMap {
|
|
embeddings = append(embeddings, models.Embedding{
|
|
ID: embedding["ID"].(string),
|
|
DocumentID: embedding["DocumentID"].(string),
|
|
TextChunk: embedding["TextChunk"].(string),
|
|
Order: embedding["Order"].(int64), // Assuming 'Order' is a float64 type
|
|
Score: embedding["Score"].(float32),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort embeddings by score in descending order (higher is better)
|
|
sort.Slice(embeddings, func(i, j int) bool {
|
|
return embeddings[i].Score > embeddings[j].Score
|
|
})
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
// DeleteDocument deletes a document from the "documents" collection by ID.
|
|
func (m *Client) DeleteDocument(ctx context.Context, id string) error {
|
|
collectionName := "documents"
|
|
partitionName := "_default"
|
|
expr := fmt.Sprintf("ID == '%s'", id)
|
|
|
|
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete document by ID: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteEmbedding deletes an embedding from the "chunks" collection by ID.
|
|
func (m *Client) DeleteEmbedding(ctx context.Context, id string) error {
|
|
collectionName := "chunks"
|
|
partitionName := "_default"
|
|
expr := fmt.Sprintf("DocumentID == '%s'", id)
|
|
|
|
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete embedding by DocumentID: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|