90 lines
2.1 KiB
Go
90 lines
2.1 KiB
Go
package api_test
|
|
|
|
import (
|
|
"easy_rag/internal/models"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
// --------------------
|
|
// Mock LLM
|
|
// --------------------
|
|
type MockLLMService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockLLMService) Generate(prompt string) (string, error) {
|
|
args := m.Called(prompt)
|
|
return args.String(0), args.Error(1)
|
|
}
|
|
|
|
func (m *MockLLMService) GetModel() string {
|
|
args := m.Called()
|
|
return args.String(0)
|
|
}
|
|
|
|
// --------------------
|
|
// Mock Embeddings
|
|
// --------------------
|
|
type MockEmbeddingsService struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockEmbeddingsService) Vectorize(text string) ([][]float32, error) {
|
|
args := m.Called(text)
|
|
return args.Get(0).([][]float32), args.Error(1)
|
|
}
|
|
|
|
func (m *MockEmbeddingsService) GetModel() string {
|
|
args := m.Called()
|
|
return args.String(0)
|
|
}
|
|
|
|
// --------------------
|
|
// Mock Database
|
|
// --------------------
|
|
type MockDatabase struct {
|
|
mock.Mock
|
|
}
|
|
|
|
// GetDocumentInfo(id string) (models.DocumentInfo, error)
|
|
func (m *MockDatabase) GetDocumentInfo(id string) (models.Document, error) {
|
|
args := m.Called(id)
|
|
return args.Get(0).(models.Document), args.Error(1)
|
|
}
|
|
|
|
// SaveDocument(document Document) error
|
|
func (m *MockDatabase) SaveDocument(doc models.Document) error {
|
|
args := m.Called(doc)
|
|
return args.Error(0)
|
|
}
|
|
|
|
// SaveEmbeddings([]Embedding) error
|
|
func (m *MockDatabase) SaveEmbeddings(emb []models.Embedding) error {
|
|
args := m.Called(emb)
|
|
return args.Error(0)
|
|
}
|
|
|
|
// ListDocuments() ([]Document, error)
|
|
func (m *MockDatabase) ListDocuments() ([]models.Document, error) {
|
|
args := m.Called()
|
|
return args.Get(0).([]models.Document), args.Error(1)
|
|
}
|
|
|
|
// GetDocument(id string) (Document, error)
|
|
func (m *MockDatabase) GetDocument(id string) (models.Document, error) {
|
|
args := m.Called(id)
|
|
return args.Get(0).(models.Document), args.Error(1)
|
|
}
|
|
|
|
// DeleteDocument(id string) error
|
|
func (m *MockDatabase) DeleteDocument(id string) error {
|
|
args := m.Called(id)
|
|
return args.Error(0)
|
|
}
|
|
|
|
// Search(vector []float32) ([]models.Embedding, error)
|
|
func (m *MockDatabase) Search(vector [][]float32) ([]models.Embedding, error) {
|
|
args := m.Called(vector)
|
|
return args.Get(0).([]models.Embedding), args.Error(1)
|
|
}
|