277 lines
7.9 KiB
Go
277 lines
7.9 KiB
Go
package milvus
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"easy_rag/internal/models"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
|
)
|
|
|
|
// Helper functions for extracting data
|
|
func extractIDs(docs []models.Document) []string {
|
|
ids := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
ids[i] = doc.ID
|
|
}
|
|
return ids
|
|
}
|
|
|
|
// extractLinks extracts the "Link" field from the documents.
|
|
func extractLinks(docs []models.Document) []string {
|
|
links := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
links[i] = doc.Link
|
|
}
|
|
return links
|
|
}
|
|
|
|
// extractFilenames extracts the "Filename" field from the documents.
|
|
func extractFilenames(docs []models.Document) []string {
|
|
filenames := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
filenames[i] = doc.Filename
|
|
}
|
|
return filenames
|
|
}
|
|
|
|
// extractCategories extracts the "Category" field from the documents as a comma-separated string.
|
|
func extractCategories(docs []models.Document) []string {
|
|
categories := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
categories[i] = fmt.Sprintf("%v", doc.Category)
|
|
}
|
|
return categories
|
|
}
|
|
|
|
// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents.
|
|
func extractEmbeddingModels(docs []models.Document) []string {
|
|
models := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
models[i] = doc.EmbeddingModel
|
|
}
|
|
return models
|
|
}
|
|
|
|
// extractSummaries extracts the "Summary" field from the documents.
|
|
func extractSummaries(docs []models.Document) []string {
|
|
summaries := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
summaries[i] = doc.Summary
|
|
}
|
|
return summaries
|
|
}
|
|
|
|
// extractMetadata extracts the "Metadata" field from the documents as a JSON string.
|
|
func extractMetadata(docs []models.Document) []string {
|
|
metadata := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
metaBytes, _ := json.Marshal(doc.Metadata)
|
|
metadata[i] = string(metaBytes)
|
|
}
|
|
return metadata
|
|
}
|
|
|
|
func convertToMetadata(metadata string) map[string]string {
|
|
var metadataMap map[string]string
|
|
json.Unmarshal([]byte(metadata), &metadataMap)
|
|
return metadataMap
|
|
}
|
|
|
|
func extractContents(docs []models.Document) []string {
|
|
contents := make([]string, len(docs))
|
|
for i, doc := range docs {
|
|
contents[i] = doc.Content
|
|
}
|
|
return contents
|
|
}
|
|
|
|
// extractEmbeddingIDs extracts the "ID" field from the embeddings.
|
|
func extractEmbeddingIDs(embeddings []models.Embedding) []string {
|
|
ids := make([]string, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
ids[i] = embedding.ID
|
|
}
|
|
return ids
|
|
}
|
|
|
|
// extractDocumentIDs extracts the "DocumentID" field from the embeddings.
|
|
func extractDocumentIDs(embeddings []models.Embedding) []string {
|
|
documentIDs := make([]string, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
documentIDs[i] = embedding.DocumentID
|
|
}
|
|
return documentIDs
|
|
}
|
|
|
|
// extractVectors extracts the "Vector" field from the embeddings.
|
|
func extractVectors(embeddings []models.Embedding) [][]float32 {
|
|
vectors := make([][]float32, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
vectors[i] = embedding.Vector // Direct assignment since it's already []float32
|
|
}
|
|
return vectors
|
|
}
|
|
|
|
// extractVectorsDocs extracts the "Vector" field from the documents.
|
|
func extractVectorsDocs(docs []models.Document) [][]float32 {
|
|
vectors := make([][]float32, len(docs))
|
|
for i, doc := range docs {
|
|
vectors[i] = doc.Vector // Direct assignment since it's already []float32
|
|
}
|
|
return vectors
|
|
}
|
|
|
|
// extractTextChunks extracts the "TextChunk" field from the embeddings.
|
|
func extractTextChunks(embeddings []models.Embedding) []string {
|
|
textChunks := make([]string, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
textChunks[i] = embedding.TextChunk
|
|
}
|
|
return textChunks
|
|
}
|
|
|
|
// extractDimensions extracts the "Dimension" field from the embeddings.
|
|
func extractDimensions(embeddings []models.Embedding) []int32 {
|
|
dimensions := make([]int32, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
dimensions[i] = int32(embedding.Dimension)
|
|
}
|
|
return dimensions
|
|
}
|
|
|
|
// extractOrders extracts the "Order" field from the embeddings.
|
|
func extractOrders(embeddings []models.Embedding) []int32 {
|
|
orders := make([]int32, len(embeddings))
|
|
for i, embedding := range embeddings {
|
|
orders[i] = int32(embedding.Order)
|
|
}
|
|
return orders
|
|
}
|
|
|
|
func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) {
|
|
if rs == nil || rs.Len() == 0 {
|
|
return nil, fmt.Errorf("empty result set")
|
|
}
|
|
|
|
results := []map[string]interface{}{}
|
|
|
|
for i := 0; i < rs.Len(); i++ { // Iterate through rows
|
|
row := map[string]interface{}{}
|
|
|
|
for _, fieldName := range outputFields {
|
|
column := rs.GetColumn(fieldName)
|
|
if column == nil {
|
|
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
|
}
|
|
|
|
switch column.Type() {
|
|
case entity.FieldTypeInt64:
|
|
value, err := column.GetAsInt64(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
row[fieldName] = value
|
|
|
|
case entity.FieldTypeInt32:
|
|
value, err := column.GetAsInt64(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
row[fieldName] = value
|
|
|
|
case entity.FieldTypeFloat:
|
|
value, err := column.GetAsDouble(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
row[fieldName] = value
|
|
|
|
case entity.FieldTypeDouble:
|
|
value, err := column.GetAsDouble(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
row[fieldName] = value
|
|
|
|
case entity.FieldTypeVarChar:
|
|
value, err := column.GetAsString(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
row[fieldName] = value
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
|
}
|
|
}
|
|
|
|
results = append(results, row)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) {
|
|
if rs.ResultCount == 0 {
|
|
return nil, fmt.Errorf("empty result set")
|
|
}
|
|
|
|
result := make([]map[string]interface{}, rs.ResultCount)
|
|
|
|
for i := 0; i < rs.ResultCount; i++ { // Iterate through rows
|
|
result[i] = make(map[string]interface{})
|
|
for _, fieldName := range outputFields {
|
|
column := rs.Fields.GetColumn(fieldName)
|
|
result[i]["Score"] = rs.Scores[i]
|
|
|
|
if column == nil {
|
|
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
|
}
|
|
|
|
switch column.Type() {
|
|
case entity.FieldTypeInt64:
|
|
value, err := column.GetAsInt64(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
result[i][fieldName] = value
|
|
|
|
case entity.FieldTypeInt32:
|
|
value, err := column.GetAsInt64(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
result[i][fieldName] = value
|
|
|
|
case entity.FieldTypeFloat:
|
|
value, err := column.GetAsDouble(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
result[i][fieldName] = value
|
|
|
|
case entity.FieldTypeDouble:
|
|
value, err := column.GetAsDouble(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
result[i][fieldName] = value
|
|
|
|
case entity.FieldTypeVarChar:
|
|
value, err := column.GetAsString(i)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
|
}
|
|
result[i][fieldName] = value
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|