Реализована операции Milvus для управления документами и встраиванием, включая функции вставки, запроса и удаления. Внедрите архитектуру RAG с LLM и сервисами встраивания. Добавьте обработку текста для фрагментации и конкатенации. Создайте автономный скрипт для настройки и управления Milvus. Разработайте комплексные тесты API для обработки документов и взаимодействия с LLM, включая имитации для сервисов. Расширьте возможности конфигурации пользователя с помощью дополнительных настроек YAML.
This commit is contained in:
20
internal/database/database.go
Normal file
20
internal/database/database.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package database
|
||||
|
||||
import "easy_rag/internal/models"
|
||||
|
||||
// database interface
|
||||
|
||||
// Database defines the interface for interacting with a database
|
||||
type Database interface {
|
||||
SaveDocument(document models.Document) error // the content will be chunked and saved
|
||||
GetDocumentInfo(id string) (models.Document, error) // return the document with the given id without content
|
||||
GetDocument(id string) (models.Document, error) // return the document with the given id with content assembled
|
||||
Search(vector [][]float32) ([]models.Embedding, error)
|
||||
ListDocuments() ([]models.Document, error)
|
||||
DeleteDocument(id string) error
|
||||
SaveEmbeddings(embeddings []models.Embedding) error
|
||||
// to implement in future
|
||||
// SearchByCategory(category []string) ([]Embedding, error)
|
||||
// SearchByMetadata(metadata map[string]string) ([]Embedding, error)
|
||||
// GetAllEmbeddingByDocumentID(documentID string) ([]Embedding, error)
|
||||
}
|
||||
142
internal/database/database_milvus.go
Normal file
142
internal/database/database_milvus.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"easy_rag/internal/models"
|
||||
"easy_rag/internal/pkg/database/milvus"
|
||||
)
|
||||
|
||||
// implement database interface for milvus
|
||||
type Milvus struct {
|
||||
Host string
|
||||
Client *milvus.Client
|
||||
}
|
||||
|
||||
func NewMilvus(host string) *Milvus {
|
||||
|
||||
milviusClient, err := milvus.NewClient(host)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Milvus{
|
||||
Host: host,
|
||||
Client: milviusClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Milvus) SaveDocument(document models.Document) error {
|
||||
// for now lets use context background
|
||||
ctx := context.Background()
|
||||
|
||||
return m.Client.InsertDocuments(ctx, []models.Document{document})
|
||||
}
|
||||
|
||||
func (m *Milvus) SaveEmbeddings(embeddings []models.Embedding) error {
|
||||
ctx := context.Background()
|
||||
return m.Client.InsertEmbeddings(ctx, embeddings)
|
||||
}
|
||||
|
||||
func (m *Milvus) GetDocumentInfo(id string) (models.Document, error) {
|
||||
ctx := context.Background()
|
||||
doc, err := m.Client.GetDocumentByID(ctx, id)
|
||||
|
||||
if err != nil {
|
||||
return models.Document{}, err
|
||||
}
|
||||
|
||||
if len(doc) == 0 {
|
||||
return models.Document{}, nil
|
||||
}
|
||||
return models.Document{
|
||||
ID: doc["ID"].(string),
|
||||
Link: doc["Link"].(string),
|
||||
Filename: doc["Filename"].(string),
|
||||
Category: doc["Category"].(string),
|
||||
EmbeddingModel: doc["EmbeddingModel"].(string),
|
||||
Summary: doc["Summary"].(string),
|
||||
Metadata: doc["Metadata"].(map[string]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Milvus) GetDocument(id string) (models.Document, error) {
|
||||
ctx := context.Background()
|
||||
doc, err := m.Client.GetDocumentByID(ctx, id)
|
||||
if err != nil {
|
||||
return models.Document{}, err
|
||||
}
|
||||
|
||||
embeds, err := m.Client.GetAllEmbeddingByDocID(ctx, id)
|
||||
|
||||
if err != nil {
|
||||
return models.Document{}, err
|
||||
}
|
||||
|
||||
// order embed by order
|
||||
sort.Slice(embeds, func(i, j int) bool {
|
||||
return embeds[i].Order < embeds[j].Order
|
||||
})
|
||||
|
||||
// concatenate text chunks
|
||||
var buf bytes.Buffer
|
||||
for _, embed := range embeds {
|
||||
buf.WriteString(embed.TextChunk)
|
||||
}
|
||||
|
||||
textChunks := buf.String()
|
||||
|
||||
if len(doc) == 0 {
|
||||
return models.Document{}, nil
|
||||
}
|
||||
return models.Document{
|
||||
ID: doc["ID"].(string),
|
||||
Content: textChunks,
|
||||
Link: doc["Link"].(string),
|
||||
Filename: doc["Filename"].(string),
|
||||
Category: doc["Category"].(string),
|
||||
EmbeddingModel: doc["EmbeddingModel"].(string),
|
||||
Summary: doc["Summary"].(string),
|
||||
Metadata: doc["Metadata"].(map[string]string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Milvus) Search(vector [][]float32) ([]models.Embedding, error) {
|
||||
ctx := context.Background()
|
||||
results, err := m.Client.Search(ctx, vector, 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (m *Milvus) ListDocuments() ([]models.Document, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
docs, err := m.Client.GetAllDocuments(ctx)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get docs: %w", err)
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (m *Milvus) DeleteDocument(id string) error {
|
||||
ctx := context.Background()
|
||||
err := m.Client.DeleteDocument(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.Client.DeleteEmbedding(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
8
internal/embeddings/embeddings.go
Normal file
8
internal/embeddings/embeddings.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package embeddings
|
||||
|
||||
// implement embeddings interface
|
||||
type EmbeddingsService interface {
|
||||
// generate embedding from text
|
||||
Vectorize(text string) ([][]float32, error)
|
||||
GetModel() string
|
||||
}
|
||||
78
internal/embeddings/ollama_embeddings.go
Normal file
78
internal/embeddings/ollama_embeddings.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type OllamaEmbeddings struct {
|
||||
Endpoint string
|
||||
Model string
|
||||
}
|
||||
|
||||
func NewOllamaEmbeddings(endpoint string, model string) *OllamaEmbeddings {
|
||||
return &OllamaEmbeddings{
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorize generates an embedding for the provided text
|
||||
func (o *OllamaEmbeddings) Vectorize(text string) ([][]float32, error) {
|
||||
// Define the request payload
|
||||
payload := map[string]string{
|
||||
"model": o.Model,
|
||||
"input": text,
|
||||
}
|
||||
|
||||
// Convert the payload to JSON
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||
}
|
||||
|
||||
// Create the HTTP request
|
||||
url := fmt.Sprintf("%s/api/embed", o.Endpoint)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute the HTTP request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for non-200 status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("received non-200 response: %s", body)
|
||||
}
|
||||
|
||||
// Read and parse the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Assuming the response JSON contains an "embedding" field with a float32 array
|
||||
var response struct {
|
||||
Embeddings [][]float32 `json:"embeddings"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return response.Embeddings, nil
|
||||
}
|
||||
|
||||
func (o *OllamaEmbeddings) GetModel() string {
|
||||
return o.Model
|
||||
}
|
||||
23
internal/embeddings/openai_embeddings.go
Normal file
23
internal/embeddings/openai_embeddings.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package embeddings
|
||||
|
||||
type OpenAIEmbeddings struct {
|
||||
APIKey string
|
||||
Endpoint string
|
||||
Model string
|
||||
}
|
||||
|
||||
func NewOpenAIEmbeddings(apiKey string, endpoint string, model string) *OpenAIEmbeddings {
|
||||
return &OpenAIEmbeddings{
|
||||
APIKey: apiKey,
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpenAIEmbeddings) Vectorize(text string) ([]float32, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIEmbeddings) GetModel() string {
|
||||
return o.Model
|
||||
}
|
||||
8
internal/llm/llm.go
Normal file
8
internal/llm/llm.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package llm
|
||||
|
||||
// implement llm interface
|
||||
type LLMService interface {
|
||||
// generate text from prompt
|
||||
Generate(prompt string) (string, error)
|
||||
GetModel() string
|
||||
}
|
||||
82
internal/llm/ollama_llm.go
Normal file
82
internal/llm/ollama_llm.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Ollama struct {
|
||||
Endpoint string
|
||||
Model string
|
||||
}
|
||||
|
||||
func NewOllama(endpoint string, model string) *Ollama {
|
||||
return &Ollama{
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents the structure of the expected response from the API.
|
||||
type Response struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
}
|
||||
|
||||
// Generate sends a prompt to the Ollama endpoint and returns the response
|
||||
func (o *Ollama) Generate(prompt string) (string, error) {
|
||||
// Create the request payload
|
||||
payload := map[string]interface{}{
|
||||
"model": o.Model,
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
// Marshal the payload into JSON
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Make the POST request
|
||||
resp, err := http.Post(o.Endpoint, "application/json", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read and parse the response
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API returned error: %s", string(body))
|
||||
}
|
||||
|
||||
// Unmarshal the response into a predefined structure
|
||||
var response Response
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
// Extract and return the content from the nested structure
|
||||
return response.Message.Content, nil
|
||||
}
|
||||
|
||||
func (o *Ollama) GetModel() string {
|
||||
return o.Model
|
||||
}
|
||||
24
internal/llm/openai_llm.go
Normal file
24
internal/llm/openai_llm.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package llm
|
||||
|
||||
type OpenAI struct {
|
||||
APIKey string
|
||||
Endpoint string
|
||||
Model string
|
||||
}
|
||||
|
||||
func NewOpenAI(apiKey string, endpoint string, model string) *OpenAI {
|
||||
return &OpenAI{
|
||||
APIKey: apiKey,
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpenAI) Generate(prompt string) (string, error) {
|
||||
return "", nil
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
func (o *OpenAI) GetModel() string {
|
||||
return o.Model
|
||||
}
|
||||
155
internal/llm/openroute/definitions.go
Normal file
155
internal/llm/openroute/definitions.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package openroute
|
||||
|
||||
// Request represents the main request structure.
|
||||
type Request struct {
|
||||
Messages []MessageRequest `json:"messages,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice ToolChoice `json:"tool_choice,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
|
||||
LogitBias map[int]float64 `json:"logit_bias,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MinP float64 `json:"min_p,omitempty"`
|
||||
TopA float64 `json:"top_a,omitempty"`
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
Transforms []string `json:"transforms,omitempty"`
|
||||
Models []string `json:"models,omitempty"`
|
||||
Route string `json:"route,omitempty"`
|
||||
Provider *ProviderPreferences `json:"provider,omitempty"`
|
||||
IncludeReasoning bool `json:"include_reasoning,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseFormat represents the response format structure.
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Prediction represents the prediction structure.
|
||||
type Prediction struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ProviderPreferences represents the provider preferences structure.
|
||||
type ProviderPreferences struct {
|
||||
RefererURL string `json:"referer_url,omitempty"`
|
||||
SiteName string `json:"site_name,omitempty"`
|
||||
}
|
||||
|
||||
// Message represents the message structure.
|
||||
type MessageRequest struct {
|
||||
Role MessageRole `json:"role"`
|
||||
Content interface{} `json:"content"` // Can be string or []ContentPart
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
RoleSystem MessageRole = "system"
|
||||
RoleUser MessageRole = "user"
|
||||
RoleAssistant MessageRole = "assistant"
|
||||
)
|
||||
|
||||
// ContentPart represents the content part structure.
|
||||
type ContentPart struct {
|
||||
Type ContnetType `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ContnetType string
|
||||
|
||||
const (
|
||||
ContentTypeText ContnetType = "text"
|
||||
ContentTypeImage ContnetType = "image_url"
|
||||
)
|
||||
|
||||
// ImageURL represents the image URL structure.
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionDescription represents the function description structure.
|
||||
type FunctionDescription struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Parameters interface{} `json:"parameters"` // JSON Schema object
|
||||
}
|
||||
|
||||
// Tool represents the tool structure.
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDescription `json:"function"`
|
||||
}
|
||||
|
||||
// ToolChoice represents the tool choice structure.
|
||||
type ToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ID string `json:"id"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
SystemFingerprint *string `json:"system_fingerprint,omitempty"`
|
||||
Usage *ResponseUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ResponseUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Message *MessageResponse `json:"message,omitempty"`
|
||||
Delta *Delta `json:"delta,omitempty"`
|
||||
Error *ErrorResponse `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type MessageResponse struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function interface{} `json:"function"`
|
||||
}
|
||||
198
internal/llm/openroute/route_agent.go
Normal file
198
internal/llm/openroute/route_agent.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package openroute
|
||||
|
||||
import "context"
|
||||
|
||||
type RouterAgentConfig struct {
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice ToolChoice `json:"tool_choice,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
|
||||
LogitBias map[int]float64 `json:"logit_bias,omitempty"`
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
MinP float64 `json:"min_p,omitempty"`
|
||||
TopA float64 `json:"top_a,omitempty"`
|
||||
}
|
||||
|
||||
type RouterAgent struct {
|
||||
client *OpenRouterClient
|
||||
model string
|
||||
config RouterAgentConfig
|
||||
}
|
||||
|
||||
func NewRouterAgent(client *OpenRouterClient, model string, config RouterAgentConfig) *RouterAgent {
|
||||
return &RouterAgent{
|
||||
client: client,
|
||||
model: model,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (agent RouterAgent) Completion(prompt string) (*Response, error) {
|
||||
request := Request{
|
||||
Prompt: prompt,
|
||||
Model: agent.model,
|
||||
ResponseFormat: agent.config.ResponseFormat,
|
||||
Stop: agent.config.Stop,
|
||||
MaxTokens: agent.config.MaxTokens,
|
||||
Temperature: agent.config.Temperature,
|
||||
Tools: agent.config.Tools,
|
||||
ToolChoice: agent.config.ToolChoice,
|
||||
Seed: agent.config.Seed,
|
||||
TopP: agent.config.TopP,
|
||||
TopK: agent.config.TopK,
|
||||
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||
PresencePenalty: agent.config.PresencePenalty,
|
||||
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||
LogitBias: agent.config.LogitBias,
|
||||
TopLogprobs: agent.config.TopLogprobs,
|
||||
MinP: agent.config.MinP,
|
||||
TopA: agent.config.TopA,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
return agent.client.FetchChatCompletions(request)
|
||||
}
|
||||
|
||||
func (agent RouterAgent) CompletionStream(prompt string, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||
request := Request{
|
||||
Prompt: prompt,
|
||||
Model: agent.model,
|
||||
ResponseFormat: agent.config.ResponseFormat,
|
||||
Stop: agent.config.Stop,
|
||||
MaxTokens: agent.config.MaxTokens,
|
||||
Temperature: agent.config.Temperature,
|
||||
Tools: agent.config.Tools,
|
||||
ToolChoice: agent.config.ToolChoice,
|
||||
Seed: agent.config.Seed,
|
||||
TopP: agent.config.TopP,
|
||||
TopK: agent.config.TopK,
|
||||
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||
PresencePenalty: agent.config.PresencePenalty,
|
||||
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||
LogitBias: agent.config.LogitBias,
|
||||
TopLogprobs: agent.config.TopLogprobs,
|
||||
MinP: agent.config.MinP,
|
||||
TopA: agent.config.TopA,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
|
||||
}
|
||||
|
||||
func (agent RouterAgent) Chat(messages []MessageRequest) (*Response, error) {
|
||||
request := Request{
|
||||
Messages: messages,
|
||||
Model: agent.model,
|
||||
ResponseFormat: agent.config.ResponseFormat,
|
||||
Stop: agent.config.Stop,
|
||||
MaxTokens: agent.config.MaxTokens,
|
||||
Temperature: agent.config.Temperature,
|
||||
Tools: agent.config.Tools,
|
||||
ToolChoice: agent.config.ToolChoice,
|
||||
Seed: agent.config.Seed,
|
||||
TopP: agent.config.TopP,
|
||||
TopK: agent.config.TopK,
|
||||
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||
PresencePenalty: agent.config.PresencePenalty,
|
||||
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||
LogitBias: agent.config.LogitBias,
|
||||
TopLogprobs: agent.config.TopLogprobs,
|
||||
MinP: agent.config.MinP,
|
||||
TopA: agent.config.TopA,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
return agent.client.FetchChatCompletions(request)
|
||||
}
|
||||
|
||||
func (agent RouterAgent) ChatStream(messages []MessageRequest, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||
request := Request{
|
||||
Messages: messages,
|
||||
Model: agent.model,
|
||||
ResponseFormat: agent.config.ResponseFormat,
|
||||
Stop: agent.config.Stop,
|
||||
MaxTokens: agent.config.MaxTokens,
|
||||
Temperature: agent.config.Temperature,
|
||||
Tools: agent.config.Tools,
|
||||
ToolChoice: agent.config.ToolChoice,
|
||||
Seed: agent.config.Seed,
|
||||
TopP: agent.config.TopP,
|
||||
TopK: agent.config.TopK,
|
||||
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||
PresencePenalty: agent.config.PresencePenalty,
|
||||
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||
LogitBias: agent.config.LogitBias,
|
||||
TopLogprobs: agent.config.TopLogprobs,
|
||||
MinP: agent.config.MinP,
|
||||
TopA: agent.config.TopA,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
|
||||
}
|
||||
|
||||
type RouterAgentChat struct {
|
||||
RouterAgent
|
||||
Messages []MessageRequest
|
||||
}
|
||||
|
||||
func NewRouterAgentChat(client *OpenRouterClient, model string, config RouterAgentConfig, system_prompt string) RouterAgentChat {
|
||||
return RouterAgentChat{
|
||||
RouterAgent: RouterAgent{
|
||||
client: client,
|
||||
model: model,
|
||||
config: config,
|
||||
},
|
||||
Messages: []MessageRequest{
|
||||
{
|
||||
Role: RoleSystem,
|
||||
Content: system_prompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (agent *RouterAgentChat) Chat(message string) error {
|
||||
agent.Messages = append(agent.Messages, MessageRequest{
|
||||
Role: RoleUser,
|
||||
Content: message,
|
||||
})
|
||||
request := Request{
|
||||
Messages: agent.Messages,
|
||||
Model: agent.model,
|
||||
ResponseFormat: agent.config.ResponseFormat,
|
||||
Stop: agent.config.Stop,
|
||||
MaxTokens: agent.config.MaxTokens,
|
||||
Temperature: agent.config.Temperature,
|
||||
Tools: agent.config.Tools,
|
||||
ToolChoice: agent.config.ToolChoice,
|
||||
Seed: agent.config.Seed,
|
||||
TopP: agent.config.TopP,
|
||||
TopK: agent.config.TopK,
|
||||
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||
PresencePenalty: agent.config.PresencePenalty,
|
||||
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||
LogitBias: agent.config.LogitBias,
|
||||
TopLogprobs: agent.config.TopLogprobs,
|
||||
MinP: agent.config.MinP,
|
||||
TopA: agent.config.TopA,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
response, err := agent.client.FetchChatCompletions(request)
|
||||
|
||||
agent.Messages = append(agent.Messages, MessageRequest{
|
||||
Role: RoleAssistant,
|
||||
Content: response.Choices[0].Message.Content,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
188
internal/llm/openroute/route_client.go
Normal file
188
internal/llm/openroute/route_client.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package openroute
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type OpenRouterClient struct {
|
||||
apiKey string
|
||||
apiURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewOpenRouterClient(apiKey string) *OpenRouterClient {
|
||||
return &OpenRouterClient{
|
||||
apiKey: apiKey,
|
||||
apiURL: "https://openrouter.ai/api/v1",
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewOpenRouterClientFull(apiKey string, apiUrl string, client *http.Client) *OpenRouterClient {
|
||||
return &OpenRouterClient{
|
||||
apiKey: apiKey,
|
||||
apiURL: apiUrl,
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenRouterClient) FetchChatCompletions(request Request) (*Response, error) {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + c.apiKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if request.Provider != nil {
|
||||
headers["HTTP-Referer"] = request.Provider.RefererURL
|
||||
headers["X-Title"] = request.Provider.SiteName
|
||||
}
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
output, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputReponse := &Response{}
|
||||
err = json.Unmarshal(output, outputReponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return outputReponse, nil
|
||||
}
|
||||
|
||||
func (c *OpenRouterClient) FetchChatCompletionsStream(request Request, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + c.apiKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if request.Provider != nil {
|
||||
headers["HTTP-Referer"] = request.Provider.RefererURL
|
||||
headers["X-Title"] = request.Provider.SiteName
|
||||
}
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
close(errChan)
|
||||
close(outputChan)
|
||||
close(processingChan)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
close(errChan)
|
||||
close(outputChan)
|
||||
close(processingChan)
|
||||
return
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
close(errChan)
|
||||
close(outputChan)
|
||||
close(processingChan)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer resp.Body.Close()
|
||||
|
||||
defer close(errChan)
|
||||
defer close(outputChan)
|
||||
defer close(processingChan)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errChan <- fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
close(errChan)
|
||||
close(outputChan)
|
||||
close(processingChan)
|
||||
return
|
||||
default:
|
||||
line, err := reader.ReadString('\n')
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, ":") {
|
||||
select {
|
||||
case processingChan <- true:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if line != "" {
|
||||
if strings.Compare(line[6:], "[DONE]") == 0 {
|
||||
return
|
||||
}
|
||||
response := Response{}
|
||||
err = json.Unmarshal([]byte(line[6:]), &response)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
select {
|
||||
case outputChan <- response:
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
24
internal/llm/openroute_llm.go
Normal file
24
internal/llm/openroute_llm.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package llm
|
||||
|
||||
type OpenRoute struct {
|
||||
APIKey string
|
||||
Endpoint string
|
||||
Model string
|
||||
}
|
||||
|
||||
func NewOpenRoute(apiKey string, endpoint string, model string) *OpenAI {
|
||||
return &OpenAI{
|
||||
APIKey: apiKey,
|
||||
Endpoint: endpoint,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OpenRoute) Generate(prompt string) (string, error) {
|
||||
return "", nil
|
||||
// TODO: implement
|
||||
}
|
||||
|
||||
func (o *OpenRoute) GetModel() string {
|
||||
return o.Model
|
||||
}
|
||||
27
internal/models/models.go
Normal file
27
internal/models/models.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package models
|
||||
|
||||
// type VectorEmbedding [][]float32
|
||||
// type Vector []float32
|
||||
// Document represents the data structure for storing documents
|
||||
type Document struct {
|
||||
ID string `json:"id" milvus:"ID"` // Unique identifier for the document
|
||||
Content string `json:"content" milvus:"Content"` // Text content of the document become chunks of data will not be saved
|
||||
Link string `json:"link" milvus:"Link"` // Link to the document
|
||||
Filename string `json:"filename" milvus:"Filename"` // Filename of the document
|
||||
Category string `json:"category" milvus:"Category"` // Category of the document
|
||||
EmbeddingModel string `json:"embedding_model" milvus:"EmbeddingModel"` // Embedding model used to generate the embedding
|
||||
Summary string `json:"summary" milvus:"Summary"` // Summary of the document
|
||||
Metadata map[string]string `json:"metadata" milvus:"Metadata"` // Additional metadata (e.g., author, timestamp)
|
||||
Vector []float32 `json:"vector" milvus:"Vector"`
|
||||
}
|
||||
|
||||
// Embedding represents the vector embedding for a document or query
|
||||
type Embedding struct {
|
||||
ID string `json:"id" milvus:"ID"` // Unique identifier
|
||||
DocumentID string `json:"document_id" milvus:"DocumentID"` // Unique identifier linked to a Document
|
||||
Vector []float32 `json:"vector" milvus:"Vector"` // The embedding vector
|
||||
TextChunk string `json:"text_chunk" milvus:"TextChunk"` // Text chunk of the document
|
||||
Dimension int64 `json:"dimension" milvus:"Dimension"` // Dimensionality of the vector
|
||||
Order int64 `json:"order" milvus:"Order"` // Order of the embedding to build the content back
|
||||
Score float32 `json:"score"` // Score of the embedding
|
||||
}
|
||||
169
internal/pkg/database/milvus/client.go
Normal file
169
internal/pkg/database/milvus/client.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Instance client.Client
|
||||
}
|
||||
|
||||
// InitMilvusClient initializes the Milvus client and returns a wrapper around it.
|
||||
func NewClient(milvusAddr string) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
c, err := client.NewClient(ctx, client.Config{Address: milvusAddr})
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to Milvus: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &Client{Instance: c}
|
||||
|
||||
err = client.EnsureCollections(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// EnsureCollections ensures that the required collections ("documents" and "chunks") exist.
|
||||
// If they don't exist, it creates them based on the predefined structs.
|
||||
func (m *Client) EnsureCollections(ctx context.Context) error {
|
||||
collections := []struct {
|
||||
Name string
|
||||
Schema *entity.Schema
|
||||
IndexField string
|
||||
IndexType string
|
||||
MetricType entity.MetricType
|
||||
Nlist int
|
||||
}{
|
||||
{
|
||||
Name: "documents",
|
||||
Schema: createDocumentSchema(),
|
||||
IndexField: "Vector", // Indexing the Vector field for similarity search
|
||||
IndexType: "IVF_FLAT",
|
||||
MetricType: entity.L2,
|
||||
Nlist: 10, // Number of clusters for IVF_FLAT index
|
||||
},
|
||||
{
|
||||
Name: "chunks",
|
||||
Schema: createEmbeddingSchema(),
|
||||
IndexField: "Vector", // Indexing the Vector field for similarity search
|
||||
IndexType: "IVF_FLAT",
|
||||
MetricType: entity.L2,
|
||||
Nlist: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, collection := range collections {
|
||||
// drop collection
|
||||
// err := m.Instance.DropCollection(ctx, collection.Name)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to drop collection '%s': %w", collection.Name, err)
|
||||
// }
|
||||
|
||||
// Ensure the collection exists
|
||||
exists, err := m.Instance.HasCollection(ctx, collection.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check collection existence: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
err := m.Instance.CreateCollection(ctx, collection.Schema, entity.DefaultShardNumber)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create collection '%s': %w", collection.Name, err)
|
||||
}
|
||||
log.Printf("Collection '%s' created successfully", collection.Name)
|
||||
} else {
|
||||
log.Printf("Collection '%s' already exists", collection.Name)
|
||||
}
|
||||
|
||||
// Ensure the default partition exists
|
||||
hasPartition, err := m.Instance.HasPartition(ctx, collection.Name, "_default")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check default partition for collection '%s': %w", collection.Name, err)
|
||||
}
|
||||
|
||||
if !hasPartition {
|
||||
err = m.Instance.CreatePartition(ctx, collection.Name, "_default")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default partition for collection '%s': %w", collection.Name, err)
|
||||
}
|
||||
log.Printf("Default partition created for collection '%s'", collection.Name)
|
||||
}
|
||||
|
||||
// Skip index creation if IndexField is empty
|
||||
if collection.IndexField == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure the index exists
|
||||
log.Printf("Creating index on field '%s' for collection '%s'", collection.IndexField, collection.Name)
|
||||
|
||||
idx, err := entity.NewIndexIvfFlat(collection.MetricType, collection.Nlist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IVF_FLAT index: %w", err)
|
||||
}
|
||||
|
||||
err = m.Instance.CreateIndex(ctx, collection.Name, collection.IndexField, idx, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create index on field '%s' for collection '%s': %w", collection.IndexField, collection.Name, err)
|
||||
}
|
||||
|
||||
log.Printf("Index created on field '%s' for collection '%s'", collection.IndexField, collection.Name)
|
||||
}
|
||||
|
||||
err := m.Instance.LoadCollection(ctx, "documents", false)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load collection, err: %v", err)
|
||||
}
|
||||
|
||||
err = m.Instance.LoadCollection(ctx, "chunks", false)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load collection, err: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper functions for creating schemas
|
||||
func createDocumentSchema() *entity.Schema {
|
||||
return entity.NewSchema().
|
||||
WithName("documents").
|
||||
WithDescription("Collection for storing documents").
|
||||
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
|
||||
WithField(entity.NewField().WithName("Content").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||
WithField(entity.NewField().WithName("Link").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||
WithField(entity.NewField().WithName("Filename").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||
WithField(entity.NewField().WithName("Category").WithDataType(entity.FieldTypeVarChar).WithMaxLength(8048)).
|
||||
WithField(entity.NewField().WithName("EmbeddingModel").WithDataType(entity.FieldTypeVarChar).WithMaxLength(256)).
|
||||
WithField(entity.NewField().WithName("Summary").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||
WithField(entity.NewField().WithName("Metadata").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)) // bge-m3
|
||||
}
|
||||
|
||||
func createEmbeddingSchema() *entity.Schema {
|
||||
return entity.NewSchema().
|
||||
WithName("chunks").
|
||||
WithDescription("Collection for storing document embeddings").
|
||||
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
|
||||
WithField(entity.NewField().WithName("DocumentID").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)). // bge-m3
|
||||
WithField(entity.NewField().WithName("TextChunk").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||
WithField(entity.NewField().WithName("Dimension").WithDataType(entity.FieldTypeInt32)).
|
||||
WithField(entity.NewField().WithName("Order").WithDataType(entity.FieldTypeInt32))
|
||||
}
|
||||
|
||||
// Close closes the Milvus client connection.
|
||||
func (m *Client) Close() {
|
||||
m.Instance.Close()
|
||||
}
|
||||
32
internal/pkg/database/milvus/client_test.go
Normal file
32
internal/pkg/database/milvus/client_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
type args struct {
|
||||
milvusAddr string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Client
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewClient(tt.args.milvusAddr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewClient() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
276
internal/pkg/database/milvus/helpers.go
Normal file
276
internal/pkg/database/milvus/helpers.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"easy_rag/internal/models"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
// Helper functions for extracting data
|
||||
func extractIDs(docs []models.Document) []string {
|
||||
ids := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
ids[i] = doc.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// extractLinks extracts the "Link" field from the documents.
|
||||
func extractLinks(docs []models.Document) []string {
|
||||
links := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
links[i] = doc.Link
|
||||
}
|
||||
return links
|
||||
}
|
||||
|
||||
// extractFilenames extracts the "Filename" field from the documents.
|
||||
func extractFilenames(docs []models.Document) []string {
|
||||
filenames := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
filenames[i] = doc.Filename
|
||||
}
|
||||
return filenames
|
||||
}
|
||||
|
||||
// extractCategories extracts the "Category" field from the documents as a comma-separated string.
|
||||
func extractCategories(docs []models.Document) []string {
|
||||
categories := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
categories[i] = fmt.Sprintf("%v", doc.Category)
|
||||
}
|
||||
return categories
|
||||
}
|
||||
|
||||
// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents.
|
||||
func extractEmbeddingModels(docs []models.Document) []string {
|
||||
models := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
models[i] = doc.EmbeddingModel
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// extractSummaries extracts the "Summary" field from the documents.
|
||||
func extractSummaries(docs []models.Document) []string {
|
||||
summaries := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
summaries[i] = doc.Summary
|
||||
}
|
||||
return summaries
|
||||
}
|
||||
|
||||
// extractMetadata extracts the "Metadata" field from the documents as a JSON string.
|
||||
func extractMetadata(docs []models.Document) []string {
|
||||
metadata := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
metaBytes, _ := json.Marshal(doc.Metadata)
|
||||
metadata[i] = string(metaBytes)
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func convertToMetadata(metadata string) map[string]string {
|
||||
var metadataMap map[string]string
|
||||
json.Unmarshal([]byte(metadata), &metadataMap)
|
||||
return metadataMap
|
||||
}
|
||||
|
||||
func extractContents(docs []models.Document) []string {
|
||||
contents := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
contents[i] = doc.Content
|
||||
}
|
||||
return contents
|
||||
}
|
||||
|
||||
// extractEmbeddingIDs extracts the "ID" field from the embeddings.
|
||||
func extractEmbeddingIDs(embeddings []models.Embedding) []string {
|
||||
ids := make([]string, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
ids[i] = embedding.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// extractDocumentIDs extracts the "DocumentID" field from the embeddings.
|
||||
func extractDocumentIDs(embeddings []models.Embedding) []string {
|
||||
documentIDs := make([]string, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
documentIDs[i] = embedding.DocumentID
|
||||
}
|
||||
return documentIDs
|
||||
}
|
||||
|
||||
// extractVectors extracts the "Vector" field from the embeddings.
|
||||
func extractVectors(embeddings []models.Embedding) [][]float32 {
|
||||
vectors := make([][]float32, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
vectors[i] = embedding.Vector // Direct assignment since it's already []float32
|
||||
}
|
||||
return vectors
|
||||
}
|
||||
|
||||
// extractVectorsDocs extracts the "Vector" field from the documents.
|
||||
func extractVectorsDocs(docs []models.Document) [][]float32 {
|
||||
vectors := make([][]float32, len(docs))
|
||||
for i, doc := range docs {
|
||||
vectors[i] = doc.Vector // Direct assignment since it's already []float32
|
||||
}
|
||||
return vectors
|
||||
}
|
||||
|
||||
// extractTextChunks extracts the "TextChunk" field from the embeddings.
|
||||
func extractTextChunks(embeddings []models.Embedding) []string {
|
||||
textChunks := make([]string, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
textChunks[i] = embedding.TextChunk
|
||||
}
|
||||
return textChunks
|
||||
}
|
||||
|
||||
// extractDimensions extracts the "Dimension" field from the embeddings.
|
||||
func extractDimensions(embeddings []models.Embedding) []int32 {
|
||||
dimensions := make([]int32, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
dimensions[i] = int32(embedding.Dimension)
|
||||
}
|
||||
return dimensions
|
||||
}
|
||||
|
||||
// extractOrders extracts the "Order" field from the embeddings.
|
||||
func extractOrders(embeddings []models.Embedding) []int32 {
|
||||
orders := make([]int32, len(embeddings))
|
||||
for i, embedding := range embeddings {
|
||||
orders[i] = int32(embedding.Order)
|
||||
}
|
||||
return orders
|
||||
}
|
||||
|
||||
func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) {
|
||||
if rs == nil || rs.Len() == 0 {
|
||||
return nil, fmt.Errorf("empty result set")
|
||||
}
|
||||
|
||||
results := []map[string]interface{}{}
|
||||
|
||||
for i := 0; i < rs.Len(); i++ { // Iterate through rows
|
||||
row := map[string]interface{}{}
|
||||
|
||||
for _, fieldName := range outputFields {
|
||||
column := rs.GetColumn(fieldName)
|
||||
if column == nil {
|
||||
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
||||
}
|
||||
|
||||
switch column.Type() {
|
||||
case entity.FieldTypeInt64:
|
||||
value, err := column.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
row[fieldName] = value
|
||||
|
||||
case entity.FieldTypeInt32:
|
||||
value, err := column.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
row[fieldName] = value
|
||||
|
||||
case entity.FieldTypeFloat:
|
||||
value, err := column.GetAsDouble(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
row[fieldName] = value
|
||||
|
||||
case entity.FieldTypeDouble:
|
||||
value, err := column.GetAsDouble(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
row[fieldName] = value
|
||||
|
||||
case entity.FieldTypeVarChar:
|
||||
value, err := column.GetAsString(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
row[fieldName] = value
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) {
|
||||
if rs.ResultCount == 0 {
|
||||
return nil, fmt.Errorf("empty result set")
|
||||
}
|
||||
|
||||
result := make([]map[string]interface{}, rs.ResultCount)
|
||||
|
||||
for i := 0; i < rs.ResultCount; i++ { // Iterate through rows
|
||||
result[i] = make(map[string]interface{})
|
||||
for _, fieldName := range outputFields {
|
||||
column := rs.Fields.GetColumn(fieldName)
|
||||
result[i]["Score"] = rs.Scores[i]
|
||||
|
||||
if column == nil {
|
||||
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
||||
}
|
||||
|
||||
switch column.Type() {
|
||||
case entity.FieldTypeInt64:
|
||||
value, err := column.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
result[i][fieldName] = value
|
||||
|
||||
case entity.FieldTypeInt32:
|
||||
value, err := column.GetAsInt64(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
result[i][fieldName] = value
|
||||
|
||||
case entity.FieldTypeFloat:
|
||||
value, err := column.GetAsDouble(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
result[i][fieldName] = value
|
||||
|
||||
case entity.FieldTypeDouble:
|
||||
value, err := column.GetAsDouble(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
result[i][fieldName] = value
|
||||
|
||||
case entity.FieldTypeVarChar:
|
||||
value, err := column.GetAsString(i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
||||
}
|
||||
result[i][fieldName] = value
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
270
internal/pkg/database/milvus/operations.go
Normal file
270
internal/pkg/database/milvus/operations.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package milvus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"easy_rag/internal/models"
|
||||
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
// InsertDocuments inserts documents into the "documents" collection.
|
||||
func (m *Client) InsertDocuments(ctx context.Context, docs []models.Document) error {
|
||||
idColumn := entity.NewColumnVarChar("ID", extractIDs(docs))
|
||||
contentColumn := entity.NewColumnVarChar("Content", extractContents(docs))
|
||||
linkColumn := entity.NewColumnVarChar("Link", extractLinks(docs))
|
||||
filenameColumn := entity.NewColumnVarChar("Filename", extractFilenames(docs))
|
||||
categoryColumn := entity.NewColumnVarChar("Category", extractCategories(docs))
|
||||
embeddingModelColumn := entity.NewColumnVarChar("EmbeddingModel", extractEmbeddingModels(docs))
|
||||
summaryColumn := entity.NewColumnVarChar("Summary", extractSummaries(docs))
|
||||
metadataColumn := entity.NewColumnVarChar("Metadata", extractMetadata(docs))
|
||||
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectorsDocs(docs))
|
||||
// Insert the data
|
||||
_, err := m.Instance.Insert(ctx, "documents", "_default", idColumn, contentColumn, linkColumn, filenameColumn,
|
||||
categoryColumn, embeddingModelColumn, summaryColumn, metadataColumn, vectorColumn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert documents: %w", err)
|
||||
}
|
||||
|
||||
// Flush the collection
|
||||
err = m.Instance.Flush(ctx, "documents", false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush documents collection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertEmbeddings inserts embeddings into the "chunks" collection.
|
||||
func (m *Client) InsertEmbeddings(ctx context.Context, embeddings []models.Embedding) error {
|
||||
idColumn := entity.NewColumnVarChar("ID", extractEmbeddingIDs(embeddings))
|
||||
documentIDColumn := entity.NewColumnVarChar("DocumentID", extractDocumentIDs(embeddings))
|
||||
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectors(embeddings))
|
||||
textChunkColumn := entity.NewColumnVarChar("TextChunk", extractTextChunks(embeddings))
|
||||
dimensionColumn := entity.NewColumnInt32("Dimension", extractDimensions(embeddings))
|
||||
orderColumn := entity.NewColumnInt32("Order", extractOrders(embeddings))
|
||||
|
||||
_, err := m.Instance.Insert(ctx, "chunks", "_default", idColumn, documentIDColumn, vectorColumn,
|
||||
textChunkColumn, dimensionColumn, orderColumn)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert embeddings: %w", err)
|
||||
}
|
||||
|
||||
err = m.Instance.Flush(ctx, "chunks", false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush chunks collection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDocumentByID retrieves a document from the "documents" collection by ID.
|
||||
func (m *Client) GetDocumentByID(ctx context.Context, id string) (map[string]interface{}, error) {
|
||||
collectionName := "documents"
|
||||
expr := fmt.Sprintf("ID == '%s'", id)
|
||||
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
||||
|
||||
results, err := m.Instance.Query(ctx, collectionName, nil, expr, projections)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query document by ID: %w", err)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("document with ID '%s' not found", id)
|
||||
}
|
||||
|
||||
mp, err := transformResultSet(results, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal document: %w", err)
|
||||
}
|
||||
|
||||
// convert metadata to map
|
||||
mp[0]["Metadata"] = convertToMetadata(mp[0]["Metadata"].(string))
|
||||
|
||||
return mp[0], err
|
||||
}
|
||||
|
||||
// GetAllDocuments retrieves all documents from the "documents" collection.
|
||||
func (m *Client) GetAllDocuments(ctx context.Context) ([]models.Document, error) {
|
||||
collectionName := "documents"
|
||||
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
||||
expr := ""
|
||||
|
||||
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query all documents: %w", err)
|
||||
}
|
||||
|
||||
if len(rs) == 0 {
|
||||
return nil, fmt.Errorf("no documents found in the collection")
|
||||
}
|
||||
|
||||
results, err := transformResultSet(rs, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
||||
}
|
||||
|
||||
var docs []models.Document = make([]models.Document, len(results))
|
||||
for i, result := range results {
|
||||
docs[i] = models.Document{
|
||||
ID: result["ID"].(string),
|
||||
Content: result["Content"].(string),
|
||||
Link: result["Link"].(string),
|
||||
Filename: result["Filename"].(string),
|
||||
Category: result["Category"].(string),
|
||||
EmbeddingModel: result["EmbeddingModel"].(string),
|
||||
Summary: result["Summary"].(string),
|
||||
Metadata: convertToMetadata(results[0]["Metadata"].(string)),
|
||||
}
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// GetAllEmbeddingByDocID retrieves all embeddings linked to a specific DocumentID from the "chunks" collection.
|
||||
func (m *Client) GetAllEmbeddingByDocID(ctx context.Context, documentID string) ([]models.Embedding, error) {
|
||||
collectionName := "chunks"
|
||||
projections := []string{"ID", "DocumentID", "TextChunk", "Order"} // Fetch all fields
|
||||
expr := fmt.Sprintf("DocumentID == '%s'", documentID)
|
||||
|
||||
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query embeddings by DocumentID: %w", err)
|
||||
}
|
||||
|
||||
if rs.Len() == 0 {
|
||||
return nil, fmt.Errorf("no embeddings found for DocumentID '%s'", documentID)
|
||||
}
|
||||
|
||||
results, err := transformResultSet(rs, "ID", "DocumentID", "TextChunk", "Order")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
||||
}
|
||||
|
||||
var embeddings []models.Embedding = make([]models.Embedding, rs.Len())
|
||||
|
||||
for i, result := range results {
|
||||
embeddings[i] = models.Embedding{
|
||||
ID: result["ID"].(string),
|
||||
DocumentID: result["DocumentID"].(string),
|
||||
TextChunk: result["TextChunk"].(string),
|
||||
Order: result["Order"].(int64),
|
||||
}
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (m *Client) Search(ctx context.Context, vectors [][]float32, topK int) ([]models.Embedding, error) {
|
||||
const (
|
||||
collectionName = "chunks"
|
||||
vectorDim = 1024 // Replace with your actual vector dimension
|
||||
)
|
||||
projections := []string{"ID", "DocumentID", "TextChunk", "Order"}
|
||||
metricType := entity.L2 // Default metric type
|
||||
|
||||
// Validate and convert input vectors
|
||||
searchVectors, err := validateAndConvertVectors(vectors, vectorDim)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set search parameters
|
||||
searchParams, err := entity.NewIndexIvfFlatSearchParam(16) // 16 is the number of clusters for IVF_FLAT index
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create search params: %w", err)
|
||||
}
|
||||
|
||||
// Perform the search
|
||||
searchResults, err := m.Instance.Search(ctx, collectionName, nil, "", projections, searchVectors, "Vector", metricType, topK, searchParams, client.WithLimit(10))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search collection: %w", err)
|
||||
}
|
||||
|
||||
// Process search results
|
||||
embeddings, err := processSearchResults(searchResults)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process search results: %w", err)
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// validateAndConvertVectors validates vector dimensions and converts them to Milvus-compatible format.
|
||||
func validateAndConvertVectors(vectors [][]float32, expectedDim int) ([]entity.Vector, error) {
|
||||
searchVectors := make([]entity.Vector, len(vectors))
|
||||
for i, vector := range vectors {
|
||||
if len(vector) != expectedDim {
|
||||
return nil, fmt.Errorf("vector dimension mismatch: expected %d, got %d", expectedDim, len(vector))
|
||||
}
|
||||
searchVectors[i] = entity.FloatVector(vector)
|
||||
}
|
||||
return searchVectors, nil
|
||||
}
|
||||
|
||||
// processSearchResults transforms and aggregates the search results into embeddings and sorts by score.
|
||||
func processSearchResults(results []client.SearchResult) ([]models.Embedding, error) {
|
||||
var embeddings []models.Embedding
|
||||
|
||||
for _, result := range results {
|
||||
for i := 0; i < result.ResultCount; i++ {
|
||||
embeddingMap, err := transformSearchResultSet(result, "ID", "DocumentID", "TextChunk", "Order")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to transform search result set: %w", err)
|
||||
}
|
||||
|
||||
for _, embedding := range embeddingMap {
|
||||
embeddings = append(embeddings, models.Embedding{
|
||||
ID: embedding["ID"].(string),
|
||||
DocumentID: embedding["DocumentID"].(string),
|
||||
TextChunk: embedding["TextChunk"].(string),
|
||||
Order: embedding["Order"].(int64), // Assuming 'Order' is a float64 type
|
||||
Score: embedding["Score"].(float32),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort embeddings by score in descending order (higher is better)
|
||||
sort.Slice(embeddings, func(i, j int) bool {
|
||||
return embeddings[i].Score > embeddings[j].Score
|
||||
})
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// DeleteDocument deletes a document from the "documents" collection by ID.
|
||||
func (m *Client) DeleteDocument(ctx context.Context, id string) error {
|
||||
collectionName := "documents"
|
||||
partitionName := "_default"
|
||||
expr := fmt.Sprintf("ID == '%s'", id)
|
||||
|
||||
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete document by ID: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteEmbedding deletes an embedding from the "chunks" collection by ID.
|
||||
func (m *Client) DeleteEmbedding(ctx context.Context, id string) error {
|
||||
collectionName := "chunks"
|
||||
partitionName := "_default"
|
||||
expr := fmt.Sprintf("DocumentID == '%s'", id)
|
||||
|
||||
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete embedding by DocumentID: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
21
internal/pkg/rag/rag.go
Normal file
21
internal/pkg/rag/rag.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"easy_rag/internal/database"
|
||||
"easy_rag/internal/embeddings"
|
||||
"easy_rag/internal/llm"
|
||||
)
|
||||
|
||||
type Rag struct {
|
||||
LLM llm.LLMService
|
||||
Embeddings embeddings.EmbeddingsService
|
||||
Database database.Database
|
||||
}
|
||||
|
||||
func NewRag(llm llm.LLMService, embeddings embeddings.EmbeddingsService, database database.Database) *Rag {
|
||||
return &Rag{
|
||||
LLM: llm,
|
||||
Embeddings: embeddings,
|
||||
Database: database,
|
||||
}
|
||||
}
|
||||
50
internal/pkg/textprocessor/textprocessor.go
Normal file
50
internal/pkg/textprocessor/textprocessor.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package textprocessor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/jonathanhecl/chunker"
|
||||
)
|
||||
|
||||
func CreateChunks(text string) []string {
|
||||
// Maximum characters per chunk
|
||||
const maxCharacters = 5000 // too slow otherwise
|
||||
|
||||
var chunks []string
|
||||
var currentChunk strings.Builder
|
||||
|
||||
// Use the chunker library to split text into sentences
|
||||
sentences := chunker.ChunkSentences(text)
|
||||
|
||||
for _, sentence := range sentences {
|
||||
// Check if adding the sentence exceeds the character limit
|
||||
if currentChunk.Len()+len(sentence) <= maxCharacters {
|
||||
if currentChunk.Len() > 0 {
|
||||
currentChunk.WriteString(" ") // Add a space between sentences
|
||||
}
|
||||
currentChunk.WriteString(sentence)
|
||||
} else {
|
||||
// Add the completed chunk to the chunks slice
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset() // Start a new chunk
|
||||
currentChunk.WriteString(sentence) // Add the sentence to the new chunk
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last chunk if it has content
|
||||
if currentChunk.Len() > 0 {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
}
|
||||
|
||||
// Return the chunks
|
||||
return chunks
|
||||
}
|
||||
|
||||
func ConcatenateStrings(strings []string) string {
|
||||
var result bytes.Buffer
|
||||
for _, str := range strings {
|
||||
result.WriteString(str)
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
Reference in New Issue
Block a user