1
0
Files
easy_rag/internal/pkg/database/milvus/helpers.go

277 lines
7.9 KiB
Go

package milvus
import (
"encoding/json"
"fmt"
"easy_rag/internal/models"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)
// Helper functions for extracting data
func extractIDs(docs []models.Document) []string {
ids := make([]string, len(docs))
for i, doc := range docs {
ids[i] = doc.ID
}
return ids
}
// extractLinks extracts the "Link" field from the documents.
func extractLinks(docs []models.Document) []string {
links := make([]string, len(docs))
for i, doc := range docs {
links[i] = doc.Link
}
return links
}
// extractFilenames extracts the "Filename" field from the documents.
func extractFilenames(docs []models.Document) []string {
filenames := make([]string, len(docs))
for i, doc := range docs {
filenames[i] = doc.Filename
}
return filenames
}
// extractCategories extracts the "Category" field from the documents as a comma-separated string.
func extractCategories(docs []models.Document) []string {
categories := make([]string, len(docs))
for i, doc := range docs {
categories[i] = fmt.Sprintf("%v", doc.Category)
}
return categories
}
// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents.
func extractEmbeddingModels(docs []models.Document) []string {
models := make([]string, len(docs))
for i, doc := range docs {
models[i] = doc.EmbeddingModel
}
return models
}
// extractSummaries extracts the "Summary" field from the documents.
func extractSummaries(docs []models.Document) []string {
summaries := make([]string, len(docs))
for i, doc := range docs {
summaries[i] = doc.Summary
}
return summaries
}
// extractMetadata extracts the "Metadata" field from the documents as a JSON string.
func extractMetadata(docs []models.Document) []string {
metadata := make([]string, len(docs))
for i, doc := range docs {
metaBytes, _ := json.Marshal(doc.Metadata)
metadata[i] = string(metaBytes)
}
return metadata
}
func convertToMetadata(metadata string) map[string]string {
var metadataMap map[string]string
json.Unmarshal([]byte(metadata), &metadataMap)
return metadataMap
}
func extractContents(docs []models.Document) []string {
contents := make([]string, len(docs))
for i, doc := range docs {
contents[i] = doc.Content
}
return contents
}
// extractEmbeddingIDs extracts the "ID" field from the embeddings.
func extractEmbeddingIDs(embeddings []models.Embedding) []string {
ids := make([]string, len(embeddings))
for i, embedding := range embeddings {
ids[i] = embedding.ID
}
return ids
}
// extractDocumentIDs extracts the "DocumentID" field from the embeddings.
func extractDocumentIDs(embeddings []models.Embedding) []string {
documentIDs := make([]string, len(embeddings))
for i, embedding := range embeddings {
documentIDs[i] = embedding.DocumentID
}
return documentIDs
}
// extractVectors extracts the "Vector" field from the embeddings.
func extractVectors(embeddings []models.Embedding) [][]float32 {
vectors := make([][]float32, len(embeddings))
for i, embedding := range embeddings {
vectors[i] = embedding.Vector // Direct assignment since it's already []float32
}
return vectors
}
// extractVectorsDocs extracts the "Vector" field from the documents.
func extractVectorsDocs(docs []models.Document) [][]float32 {
vectors := make([][]float32, len(docs))
for i, doc := range docs {
vectors[i] = doc.Vector // Direct assignment since it's already []float32
}
return vectors
}
// extractTextChunks extracts the "TextChunk" field from the embeddings.
func extractTextChunks(embeddings []models.Embedding) []string {
textChunks := make([]string, len(embeddings))
for i, embedding := range embeddings {
textChunks[i] = embedding.TextChunk
}
return textChunks
}
// extractDimensions extracts the "Dimension" field from the embeddings.
func extractDimensions(embeddings []models.Embedding) []int32 {
dimensions := make([]int32, len(embeddings))
for i, embedding := range embeddings {
dimensions[i] = int32(embedding.Dimension)
}
return dimensions
}
// extractOrders extracts the "Order" field from the embeddings.
func extractOrders(embeddings []models.Embedding) []int32 {
orders := make([]int32, len(embeddings))
for i, embedding := range embeddings {
orders[i] = int32(embedding.Order)
}
return orders
}
func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) {
if rs == nil || rs.Len() == 0 {
return nil, fmt.Errorf("empty result set")
}
results := []map[string]interface{}{}
for i := 0; i < rs.Len(); i++ { // Iterate through rows
row := map[string]interface{}{}
for _, fieldName := range outputFields {
column := rs.GetColumn(fieldName)
if column == nil {
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
}
switch column.Type() {
case entity.FieldTypeInt64:
value, err := column.GetAsInt64(i)
if err != nil {
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
}
row[fieldName] = value
case entity.FieldTypeInt32:
value, err := column.GetAsInt64(i)
if err != nil {
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
}
row[fieldName] = value
case entity.FieldTypeFloat:
value, err := column.GetAsDouble(i)
if err != nil {
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
}
row[fieldName] = value
case entity.FieldTypeDouble:
value, err := column.GetAsDouble(i)
if err != nil {
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
}
row[fieldName] = value
case entity.FieldTypeVarChar:
value, err := column.GetAsString(i)
if err != nil {
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
}
row[fieldName] = value
default:
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
}
}
results = append(results, row)
}
return results, nil
}
func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) {
if rs.ResultCount == 0 {
return nil, fmt.Errorf("empty result set")
}
result := make([]map[string]interface{}, rs.ResultCount)
for i := 0; i < rs.ResultCount; i++ { // Iterate through rows
result[i] = make(map[string]interface{})
for _, fieldName := range outputFields {
column := rs.Fields.GetColumn(fieldName)
result[i]["Score"] = rs.Scores[i]
if column == nil {
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
}
switch column.Type() {
case entity.FieldTypeInt64:
value, err := column.GetAsInt64(i)
if err != nil {
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
}
result[i][fieldName] = value
case entity.FieldTypeInt32:
value, err := column.GetAsInt64(i)
if err != nil {
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
}
result[i][fieldName] = value
case entity.FieldTypeFloat:
value, err := column.GetAsDouble(i)
if err != nil {
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
}
result[i][fieldName] = value
case entity.FieldTypeDouble:
value, err := column.GetAsDouble(i)
if err != nil {
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
}
result[i][fieldName] = value
case entity.FieldTypeVarChar:
value, err := column.GetAsString(i)
if err != nil {
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
}
result[i][fieldName] = value
default:
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
}
}
}
return result, nil
}