diff --git a/.gitignore b/.gitignore
index 5b90e79..0874f28 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,27 +1,6 @@
-# ---> Go
-# If you prefer the allow list template instead of the deny list, see community template:
-# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
-#
-# Binaries for programs and plugins
-*.exe
-*.exe~
-*.dll
-*.so
-*.dylib
-
-# Test binary, built with `go test -c`
-*.test
-
-# Output of the go coverage tool, specifically when used with LiteIDE
-*.out
-
-# Dependency directories (remove the comment below to include it)
-# vendor/
-
-# Go workspace file
-go.work
-go.work.sum
-
-# env file
-.env
-
+*.env
+*.sum
+id_rsa2
+id_rsa2.pub
+/tests/RAG.postman_collection.json
+/volumes/
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/easy_rag.iml b/.idea/easy_rag.iml
new file mode 100644
index 0000000..b279d48
--- /dev/null
+++ b/.idea/easy_rag.iml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..8536ff3
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 89cd039..9e3f99f 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,454 @@
-# easy_rag
+# Easy RAG - Система Retrieval-Augmented Generation
+## Обзор
+
+Easy RAG - это мощная система для управления документами и генерации ответов на основе метода Retrieval-Augmented Generation (RAG). Проект реализует полнофункциональное API для загрузки, хранения, поиска и анализа документов с использованием векторных баз данных и современных языковых моделей.
+
+### Ключевые возможности
+
+- 🔍 **Семантический поиск** по документам с использованием векторных эмбеддингов
+- 📄 **Управление документами** - загрузка, хранение, получение и удаление
+- 🤖 **Интеграция с LLM** - поддержка OpenAI, Ollama и OpenRoute
+- 💾 **Векторная база данных** - использование Milvus для хранения эмбеддингов
+- 🔧 **Модульная архитектура** - легко расширяемая и настраиваемая система
+- 🚀 **RESTful API** - простое и понятное API для интеграции
+- 🐳 **Docker-готовность** - контейнеризация для простого развертывания
+
+---
+
+## Архитектура системы
+
+### Компоненты
+
+1. **API слой** (`api/`) - HTTP API на базе Echo Framework
+2. **Основная логика** (`internal/pkg/rag/`) - ядро RAG системы
+3. **Провайдеры LLM** (`internal/llm/`) - интеграция с языковыми моделями
+4. **Провайдеры эмбеддингов** (`internal/embeddings/`) - генерация векторных представлений
+5. **База данных** (`internal/database/`) - работа с векторной БД Milvus
+6. **Обработка текста** (`internal/pkg/textprocessor/`) - разбиение на чанки
+7. **Конфигурация** (`config/`) - управление настройками
+
+### Поддерживаемые провайдеры
+
+#### LLM (Языковые модели):
+- **OpenAI** - GPT-3.5, GPT-4 и другие модели
+- **Ollama** - локальные открытые модели
+- **OpenRoute** - доступ к различным моделям через единый API
+
+#### Эмбеддинги:
+- **OpenAI Embeddings** - text-embedding-ada-002 и новые модели
+- **Ollama Embeddings** - локальные модели эмбеддингов (например, bge-m3)
+
+#### Векторная база данных:
+- **Milvus** - высокопроизводительная векторная база данных
+
+---
+
+## API Документация
+
+### Базовый URL
+```
+http://localhost:4002/api/v1
+```
+
+### Эндпоинты
+
+#### 1. Получить все документы
+```http
+GET /api/v1/docs
+```
+**Описание**: Получить список всех сохраненных документов.
+
+**Ответ**:
+```json
+{
+ "version": "v1",
+ "docs": [
+ {
+ "id": "document_id",
+ "filename": "document.txt",
+ "summary": "Краткое описание документа",
+ "metadata": {
+ "category": "Техническая документация",
+ "author": "Автор"
+ }
+ }
+ ]
+}
+```
+
+#### 2. Загрузить документы
+```http
+POST /api/v1/upload
+```
+**Описание**: Загрузить один или несколько документов для обработки и индексации.
+
+**Тело запроса**:
+```json
+{
+ "docs": [
+ {
+ "content": "Содержимое документа...",
+ "link": "https://example.com/document",
+ "filename": "document.txt",
+ "category": "Категория",
+ "metadata": {
+ "author": "Автор",
+ "date": "2024-01-01"
+ }
+ }
+ ]
+}
+```
+
+**Ответ**:
+```json
+{
+ "version": "v1",
+ "task_id": "unique_task_id",
+ "expected_time": "10m",
+ "status": "Обработка начата"
+}
+```
+
+#### 3. Получить документ по ID
+```http
+GET /api/v1/doc/{id}
+```
+**Описание**: Получить детальную информацию о документе по его идентификатору.
+
+**Ответ**:
+```json
+{
+ "version": "v1",
+ "doc": {
+ "id": "document_id",
+ "content": "Полное содержимое документа",
+ "filename": "document.txt",
+ "summary": "Краткое описание",
+ "metadata": {
+ "category": "Категория",
+ "author": "Автор"
+ }
+ }
+}
+```
+
+#### 4. Задать вопрос
+```http
+POST /api/v1/ask
+```
+**Описание**: Задать вопрос на основе проиндексированных документов.
+
+**Тело запроса**:
+```json
+{
+ "question": "Что такое ISO 27001?"
+}
+```
+
+**Ответ**:
+```json
+{
+ "version": "v1",
+ "docs": ["document_id_1", "document_id_2"],
+ "answer": "ISO 27001 - это международный стандарт информационной безопасности..."
+}
+```
+
+#### 5. Удалить документ
+```http
+DELETE /api/v1/doc/{id}
+```
+**Описание**: Удалить документ по его идентификатору.
+
+**Ответ**:
+```json
+{
+ "version": "v1",
+ "docs": null
+}
+```
+
+---
+
+## Структуры данных
+
+### Document (Документ)
+```go
+type Document struct {
+ ID string `json:"id"` // Уникальный идентификатор
+ Content string `json:"content"` // Содержимое документа
+ Link string `json:"link"` // Ссылка на источник
+ Filename string `json:"filename"` // Имя файла
+ Category string `json:"category"` // Категория
+ EmbeddingModel string `json:"embedding_model"` // Модель эмбеддингов
+ Summary string `json:"summary"` // Краткое описание
+ Metadata map[string]string `json:"metadata"` // Метаданные
+ Vector []float32 `json:"vector"` // Векторное представление
+}
+```
+
+### Embedding (Эмбеддинг)
+```go
+type Embedding struct {
+ ID string `json:"id"` // Уникальный идентификатор
+ DocumentID string `json:"document_id"` // ID связанного документа
+ Vector []float32 `json:"vector"` // Векторное представление
+ TextChunk string `json:"text_chunk"` // Фрагмент текста
+ Dimension int64 `json:"dimension"` // Размерность вектора
+ Order int64 `json:"order"` // Порядок фрагмента
+ Score float32 `json:"score"` // Оценка релевантности
+}
+```
+
+---
+
+## Установка и настройка
+
+### Требования
+- Go 1.24.3+
+- Milvus векторная база данных
+- Ollama (для локальных моделей) или API ключи для OpenAI/OpenRoute
+
+### 1. Клонирование репозитория
+```bash
+git clone https://github.com/elchemista/easy_rag.git
+cd easy_rag
+```
+
+### 2. Установка зависимостей
+```bash
+go mod tidy
+```
+
+### 3. Настройка окружения
+Создайте файл `.env` или установите переменные окружения:
+
+```env
+# LLM настройки
+OPENAI_API_KEY=your_openai_api_key
+OPENAI_ENDPOINT=https://api.openai.com/v1
+OPENAI_MODEL=gpt-3.5-turbo
+
+OPENROUTE_API_KEY=your_openroute_api_key
+OPENROUTE_ENDPOINT=https://openrouter.ai/api/v1
+OPENROUTE_MODEL=anthropic/claude-3-haiku
+
+OLLAMA_ENDPOINT=http://localhost:11434/api/chat
+OLLAMA_MODEL=qwen3:latest
+
+# Эмбеддинги
+OPENAI_EMBEDDING_API_KEY=your_openai_api_key
+OPENAI_EMBEDDING_ENDPOINT=https://api.openai.com/v1
+OPENAI_EMBEDDING_MODEL=text-embedding-ada-002
+
+OLLAMA_EMBEDDING_ENDPOINT=http://localhost:11434
+OLLAMA_EMBEDDING_MODEL=bge-m3
+
+# База данных
+MILVUS_HOST=localhost:19530
+```
+
+### 4. Запуск Milvus
+```bash
+# Используя Docker
+docker run -d --name milvus-standalone \
+ -p 19530:19530 -p 9091:9091 \
+ milvusdb/milvus:latest
+```
+
+### 5. Запуск приложения
+```bash
+go run cmd/rag/main.go
+```
+
+API будет доступно по адресу `http://localhost:4002`
+
+---
+
+## Запуск с Docker
+
+### Сборка образа
+```bash
+docker build -f deploy/Dockerfile -t easy-rag .
+```
+
+### Запуск контейнера
+```bash
+docker run -d -p 4002:4002 \
+ -e MILVUS_HOST=your_milvus_host:19530 \
+ -e OLLAMA_ENDPOINT=http://your_ollama_host:11434/api/chat \
+ -e OLLAMA_MODEL=qwen3:latest \
+ easy-rag
+```
+
+---
+
+## Примеры использования
+
+### 1. Загрузка документа
+```bash
+curl -X POST http://localhost:4002/api/v1/upload \
+ -H "Content-Type: application/json" \
+ -d '{
+ "docs": [{
+ "content": "Это тестовый документ для демонстрации RAG системы.",
+ "filename": "test.txt",
+ "category": "Тест",
+ "metadata": {
+ "author": "Пользователь",
+ "type": "demo"
+ }
+ }]
+ }'
+```
+
+### 2. Поиск ответа
+```bash
+curl -X POST http://localhost:4002/api/v1/ask \
+ -H "Content-Type: application/json" \
+ -d '{
+ "question": "О чем этот документ?"
+ }'
+```
+
+### 3. Получение всех документов
+```bash
+curl http://localhost:4002/api/v1/docs
+```
+
+---
+
+## Архитектурные особенности
+
+### Модульность
+Система спроектирована с использованием интерфейсов, что позволяет легко:
+- Переключаться между различными LLM провайдерами
+- Использовать разные модели эмбеддингов
+- Менять векторную базу данных
+- Добавлять новые методы обработки текста
+
+### Обработка текста
+- Автоматическое разбиение документов на чанки
+- Генерация эмбеддингов для каждого фрагмента
+- Сохранение порядка фрагментов для корректной реконструкции
+
+### Поиск и ранжирование
+- Семантический поиск по векторным представлениям
+- Ранжирование результатов по релевантности
+- Контекстная генерация ответов на основе найденных документов
+
+---
+
+## Разработка и тестирование
+
+### Структура проекта
+```
+easy_rag/
+├── api/ # HTTP API обработчики
+├── cmd/rag/ # Точка входа приложения
+├── config/ # Конфигурация
+├── deploy/ # Docker файлы
+├── internal/ # Внутренняя логика
+│ ├── database/ # Интерфейсы БД
+│ ├── embeddings/ # Провайдеры эмбеддингов
+│ ├── llm/ # Провайдеры LLM
+│ ├── models/ # Модели данных
+│ └── pkg/ # Пакеты общего назначения
+├── scripts/ # Вспомогательные скрипты
+└── tests/ # Тесты и коллекции Postman
+```
+
+### Запуск тестов
+```bash
+go test ./...
+```
+
+### Использование Postman
+В папке `tests/` находится коллекция Postman для тестирования API:
+```
+tests/RAG.postman_collection.json
+```
+
+---
+
+## Производительность и масштабирование
+
+### Рекомендации по производительности
+- Используйте SSD для хранения векторной базы данных
+- Настройте индексы Milvus для оптимальной производительности
+- Рассмотрите использование GPU для генерации эмбеддингов
+- Кэшируйте часто запрашиваемые результаты
+
+### Масштабирование
+- Горизонтальное масштабирование Milvus кластера
+- Балансировка нагрузки между несколькими экземплярами API
+- Асинхронная обработка загрузки документов
+- Использование очередей для обработки больших объемов данных
+
+---
+
+## Устранение неполадок
+
+### Частые проблемы
+
+1. **Не удается подключиться к Milvus**
+ - Проверьте, что Milvus запущен и доступен
+ - Убедитесь в правильности MILVUS_HOST
+
+2. **Ошибки LLM провайдера**
+ - Проверьте API ключи
+ - Убедитесь в доступности эндпоинтов
+ - Проверьте правильность названий моделей
+
+3. **Медленная обработка документов**
+ - Уменьшите размер чанков
+ - Используйте более быстрые модели эмбеддингов
+ - Оптимизируйте настройки Milvus
+
+---
+
+## Вклад в проект
+
+Мы приветствуем вклад в развитие проекта! Пожалуйста:
+
+1. Форкните репозиторий
+2. Создайте ветку для новой функции
+3. Внесите изменения
+4. Добавьте тесты
+5. Отправьте Pull Request
+
+---
+
+## Лицензия
+
+Проект распространяется под лицензией MIT. См. файл `LICENSE` для подробностей.
+
+---
+
+## Поддержка
+
+Если у вас есть вопросы или проблемы:
+- Создайте Issue в GitHub репозитории
+- Обратитесь к документации API
+- Проверьте примеры в папке `tests/`
+
+---
+
+## Roadmap
+
+### Ближайшие планы
+- [ ] Поддержка дополнительных форматов документов (PDF, DOCX)
+- [ ] Веб-интерфейс для управления документами
+- [ ] Улучшенная система метаданных
+- [ ] Поддержка multimodal моделей
+- [ ] Кэширование результатов
+- [ ] Мониторинг и метрики
+
+### Долгосрочные цели
+- [ ] Поддержка множественных языков
+- [ ] Федеративный поиск по нескольким источникам
+- [ ] Интеграция с внешними системами (SharePoint, Confluence)
+- [ ] Продвинутая аналитика использования
+- [ ] Система разрешений и ролей
diff --git a/api/api.go b/api/api.go
new file mode 100644
index 0000000..0403c15
--- /dev/null
+++ b/api/api.go
@@ -0,0 +1,35 @@
+package api
+
+import (
+ "fmt"
+ "github.com/labstack/echo/v4"
+
+ "easy_rag/internal/pkg/rag"
+ "github.com/labstack/echo/v4/middleware"
+)
+
+const (
+ // APIVersion is the version of the API
+ APIVersion = "v1"
+)
+
+func NewAPI(e *echo.Echo, rag *rag.Rag) {
+ // Middleware
+ e.Use(middleware.Logger())
+ e.Use(middleware.Recover())
+ // put rag pointer in context
+ e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ c.Set("Rag", rag)
+ return next(c)
+ }
+ })
+
+ api := e.Group(fmt.Sprintf("/api/%s", APIVersion))
+
+ api.POST("/upload", UploadHandler)
+ api.POST("/ask", AskDocHandler)
+ api.GET("/docs", ListAllDocsHandler)
+ api.GET("/doc/:id", GetDocHandler)
+ api.DELETE("/doc/:id", DeleteDocHandler)
+}
diff --git a/api/handler.go b/api/handler.go
new file mode 100644
index 0000000..4fcf3ef
--- /dev/null
+++ b/api/handler.go
@@ -0,0 +1,248 @@
+package api
+
+import (
+ "fmt"
+ "github.com/labstack/echo/v4"
+ "log"
+ "net/http"
+
+ "easy_rag/internal/models"
+ "easy_rag/internal/pkg/rag"
+ "easy_rag/internal/pkg/textprocessor"
+ "github.com/google/uuid"
+)
+
+type UploadDoc struct {
+ Content string `json:"content"`
+ Link string `json:"link"`
+ Filename string `json:"filename"`
+ Category string `json:"category"`
+ Metadata map[string]string `json:"metadata"`
+}
+
+type RequestUpload struct {
+ Docs []UploadDoc `json:"docs"`
+}
+
+type RequestQuestion struct {
+ Question string `json:"question"`
+}
+
+type ResposeQuestion struct {
+ Version string `json:"version"`
+ Docs []models.Document `json:"docs"`
+ Answer string `json:"answer"`
+}
+
+func UploadHandler(c echo.Context) error {
+ // Retrieve the RAG instance from context
+ rag := c.Get("Rag").(*rag.Rag)
+
+ var request RequestUpload
+ if err := c.Bind(&request); err != nil {
+ return ErrorHandler(err, c)
+ }
+
+ // Generate a unique task ID
+ taskID := uuid.NewString()
+
+ // Launch the upload process in a separate goroutine
+ go func(taskID string, request RequestUpload) {
+ log.Printf("Task %s: started processing", taskID)
+ defer log.Printf("Task %s: completed processing", taskID)
+
+ var docs []models.Document
+
+ for idx, doc := range request.Docs {
+ // Generate a unique ID for each document
+ docID := uuid.NewString()
+ log.Printf("Task %s: processing document %d with generated ID %s (filename: %s)", taskID, idx, docID, doc.Filename)
+
+ // Step 1: Create chunks from document content
+ chunks := textprocessor.CreateChunks(doc.Content)
+ log.Printf("Task %s: created %d chunks for document %s", taskID, len(chunks), docID)
+
+ // Step 2: Generate summary for the document
+ var summaryChunks string
+ if len(chunks) < 4 {
+ summaryChunks = doc.Content
+ } else {
+ summaryChunks = textprocessor.ConcatenateStrings(chunks[:3])
+ }
+
+ log.Printf("Task %s: generating summary for document %s", taskID, docID)
+ summary, err := rag.LLM.Generate(fmt.Sprintf("Give me only summary of the following text: %s", summaryChunks))
+ if err != nil {
+ log.Printf("Task %s: error generating summary for document %s: %v", taskID, docID, err)
+ return
+ }
+ log.Printf("Task %s: generated summary for document %s", taskID, docID)
+
+ // Step 3: Vectorize the summary
+ log.Printf("Task %s: vectorizing summary for document %s", taskID, docID)
+ vectorSum, err := rag.Embeddings.Vectorize(summary)
+ if err != nil {
+ log.Printf("Task %s: error vectorizing summary for document %s: %v", taskID, docID, err)
+ return
+ }
+ log.Printf("Task %s: vectorized summary for document %s", taskID, docID)
+
+ // Step 4: Save the document
+ document := models.Document{
+ ID: docID, // Use generated ID
+ Content: "",
+ Link: doc.Link,
+ Filename: doc.Filename,
+ Category: doc.Category,
+ EmbeddingModel: rag.Embeddings.GetModel(),
+ Summary: summary,
+ Vector: vectorSum[0],
+ Metadata: doc.Metadata,
+ }
+ log.Printf("Task %s: saving document %s", taskID, docID)
+ if err := rag.Database.SaveDocument(document); err != nil {
+ log.Printf("Task %s: error saving document %s: %v", taskID, docID, err)
+ return
+ }
+ log.Printf("Task %s: saved document %s", taskID, docID)
+
+ // Step 5: Process and save embeddings for each chunk
+ var embeddings []models.Embedding
+ for order, chunk := range chunks {
+ log.Printf("Task %s: vectorizing chunk %d for document %s", taskID, order, docID)
+ vectorEmbedding, err := rag.Embeddings.Vectorize(chunk)
+ if err != nil {
+ log.Printf("Task %s: error vectorizing chunk %d for document %s: %v", taskID, order, docID, err)
+ return
+ }
+ log.Printf("Task %s: vectorized chunk %d for document %s", taskID, order, docID)
+
+ embedding := models.Embedding{
+ ID: uuid.NewString(),
+ DocumentID: docID,
+ Vector: vectorEmbedding[0],
+ TextChunk: chunk,
+ Dimension: int64(1024),
+ Order: int64(order),
+ }
+ embeddings = append(embeddings, embedding)
+ }
+
+ log.Printf("Task %s: saving %d embeddings for document %s", taskID, len(embeddings), docID)
+ if err := rag.Database.SaveEmbeddings(embeddings); err != nil {
+ log.Printf("Task %s: error saving embeddings for document %s: %v", taskID, docID, err)
+ return
+ }
+ log.Printf("Task %s: saved embeddings for document %s", taskID, docID)
+
+ docs = append(docs, document)
+ }
+ }(taskID, request)
+
+ // Return the task ID and expected completion time
+ return c.JSON(http.StatusAccepted, map[string]interface{}{
+ "version": APIVersion,
+ "task_id": taskID,
+ "expected_time": "10m",
+ "status": "Processing started",
+ })
+}
+
+func ListAllDocsHandler(c echo.Context) error {
+ rag := c.Get("Rag").(*rag.Rag)
+ docs, err := rag.Database.ListDocuments()
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+ return c.JSON(http.StatusOK, map[string]interface{}{
+ "version": APIVersion,
+ "docs": docs,
+ })
+}
+
+func GetDocHandler(c echo.Context) error {
+ rag := c.Get("Rag").(*rag.Rag)
+ id := c.Param("id")
+ doc, err := rag.Database.GetDocument(id)
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+ return c.JSON(http.StatusOK, map[string]interface{}{
+ "version": APIVersion,
+ "doc": doc,
+ })
+}
+
+func AskDocHandler(c echo.Context) error {
+ rag := c.Get("Rag").(*rag.Rag)
+
+ var request RequestQuestion
+ err := c.Bind(&request)
+
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+
+ questionV, err := rag.Embeddings.Vectorize(request.Question)
+
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+
+ embeddings, err := rag.Database.Search(questionV)
+
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+
+ if len(embeddings) == 0 {
+ return c.JSON(http.StatusOK, map[string]interface{}{
+ "version": APIVersion,
+ "docs": nil,
+ "answer": "Don't found any relevant documents",
+ })
+ }
+
+ answer, err := rag.LLM.Generate(fmt.Sprintf("Given the following information: %s \nAnswer the question: %s", embeddings[0].TextChunk, request.Question))
+
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+
+ // Use a map to track unique DocumentIDs
+ docSet := make(map[string]struct{})
+ for _, embedding := range embeddings {
+ docSet[embedding.DocumentID] = struct{}{}
+ }
+
+ // Convert the map keys to a slice
+ docs := make([]string, 0, len(docSet))
+ for docID := range docSet {
+ docs = append(docs, docID)
+ }
+
+ return c.JSON(http.StatusOK, map[string]interface{}{
+ "version": APIVersion,
+ "docs": docs,
+ "answer": answer,
+ })
+}
+
+func DeleteDocHandler(c echo.Context) error {
+ rag := c.Get("Rag").(*rag.Rag)
+ id := c.Param("id")
+ err := rag.Database.DeleteDocument(id)
+ if err != nil {
+ return ErrorHandler(err, c)
+ }
+ return c.JSON(http.StatusOK, map[string]interface{}{
+ "version": APIVersion,
+ "docs": nil,
+ })
+}
+
+func ErrorHandler(err error, c echo.Context) error {
+ return c.JSON(http.StatusBadRequest, map[string]interface{}{
+ "error": err.Error(),
+ })
+}
diff --git a/cmd/rag/main.go b/cmd/rag/main.go
new file mode 100644
index 0000000..2653d2b
--- /dev/null
+++ b/cmd/rag/main.go
@@ -0,0 +1,43 @@
+package main
+
+import (
+ "easy_rag/api"
+ "easy_rag/config"
+ "easy_rag/internal/database"
+ "easy_rag/internal/embeddings"
+ "easy_rag/internal/llm"
+ "easy_rag/internal/pkg/rag"
+ "github.com/labstack/echo/v4"
+)
+
+// Rag is the main struct for the rag application
+
+func main() {
+ cfg := config.NewConfig()
+
+ llm := llm.NewOllama(
+ cfg.OllamaEndpoint,
+ cfg.OllamaModel,
+ )
+ embeddings := embeddings.NewOllamaEmbeddings(
+ cfg.OllamaEmbeddingEndpoint,
+ cfg.OllamaEmbeddingModel,
+ )
+ database := database.NewMilvus(cfg.MilvusHost)
+
+ // Rag instance
+ rag := rag.NewRag(
+ llm,
+ embeddings,
+ database,
+ )
+
+ // Echo WebServer instance
+ e := echo.New()
+
+ // Wrapper for API
+ api.NewAPI(e, rag)
+
+ // Start Server
+ e.Logger.Fatal(e.Start(":4002"))
+}
diff --git a/config/config.go b/config/config.go
new file mode 100644
index 0000000..d6a7e05
--- /dev/null
+++ b/config/config.go
@@ -0,0 +1,38 @@
+package config
+
+import cfg "github.com/eschao/config"
+
+type Config struct {
+ // LLM
+ OpenAIAPIKey string `env:"OPENAI_API_KEY"`
+ OpenAIEndpoint string `env:"OPENAI_ENDPOINT"`
+ OpenAIModel string `env:"OPENAI_MODEL"`
+ OpenRouteAPIKey string `env:"OPENROUTE_API_KEY"`
+ OpenRouteEndpoint string `env:"OPENROUTE_ENDPOINT"`
+ OpenRouteModel string `env:"OPENROUTE_MODEL"`
+
+ OllamaEndpoint string `env:"OLLAMA_ENDPOINT"`
+ OllamaModel string `env:"OLLAMA_MODEL"`
+
+ // Embeddings
+ OpenAIEmbeddingAPIKey string `env:"OPENAI_EMBEDDING_API_KEY"`
+ OpenAIEmbeddingEndpoint string `env:"OPENAI_EMBEDDING_ENDPOINT"`
+ OpenAIEmbeddingModel string `env:"OPENAI_EMBEDDING_MODEL"`
+ OllamaEmbeddingEndpoint string `env:"OLLAMA_EMBEDDING_ENDPOINT"`
+ OllamaEmbeddingModel string `env:"OLLAMA_EMBEDDING_MODEL"`
+
+ // Database
+ MilvusHost string `env:"MILVUS_HOST"`
+}
+
+func NewConfig() Config {
+ config := Config{
+ MilvusHost: "192.168.10.56:19530",
+ OllamaEmbeddingEndpoint: "http://192.168.10.56:11434",
+ OllamaEmbeddingModel: "bge-m3",
+ OllamaEndpoint: "http://192.168.10.56:11434/api/chat",
+ OllamaModel: "qwen3:latest",
+ }
+ cfg.ParseEnv(&config)
+ return config
+}
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
new file mode 100644
index 0000000..ccf9f09
--- /dev/null
+++ b/deploy/Dockerfile
@@ -0,0 +1,16 @@
+FROM golang:1.20 AS builder
+
+WORKDIR /cmd
+
+COPY . .
+
+RUN go mod download
+RUN go build -o rag ./cmd/rag
+
+FROM alpine:latest
+
+WORKDIR /root/
+
+COPY --from=builder /cmd/rag ./
+
+CMD ["./rag"]
\ No newline at end of file
diff --git a/deploy/docker-compose.yaml b/deploy/docker-compose.yaml
new file mode 100644
index 0000000..e69de29
diff --git a/embedEtcd.yaml b/embedEtcd.yaml
new file mode 100644
index 0000000..32954fa
--- /dev/null
+++ b/embedEtcd.yaml
@@ -0,0 +1,5 @@
+listen-client-urls: http://0.0.0.0:2379
+advertise-client-urls: http://0.0.0.0:2379
+quota-backend-bytes: 4294967296
+auto-compaction-mode: revision
+auto-compaction-retention: '1000'
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..ce63a30
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,49 @@
+module easy_rag
+
+go 1.24.3
+
+require (
+ github.com/eschao/config v0.1.0
+ github.com/google/uuid v1.6.0
+ github.com/jonathanhecl/chunker v0.0.1
+ github.com/labstack/echo/v4 v4.13.4
+ github.com/milvus-io/milvus-sdk-go/v2 v2.4.2
+ github.com/stretchr/testify v1.10.0
+)
+
+require (
+ github.com/cockroachdb/errors v1.9.1 // indirect
+ github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
+ github.com/cockroachdb/redact v1.1.3 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/getsentry/sentry-go v0.12.0 // indirect
+ github.com/gogo/protobuf v1.3.2 // indirect
+ github.com/golang/protobuf v1.5.2 // indirect
+ github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
+ github.com/kr/pretty v0.3.0 // indirect
+ github.com/kr/text v0.2.0 // indirect
+ github.com/labstack/gommon v0.4.2 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect
+ github.com/pkg/errors v0.9.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/rogpeppe/go-internal v1.8.1 // indirect
+ github.com/stretchr/objx v0.5.2 // indirect
+ github.com/tidwall/gjson v1.14.4 // indirect
+ github.com/tidwall/match v1.1.1 // indirect
+ github.com/tidwall/pretty v1.2.0 // indirect
+ github.com/valyala/bytebufferpool v1.0.0 // indirect
+ github.com/valyala/fasttemplate v1.2.2 // indirect
+ golang.org/x/crypto v0.38.0 // indirect
+ golang.org/x/net v0.40.0 // indirect
+ golang.org/x/sync v0.14.0 // indirect
+ golang.org/x/sys v0.33.0 // indirect
+ golang.org/x/text v0.25.0 // indirect
+ golang.org/x/time v0.11.0 // indirect
+ google.golang.org/genproto v0.0.0-20220503193339-ba3ae3f07e29 // indirect
+ google.golang.org/grpc v1.48.0 // indirect
+ google.golang.org/protobuf v1.33.0 // indirect
+ gopkg.in/yaml.v2 v2.2.8 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/internal/database/database.go b/internal/database/database.go
new file mode 100644
index 0000000..c543e7e
--- /dev/null
+++ b/internal/database/database.go
@@ -0,0 +1,20 @@
+package database
+
+import "easy_rag/internal/models"
+
+// database interface
+
+// Database defines the interface for interacting with a database
+type Database interface {
+ SaveDocument(document models.Document) error // the content will be chunked and saved
+ GetDocumentInfo(id string) (models.Document, error) // return the document with the given id without content
+ GetDocument(id string) (models.Document, error) // return the document with the given id with content assembled
+ Search(vector [][]float32) ([]models.Embedding, error)
+ ListDocuments() ([]models.Document, error)
+ DeleteDocument(id string) error
+ SaveEmbeddings(embeddings []models.Embedding) error
+ // to implement in future
+ // SearchByCategory(category []string) ([]Embedding, error)
+ // SearchByMetadata(metadata map[string]string) ([]Embedding, error)
+ // GetAllEmbeddingByDocumentID(documentID string) ([]Embedding, error)
+}
diff --git a/internal/database/database_milvus.go b/internal/database/database_milvus.go
new file mode 100644
index 0000000..1bf6771
--- /dev/null
+++ b/internal/database/database_milvus.go
@@ -0,0 +1,142 @@
+package database
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "sort"
+
+ "easy_rag/internal/models"
+ "easy_rag/internal/pkg/database/milvus"
+)
+
+// implement database interface for milvus
+type Milvus struct {
+ Host string
+ Client *milvus.Client
+}
+
+func NewMilvus(host string) *Milvus {
+
+ milviusClient, err := milvus.NewClient(host)
+ if err != nil {
+ panic(err)
+ }
+
+ return &Milvus{
+ Host: host,
+ Client: milviusClient,
+ }
+}
+
+func (m *Milvus) SaveDocument(document models.Document) error {
+ // for now lets use context background
+ ctx := context.Background()
+
+ return m.Client.InsertDocuments(ctx, []models.Document{document})
+}
+
+func (m *Milvus) SaveEmbeddings(embeddings []models.Embedding) error {
+ ctx := context.Background()
+ return m.Client.InsertEmbeddings(ctx, embeddings)
+}
+
+func (m *Milvus) GetDocumentInfo(id string) (models.Document, error) {
+ ctx := context.Background()
+ doc, err := m.Client.GetDocumentByID(ctx, id)
+
+ if err != nil {
+ return models.Document{}, err
+ }
+
+ if len(doc) == 0 {
+ return models.Document{}, nil
+ }
+ return models.Document{
+ ID: doc["ID"].(string),
+ Link: doc["Link"].(string),
+ Filename: doc["Filename"].(string),
+ Category: doc["Category"].(string),
+ EmbeddingModel: doc["EmbeddingModel"].(string),
+ Summary: doc["Summary"].(string),
+ Metadata: doc["Metadata"].(map[string]string),
+ }, nil
+}
+
+func (m *Milvus) GetDocument(id string) (models.Document, error) {
+ ctx := context.Background()
+ doc, err := m.Client.GetDocumentByID(ctx, id)
+ if err != nil {
+ return models.Document{}, err
+ }
+
+ embeds, err := m.Client.GetAllEmbeddingByDocID(ctx, id)
+
+ if err != nil {
+ return models.Document{}, err
+ }
+
+ // order embed by order
+ sort.Slice(embeds, func(i, j int) bool {
+ return embeds[i].Order < embeds[j].Order
+ })
+
+ // concatenate text chunks
+ var buf bytes.Buffer
+ for _, embed := range embeds {
+ buf.WriteString(embed.TextChunk)
+ }
+
+ textChunks := buf.String()
+
+ if len(doc) == 0 {
+ return models.Document{}, nil
+ }
+ return models.Document{
+ ID: doc["ID"].(string),
+ Content: textChunks,
+ Link: doc["Link"].(string),
+ Filename: doc["Filename"].(string),
+ Category: doc["Category"].(string),
+ EmbeddingModel: doc["EmbeddingModel"].(string),
+ Summary: doc["Summary"].(string),
+ Metadata: doc["Metadata"].(map[string]string),
+ }, nil
+}
+
+func (m *Milvus) Search(vector [][]float32) ([]models.Embedding, error) {
+ ctx := context.Background()
+ results, err := m.Client.Search(ctx, vector, 10)
+ if err != nil {
+ return nil, err
+ }
+
+ return results, nil
+}
+
+func (m *Milvus) ListDocuments() ([]models.Document, error) {
+ ctx := context.Background()
+
+ docs, err := m.Client.GetAllDocuments(ctx)
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to get docs: %w", err)
+ }
+
+ return docs, nil
+}
+
+func (m *Milvus) DeleteDocument(id string) error {
+ ctx := context.Background()
+ err := m.Client.DeleteDocument(ctx, id)
+ if err != nil {
+ return err
+ }
+
+ err = m.Client.DeleteEmbedding(ctx, id)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/embeddings/embeddings.go b/internal/embeddings/embeddings.go
new file mode 100644
index 0000000..77a6fd3
--- /dev/null
+++ b/internal/embeddings/embeddings.go
@@ -0,0 +1,8 @@
+package embeddings
+
+// implement embeddings interface
+type EmbeddingsService interface {
+ // generate embedding from text
+ Vectorize(text string) ([][]float32, error)
+ GetModel() string
+}
diff --git a/internal/embeddings/ollama_embeddings.go b/internal/embeddings/ollama_embeddings.go
new file mode 100644
index 0000000..ac8f918
--- /dev/null
+++ b/internal/embeddings/ollama_embeddings.go
@@ -0,0 +1,78 @@
+package embeddings
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+type OllamaEmbeddings struct {
+ Endpoint string
+ Model string
+}
+
+func NewOllamaEmbeddings(endpoint string, model string) *OllamaEmbeddings {
+ return &OllamaEmbeddings{
+ Endpoint: endpoint,
+ Model: model,
+ }
+}
+
+// Vectorize generates an embedding for the provided text
+func (o *OllamaEmbeddings) Vectorize(text string) ([][]float32, error) {
+ // Define the request payload
+ payload := map[string]string{
+ "model": o.Model,
+ "input": text,
+ }
+
+ // Convert the payload to JSON
+ jsonData, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request payload: %w", err)
+ }
+
+ // Create the HTTP request
+ url := fmt.Sprintf("%s/api/embed", o.Endpoint)
+ req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ // Execute the HTTP request
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to make HTTP request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Check for non-200 status code
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("received non-200 response: %s", body)
+ }
+
+ // Read and parse the response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Assuming the response JSON contains an "embedding" field with a float32 array
+ var response struct {
+ Embeddings [][]float32 `json:"embeddings"`
+ }
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+
+ return response.Embeddings, nil
+}
+
+func (o *OllamaEmbeddings) GetModel() string {
+ return o.Model
+}
diff --git a/internal/embeddings/openai_embeddings.go b/internal/embeddings/openai_embeddings.go
new file mode 100644
index 0000000..f42fbb8
--- /dev/null
+++ b/internal/embeddings/openai_embeddings.go
@@ -0,0 +1,23 @@
+package embeddings
+
+type OpenAIEmbeddings struct {
+ APIKey string
+ Endpoint string
+ Model string
+}
+
+func NewOpenAIEmbeddings(apiKey string, endpoint string, model string) *OpenAIEmbeddings {
+ return &OpenAIEmbeddings{
+ APIKey: apiKey,
+ Endpoint: endpoint,
+ Model: model,
+ }
+}
+
+func (o *OpenAIEmbeddings) Vectorize(text string) ([]float32, error) {
+ return nil, nil
+}
+
+func (o *OpenAIEmbeddings) GetModel() string {
+ return o.Model
+}
diff --git a/internal/llm/llm.go b/internal/llm/llm.go
new file mode 100644
index 0000000..88ef5f3
--- /dev/null
+++ b/internal/llm/llm.go
@@ -0,0 +1,8 @@
+package llm
+
+// implement llm interface
+type LLMService interface {
+ // generate text from prompt
+ Generate(prompt string) (string, error)
+ GetModel() string
+}
diff --git a/internal/llm/ollama_llm.go b/internal/llm/ollama_llm.go
new file mode 100644
index 0000000..54d9271
--- /dev/null
+++ b/internal/llm/ollama_llm.go
@@ -0,0 +1,82 @@
+package llm
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+type Ollama struct {
+ Endpoint string
+ Model string
+}
+
+func NewOllama(endpoint string, model string) *Ollama {
+ return &Ollama{
+ Endpoint: endpoint,
+ Model: model,
+ }
+}
+
+// Response represents the structure of the expected response from the API.
+type Response struct {
+ Model string `json:"model"`
+ CreatedAt string `json:"created_at"`
+ Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ } `json:"message"`
+}
+
+// Generate sends a prompt to the Ollama endpoint and returns the response
+func (o *Ollama) Generate(prompt string) (string, error) {
+ // Create the request payload
+ payload := map[string]interface{}{
+ "model": o.Model,
+ "messages": []map[string]string{
+ {
+ "role": "user",
+ "content": prompt,
+ },
+ },
+ "stream": false,
+ }
+
+ // Marshal the payload into JSON
+ data, err := json.Marshal(payload)
+ if err != nil {
+ return "", fmt.Errorf("failed to marshal payload: %w", err)
+ }
+
+ // Make the POST request
+ resp, err := http.Post(o.Endpoint, "application/json", bytes.NewBuffer(data))
+ if err != nil {
+ return "", fmt.Errorf("failed to make request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Read and parse the response
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("API returned error: %s", string(body))
+ }
+
+ // Unmarshal the response into a predefined structure
+ var response Response
+ if err := json.Unmarshal(body, &response); err != nil {
+ return "", fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+
+ // Extract and return the content from the nested structure
+ return response.Message.Content, nil
+}
+
+func (o *Ollama) GetModel() string {
+ return o.Model
+}
diff --git a/internal/llm/openai_llm.go b/internal/llm/openai_llm.go
new file mode 100644
index 0000000..d88f3f8
--- /dev/null
+++ b/internal/llm/openai_llm.go
@@ -0,0 +1,24 @@
+package llm
+
+type OpenAI struct {
+ APIKey string
+ Endpoint string
+ Model string
+}
+
+func NewOpenAI(apiKey string, endpoint string, model string) *OpenAI {
+ return &OpenAI{
+ APIKey: apiKey,
+ Endpoint: endpoint,
+ Model: model,
+ }
+}
+
+func (o *OpenAI) Generate(prompt string) (string, error) {
+ return "", nil
+ // TODO: implement
+}
+
+func (o *OpenAI) GetModel() string {
+ return o.Model
+}
diff --git a/internal/llm/openroute/definitions.go b/internal/llm/openroute/definitions.go
new file mode 100644
index 0000000..f427600
--- /dev/null
+++ b/internal/llm/openroute/definitions.go
@@ -0,0 +1,155 @@
+package openroute
+
+// Request represents the main request structure.
+type Request struct {
+ Messages []MessageRequest `json:"messages,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Model string `json:"model,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
+ ToolChoice ToolChoice `json:"tool_choice,omitempty"`
+ Seed int `json:"seed,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
+ LogitBias map[int]float64 `json:"logit_bias,omitempty"`
+ TopLogprobs int `json:"top_logprobs,omitempty"`
+ MinP float64 `json:"min_p,omitempty"`
+ TopA float64 `json:"top_a,omitempty"`
+ Prediction *Prediction `json:"prediction,omitempty"`
+ Transforms []string `json:"transforms,omitempty"`
+ Models []string `json:"models,omitempty"`
+ Route string `json:"route,omitempty"`
+ Provider *ProviderPreferences `json:"provider,omitempty"`
+ IncludeReasoning bool `json:"include_reasoning,omitempty"`
+}
+
+// ResponseFormat represents the response format structure.
+type ResponseFormat struct {
+ Type string `json:"type"`
+}
+
+// Prediction represents the prediction structure.
+type Prediction struct {
+ Type string `json:"type"`
+ Content string `json:"content"`
+}
+
+// ProviderPreferences represents the provider preferences structure.
+type ProviderPreferences struct {
+ RefererURL string `json:"referer_url,omitempty"`
+ SiteName string `json:"site_name,omitempty"`
+}
+
+// Message represents the message structure.
+type MessageRequest struct {
+ Role MessageRole `json:"role"`
+ Content interface{} `json:"content"` // Can be string or []ContentPart
+ Name string `json:"name,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+type MessageRole string
+
+const (
+ RoleSystem MessageRole = "system"
+ RoleUser MessageRole = "user"
+ RoleAssistant MessageRole = "assistant"
+)
+
+// ContentPart represents the content part structure.
+type ContentPart struct {
+ Type ContnetType `json:"type"`
+ Text string `json:"text,omitempty"`
+ ImageURL *ImageURL `json:"image_url,omitempty"`
+}
+
+type ContnetType string
+
+const (
+ ContentTypeText ContnetType = "text"
+ ContentTypeImage ContnetType = "image_url"
+)
+
+// ImageURL represents the image URL structure.
+type ImageURL struct {
+ URL string `json:"url"`
+ Detail string `json:"detail,omitempty"`
+}
+
+// FunctionDescription represents the function description structure.
+type FunctionDescription struct {
+ Description string `json:"description,omitempty"`
+ Name string `json:"name"`
+ Parameters interface{} `json:"parameters"` // JSON Schema object
+}
+
+// Tool represents the tool structure.
+type Tool struct {
+ Type string `json:"type"`
+ Function FunctionDescription `json:"function"`
+}
+
+// ToolChoice represents the tool choice structure.
+type ToolChoice struct {
+ Type string `json:"type"`
+ Function struct {
+ Name string `json:"name"`
+ } `json:"function"`
+}
+
+type Response struct {
+ ID string `json:"id"`
+ Choices []Choice `json:"choices"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Object string `json:"object"`
+ SystemFingerprint *string `json:"system_fingerprint,omitempty"`
+ Usage *ResponseUsage `json:"usage,omitempty"`
+}
+
+type ResponseUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type Choice struct {
+ FinishReason string `json:"finish_reason"`
+ Text string `json:"text,omitempty"`
+ Message *MessageResponse `json:"message,omitempty"`
+ Delta *Delta `json:"delta,omitempty"`
+ Error *ErrorResponse `json:"error,omitempty"`
+}
+
+type MessageResponse struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+ Reasoning string `json:"reasoning,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+}
+
+type Delta struct {
+ Content string `json:"content"`
+ Role string `json:"role,omitempty"`
+ Reasoning string `json:"reasoning,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+}
+
+type ErrorResponse struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type ToolCall struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function interface{} `json:"function"`
+}
diff --git a/internal/llm/openroute/route_agent.go b/internal/llm/openroute/route_agent.go
new file mode 100644
index 0000000..8fd58bd
--- /dev/null
+++ b/internal/llm/openroute/route_agent.go
@@ -0,0 +1,198 @@
+package openroute
+
+import "context"
+
+type RouterAgentConfig struct {
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Tools []Tool `json:"tools,omitempty"`
+ ToolChoice ToolChoice `json:"tool_choice,omitempty"`
+ Seed int `json:"seed,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
+ LogitBias map[int]float64 `json:"logit_bias,omitempty"`
+ TopLogprobs int `json:"top_logprobs,omitempty"`
+ MinP float64 `json:"min_p,omitempty"`
+ TopA float64 `json:"top_a,omitempty"`
+}
+
+type RouterAgent struct {
+ client *OpenRouterClient
+ model string
+ config RouterAgentConfig
+}
+
+func NewRouterAgent(client *OpenRouterClient, model string, config RouterAgentConfig) *RouterAgent {
+ return &RouterAgent{
+ client: client,
+ model: model,
+ config: config,
+ }
+}
+
+func (agent RouterAgent) Completion(prompt string) (*Response, error) {
+ request := Request{
+ Prompt: prompt,
+ Model: agent.model,
+ ResponseFormat: agent.config.ResponseFormat,
+ Stop: agent.config.Stop,
+ MaxTokens: agent.config.MaxTokens,
+ Temperature: agent.config.Temperature,
+ Tools: agent.config.Tools,
+ ToolChoice: agent.config.ToolChoice,
+ Seed: agent.config.Seed,
+ TopP: agent.config.TopP,
+ TopK: agent.config.TopK,
+ FrequencyPenalty: agent.config.FrequencyPenalty,
+ PresencePenalty: agent.config.PresencePenalty,
+ RepetitionPenalty: agent.config.RepetitionPenalty,
+ LogitBias: agent.config.LogitBias,
+ TopLogprobs: agent.config.TopLogprobs,
+ MinP: agent.config.MinP,
+ TopA: agent.config.TopA,
+ Stream: false,
+ }
+
+ return agent.client.FetchChatCompletions(request)
+}
+
+func (agent RouterAgent) CompletionStream(prompt string, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
+ request := Request{
+ Prompt: prompt,
+ Model: agent.model,
+ ResponseFormat: agent.config.ResponseFormat,
+ Stop: agent.config.Stop,
+ MaxTokens: agent.config.MaxTokens,
+ Temperature: agent.config.Temperature,
+ Tools: agent.config.Tools,
+ ToolChoice: agent.config.ToolChoice,
+ Seed: agent.config.Seed,
+ TopP: agent.config.TopP,
+ TopK: agent.config.TopK,
+ FrequencyPenalty: agent.config.FrequencyPenalty,
+ PresencePenalty: agent.config.PresencePenalty,
+ RepetitionPenalty: agent.config.RepetitionPenalty,
+ LogitBias: agent.config.LogitBias,
+ TopLogprobs: agent.config.TopLogprobs,
+ MinP: agent.config.MinP,
+ TopA: agent.config.TopA,
+ Stream: true,
+ }
+
+ agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
+}
+
+func (agent RouterAgent) Chat(messages []MessageRequest) (*Response, error) {
+ request := Request{
+ Messages: messages,
+ Model: agent.model,
+ ResponseFormat: agent.config.ResponseFormat,
+ Stop: agent.config.Stop,
+ MaxTokens: agent.config.MaxTokens,
+ Temperature: agent.config.Temperature,
+ Tools: agent.config.Tools,
+ ToolChoice: agent.config.ToolChoice,
+ Seed: agent.config.Seed,
+ TopP: agent.config.TopP,
+ TopK: agent.config.TopK,
+ FrequencyPenalty: agent.config.FrequencyPenalty,
+ PresencePenalty: agent.config.PresencePenalty,
+ RepetitionPenalty: agent.config.RepetitionPenalty,
+ LogitBias: agent.config.LogitBias,
+ TopLogprobs: agent.config.TopLogprobs,
+ MinP: agent.config.MinP,
+ TopA: agent.config.TopA,
+ Stream: false,
+ }
+
+ return agent.client.FetchChatCompletions(request)
+}
+
+func (agent RouterAgent) ChatStream(messages []MessageRequest, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
+ request := Request{
+ Messages: messages,
+ Model: agent.model,
+ ResponseFormat: agent.config.ResponseFormat,
+ Stop: agent.config.Stop,
+ MaxTokens: agent.config.MaxTokens,
+ Temperature: agent.config.Temperature,
+ Tools: agent.config.Tools,
+ ToolChoice: agent.config.ToolChoice,
+ Seed: agent.config.Seed,
+ TopP: agent.config.TopP,
+ TopK: agent.config.TopK,
+ FrequencyPenalty: agent.config.FrequencyPenalty,
+ PresencePenalty: agent.config.PresencePenalty,
+ RepetitionPenalty: agent.config.RepetitionPenalty,
+ LogitBias: agent.config.LogitBias,
+ TopLogprobs: agent.config.TopLogprobs,
+ MinP: agent.config.MinP,
+ TopA: agent.config.TopA,
+ Stream: true,
+ }
+
+ agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
+}
+
+type RouterAgentChat struct {
+ RouterAgent
+ Messages []MessageRequest
+}
+
+func NewRouterAgentChat(client *OpenRouterClient, model string, config RouterAgentConfig, system_prompt string) RouterAgentChat {
+ return RouterAgentChat{
+ RouterAgent: RouterAgent{
+ client: client,
+ model: model,
+ config: config,
+ },
+ Messages: []MessageRequest{
+ {
+ Role: RoleSystem,
+ Content: system_prompt,
+ },
+ },
+ }
+}
+
+func (agent *RouterAgentChat) Chat(message string) error {
+ agent.Messages = append(agent.Messages, MessageRequest{
+ Role: RoleUser,
+ Content: message,
+ })
+ request := Request{
+ Messages: agent.Messages,
+ Model: agent.model,
+ ResponseFormat: agent.config.ResponseFormat,
+ Stop: agent.config.Stop,
+ MaxTokens: agent.config.MaxTokens,
+ Temperature: agent.config.Temperature,
+ Tools: agent.config.Tools,
+ ToolChoice: agent.config.ToolChoice,
+ Seed: agent.config.Seed,
+ TopP: agent.config.TopP,
+ TopK: agent.config.TopK,
+ FrequencyPenalty: agent.config.FrequencyPenalty,
+ PresencePenalty: agent.config.PresencePenalty,
+ RepetitionPenalty: agent.config.RepetitionPenalty,
+ LogitBias: agent.config.LogitBias,
+ TopLogprobs: agent.config.TopLogprobs,
+ MinP: agent.config.MinP,
+ TopA: agent.config.TopA,
+ Stream: false,
+ }
+
+ response, err := agent.client.FetchChatCompletions(request)
+
+ agent.Messages = append(agent.Messages, MessageRequest{
+ Role: RoleAssistant,
+ Content: response.Choices[0].Message.Content,
+ })
+
+ return err
+}
diff --git a/internal/llm/openroute/route_client.go b/internal/llm/openroute/route_client.go
new file mode 100644
index 0000000..eade483
--- /dev/null
+++ b/internal/llm/openroute/route_client.go
@@ -0,0 +1,188 @@
+package openroute
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+)
+
+type OpenRouterClient struct {
+ apiKey string
+ apiURL string
+ httpClient *http.Client
+}
+
+func NewOpenRouterClient(apiKey string) *OpenRouterClient {
+ return &OpenRouterClient{
+ apiKey: apiKey,
+ apiURL: "https://openrouter.ai/api/v1",
+ httpClient: &http.Client{},
+ }
+}
+
+func NewOpenRouterClientFull(apiKey string, apiUrl string, client *http.Client) *OpenRouterClient {
+ return &OpenRouterClient{
+ apiKey: apiKey,
+ apiURL: apiUrl,
+ httpClient: client,
+ }
+}
+
+func (c *OpenRouterClient) FetchChatCompletions(request Request) (*Response, error) {
+ headers := map[string]string{
+ "Authorization": "Bearer " + c.apiKey,
+ "Content-Type": "application/json",
+ }
+
+ if request.Provider != nil {
+ headers["HTTP-Referer"] = request.Provider.RefererURL
+ headers["X-Title"] = request.Provider.SiteName
+ }
+
+ body, err := json.Marshal(request)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
+ if err != nil {
+ return nil, err
+ }
+
+ for key, value := range headers {
+ req.Header.Set(key, value)
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ output, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ outputReponse := &Response{}
+ err = json.Unmarshal(output, outputReponse)
+ if err != nil {
+ return nil, err
+ }
+
+ return outputReponse, nil
+}
+
+func (c *OpenRouterClient) FetchChatCompletionsStream(request Request, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
+ headers := map[string]string{
+ "Authorization": "Bearer " + c.apiKey,
+ "Content-Type": "application/json",
+ }
+
+ if request.Provider != nil {
+ headers["HTTP-Referer"] = request.Provider.RefererURL
+ headers["X-Title"] = request.Provider.SiteName
+ }
+
+ body, err := json.Marshal(request)
+ if err != nil {
+ errChan <- err
+ close(errChan)
+ close(outputChan)
+ close(processingChan)
+ return
+ }
+
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
+ if err != nil {
+ errChan <- err
+ close(errChan)
+ close(outputChan)
+ close(processingChan)
+ return
+ }
+
+ for key, value := range headers {
+ req.Header.Set(key, value)
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ errChan <- err
+ close(errChan)
+ close(outputChan)
+ close(processingChan)
+ return
+ }
+
+ go func() {
+ defer resp.Body.Close()
+
+ defer close(errChan)
+ defer close(outputChan)
+ defer close(processingChan)
+ if resp.StatusCode != http.StatusOK {
+ errChan <- fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ return
+ }
+
+ reader := bufio.NewReader(resp.Body)
+ for {
+ select {
+ case <-ctx.Done():
+ errChan <- ctx.Err()
+ close(errChan)
+ close(outputChan)
+ close(processingChan)
+ return
+ default:
+ line, err := reader.ReadString('\n')
+ line = strings.TrimSpace(line)
+ if strings.HasPrefix(line, ":") {
+ select {
+ case processingChan <- true:
+ case <-ctx.Done():
+ errChan <- ctx.Err()
+ return
+ }
+ continue
+ }
+
+ if line != "" {
+ if strings.Compare(line[6:], "[DONE]") == 0 {
+ return
+ }
+ response := Response{}
+ err = json.Unmarshal([]byte(line[6:]), &response)
+ if err != nil {
+ errChan <- err
+ return
+ }
+ select {
+ case outputChan <- response:
+ case <-ctx.Done():
+ errChan <- ctx.Err()
+ return
+ }
+ }
+
+ if err != nil {
+ if err == io.EOF {
+ return
+ }
+ errChan <- err
+ return
+ }
+ }
+ }
+ }()
+}
diff --git a/internal/llm/openroute_llm.go b/internal/llm/openroute_llm.go
new file mode 100644
index 0000000..bba1168
--- /dev/null
+++ b/internal/llm/openroute_llm.go
@@ -0,0 +1,24 @@
+package llm
+
+type OpenRoute struct {
+ APIKey string
+ Endpoint string
+ Model string
+}
+
+func NewOpenRoute(apiKey string, endpoint string, model string) *OpenAI {
+ return &OpenAI{
+ APIKey: apiKey,
+ Endpoint: endpoint,
+ Model: model,
+ }
+}
+
+func (o *OpenRoute) Generate(prompt string) (string, error) {
+ return "", nil
+ // TODO: implement
+}
+
+func (o *OpenRoute) GetModel() string {
+ return o.Model
+}
diff --git a/internal/models/models.go b/internal/models/models.go
new file mode 100644
index 0000000..0531a9e
--- /dev/null
+++ b/internal/models/models.go
@@ -0,0 +1,27 @@
+package models
+
+// type VectorEmbedding [][]float32
+// type Vector []float32
+// Document represents the data structure for storing documents
+type Document struct {
+ ID string `json:"id" milvus:"ID"` // Unique identifier for the document
+ Content string `json:"content" milvus:"Content"` // Text content of the document become chunks of data will not be saved
+ Link string `json:"link" milvus:"Link"` // Link to the document
+ Filename string `json:"filename" milvus:"Filename"` // Filename of the document
+ Category string `json:"category" milvus:"Category"` // Category of the document
+ EmbeddingModel string `json:"embedding_model" milvus:"EmbeddingModel"` // Embedding model used to generate the embedding
+ Summary string `json:"summary" milvus:"Summary"` // Summary of the document
+ Metadata map[string]string `json:"metadata" milvus:"Metadata"` // Additional metadata (e.g., author, timestamp)
+ Vector []float32 `json:"vector" milvus:"Vector"`
+}
+
+// Embedding represents the vector embedding for a document or query
+type Embedding struct {
+ ID string `json:"id" milvus:"ID"` // Unique identifier
+ DocumentID string `json:"document_id" milvus:"DocumentID"` // Unique identifier linked to a Document
+ Vector []float32 `json:"vector" milvus:"Vector"` // The embedding vector
+ TextChunk string `json:"text_chunk" milvus:"TextChunk"` // Text chunk of the document
+ Dimension int64 `json:"dimension" milvus:"Dimension"` // Dimensionality of the vector
+ Order int64 `json:"order" milvus:"Order"` // Order of the embedding to build the content back
+ Score float32 `json:"score"` // Score of the embedding
+}
diff --git a/internal/pkg/database/milvus/client.go b/internal/pkg/database/milvus/client.go
new file mode 100644
index 0000000..c4a7e9d
--- /dev/null
+++ b/internal/pkg/database/milvus/client.go
@@ -0,0 +1,169 @@
+package milvus
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "time"
+
+ "github.com/milvus-io/milvus-sdk-go/v2/client"
+ "github.com/milvus-io/milvus-sdk-go/v2/entity"
+)
+
+type Client struct {
+ Instance client.Client
+}
+
+// InitMilvusClient initializes the Milvus client and returns a wrapper around it.
+func NewClient(milvusAddr string) (*Client, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ c, err := client.NewClient(ctx, client.Config{Address: milvusAddr})
+ if err != nil {
+ log.Printf("Failed to connect to Milvus: %v", err)
+ return nil, err
+ }
+
+ client := &Client{Instance: c}
+
+ err = client.EnsureCollections(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return client, nil
+}
+
+// EnsureCollections ensures that the required collections ("documents" and "chunks") exist.
+// If they don't exist, it creates them based on the predefined structs.
+func (m *Client) EnsureCollections(ctx context.Context) error {
+ collections := []struct {
+ Name string
+ Schema *entity.Schema
+ IndexField string
+ IndexType string
+ MetricType entity.MetricType
+ Nlist int
+ }{
+ {
+ Name: "documents",
+ Schema: createDocumentSchema(),
+ IndexField: "Vector", // Indexing the Vector field for similarity search
+ IndexType: "IVF_FLAT",
+ MetricType: entity.L2,
+ Nlist: 10, // Number of clusters for IVF_FLAT index
+ },
+ {
+ Name: "chunks",
+ Schema: createEmbeddingSchema(),
+ IndexField: "Vector", // Indexing the Vector field for similarity search
+ IndexType: "IVF_FLAT",
+ MetricType: entity.L2,
+ Nlist: 10,
+ },
+ }
+
+ for _, collection := range collections {
+ // drop collection
+ // err := m.Instance.DropCollection(ctx, collection.Name)
+ // if err != nil {
+ // return fmt.Errorf("failed to drop collection '%s': %w", collection.Name, err)
+ // }
+
+ // Ensure the collection exists
+ exists, err := m.Instance.HasCollection(ctx, collection.Name)
+ if err != nil {
+ return fmt.Errorf("failed to check collection existence: %w", err)
+ }
+
+ if !exists {
+ err := m.Instance.CreateCollection(ctx, collection.Schema, entity.DefaultShardNumber)
+ if err != nil {
+ return fmt.Errorf("failed to create collection '%s': %w", collection.Name, err)
+ }
+ log.Printf("Collection '%s' created successfully", collection.Name)
+ } else {
+ log.Printf("Collection '%s' already exists", collection.Name)
+ }
+
+ // Ensure the default partition exists
+ hasPartition, err := m.Instance.HasPartition(ctx, collection.Name, "_default")
+ if err != nil {
+ return fmt.Errorf("failed to check default partition for collection '%s': %w", collection.Name, err)
+ }
+
+ if !hasPartition {
+ err = m.Instance.CreatePartition(ctx, collection.Name, "_default")
+ if err != nil {
+ return fmt.Errorf("failed to create default partition for collection '%s': %w", collection.Name, err)
+ }
+ log.Printf("Default partition created for collection '%s'", collection.Name)
+ }
+
+ // Skip index creation if IndexField is empty
+ if collection.IndexField == "" {
+ continue
+ }
+
+ // Ensure the index exists
+ log.Printf("Creating index on field '%s' for collection '%s'", collection.IndexField, collection.Name)
+
+ idx, err := entity.NewIndexIvfFlat(collection.MetricType, collection.Nlist)
+ if err != nil {
+ return fmt.Errorf("failed to create IVF_FLAT index: %w", err)
+ }
+
+ err = m.Instance.CreateIndex(ctx, collection.Name, collection.IndexField, idx, false)
+ if err != nil {
+ return fmt.Errorf("failed to create index on field '%s' for collection '%s': %w", collection.IndexField, collection.Name, err)
+ }
+
+ log.Printf("Index created on field '%s' for collection '%s'", collection.IndexField, collection.Name)
+ }
+
+ err := m.Instance.LoadCollection(ctx, "documents", false)
+ if err != nil {
+ log.Fatalf("failed to load collection, err: %v", err)
+ }
+
+ err = m.Instance.LoadCollection(ctx, "chunks", false)
+ if err != nil {
+ log.Fatalf("failed to load collection, err: %v", err)
+ }
+
+ return nil
+}
+
+// Helper functions for creating schemas
+func createDocumentSchema() *entity.Schema {
+ return entity.NewSchema().
+ WithName("documents").
+ WithDescription("Collection for storing documents").
+ WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
+ WithField(entity.NewField().WithName("Content").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
+ WithField(entity.NewField().WithName("Link").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
+ WithField(entity.NewField().WithName("Filename").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
+ WithField(entity.NewField().WithName("Category").WithDataType(entity.FieldTypeVarChar).WithMaxLength(8048)).
+ WithField(entity.NewField().WithName("EmbeddingModel").WithDataType(entity.FieldTypeVarChar).WithMaxLength(256)).
+ WithField(entity.NewField().WithName("Summary").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
+ WithField(entity.NewField().WithName("Metadata").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
+ WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)) // bge-m3
+}
+
+func createEmbeddingSchema() *entity.Schema {
+ return entity.NewSchema().
+ WithName("chunks").
+ WithDescription("Collection for storing document embeddings").
+ WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
+ WithField(entity.NewField().WithName("DocumentID").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
+ WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)). // bge-m3
+ WithField(entity.NewField().WithName("TextChunk").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
+ WithField(entity.NewField().WithName("Dimension").WithDataType(entity.FieldTypeInt32)).
+ WithField(entity.NewField().WithName("Order").WithDataType(entity.FieldTypeInt32))
+}
+
+// Close closes the Milvus client connection.
+func (m *Client) Close() {
+ m.Instance.Close()
+}
diff --git a/internal/pkg/database/milvus/client_test.go b/internal/pkg/database/milvus/client_test.go
new file mode 100644
index 0000000..360ca32
--- /dev/null
+++ b/internal/pkg/database/milvus/client_test.go
@@ -0,0 +1,32 @@
+package milvus
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestNewClient(t *testing.T) {
+ type args struct {
+ milvusAddr string
+ }
+ tests := []struct {
+ name string
+ args args
+ want *Client
+ wantErr bool
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := NewClient(tt.args.milvusAddr)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("NewClient() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/pkg/database/milvus/helpers.go b/internal/pkg/database/milvus/helpers.go
new file mode 100644
index 0000000..cfecbab
--- /dev/null
+++ b/internal/pkg/database/milvus/helpers.go
@@ -0,0 +1,276 @@
+package milvus
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "easy_rag/internal/models"
+ "github.com/milvus-io/milvus-sdk-go/v2/client"
+ "github.com/milvus-io/milvus-sdk-go/v2/entity"
+)
+
+// Helper functions for extracting data
+func extractIDs(docs []models.Document) []string {
+ ids := make([]string, len(docs))
+ for i, doc := range docs {
+ ids[i] = doc.ID
+ }
+ return ids
+}
+
+// extractLinks extracts the "Link" field from the documents.
+func extractLinks(docs []models.Document) []string {
+ links := make([]string, len(docs))
+ for i, doc := range docs {
+ links[i] = doc.Link
+ }
+ return links
+}
+
+// extractFilenames extracts the "Filename" field from the documents.
+func extractFilenames(docs []models.Document) []string {
+ filenames := make([]string, len(docs))
+ for i, doc := range docs {
+ filenames[i] = doc.Filename
+ }
+ return filenames
+}
+
+// extractCategories extracts the "Category" field from the documents as a comma-separated string.
+func extractCategories(docs []models.Document) []string {
+ categories := make([]string, len(docs))
+ for i, doc := range docs {
+ categories[i] = fmt.Sprintf("%v", doc.Category)
+ }
+ return categories
+}
+
+// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents.
+func extractEmbeddingModels(docs []models.Document) []string {
+ models := make([]string, len(docs))
+ for i, doc := range docs {
+ models[i] = doc.EmbeddingModel
+ }
+ return models
+}
+
+// extractSummaries extracts the "Summary" field from the documents.
+func extractSummaries(docs []models.Document) []string {
+ summaries := make([]string, len(docs))
+ for i, doc := range docs {
+ summaries[i] = doc.Summary
+ }
+ return summaries
+}
+
+// extractMetadata extracts the "Metadata" field from the documents as a JSON string.
+func extractMetadata(docs []models.Document) []string {
+ metadata := make([]string, len(docs))
+ for i, doc := range docs {
+ metaBytes, _ := json.Marshal(doc.Metadata)
+ metadata[i] = string(metaBytes)
+ }
+ return metadata
+}
+
+func convertToMetadata(metadata string) map[string]string {
+ var metadataMap map[string]string
+ json.Unmarshal([]byte(metadata), &metadataMap)
+ return metadataMap
+}
+
+func extractContents(docs []models.Document) []string {
+ contents := make([]string, len(docs))
+ for i, doc := range docs {
+ contents[i] = doc.Content
+ }
+ return contents
+}
+
+// extractEmbeddingIDs extracts the "ID" field from the embeddings.
+func extractEmbeddingIDs(embeddings []models.Embedding) []string {
+ ids := make([]string, len(embeddings))
+ for i, embedding := range embeddings {
+ ids[i] = embedding.ID
+ }
+ return ids
+}
+
+// extractDocumentIDs extracts the "DocumentID" field from the embeddings.
+func extractDocumentIDs(embeddings []models.Embedding) []string {
+ documentIDs := make([]string, len(embeddings))
+ for i, embedding := range embeddings {
+ documentIDs[i] = embedding.DocumentID
+ }
+ return documentIDs
+}
+
+// extractVectors extracts the "Vector" field from the embeddings.
+func extractVectors(embeddings []models.Embedding) [][]float32 {
+ vectors := make([][]float32, len(embeddings))
+ for i, embedding := range embeddings {
+ vectors[i] = embedding.Vector // Direct assignment since it's already []float32
+ }
+ return vectors
+}
+
+// extractVectorsDocs extracts the "Vector" field from the documents.
+func extractVectorsDocs(docs []models.Document) [][]float32 {
+ vectors := make([][]float32, len(docs))
+ for i, doc := range docs {
+ vectors[i] = doc.Vector // Direct assignment since it's already []float32
+ }
+ return vectors
+}
+
+// extractTextChunks extracts the "TextChunk" field from the embeddings.
+func extractTextChunks(embeddings []models.Embedding) []string {
+ textChunks := make([]string, len(embeddings))
+ for i, embedding := range embeddings {
+ textChunks[i] = embedding.TextChunk
+ }
+ return textChunks
+}
+
+// extractDimensions extracts the "Dimension" field from the embeddings.
+func extractDimensions(embeddings []models.Embedding) []int32 {
+ dimensions := make([]int32, len(embeddings))
+ for i, embedding := range embeddings {
+ dimensions[i] = int32(embedding.Dimension)
+ }
+ return dimensions
+}
+
+// extractOrders extracts the "Order" field from the embeddings.
+func extractOrders(embeddings []models.Embedding) []int32 {
+ orders := make([]int32, len(embeddings))
+ for i, embedding := range embeddings {
+ orders[i] = int32(embedding.Order)
+ }
+ return orders
+}
+
+func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) {
+ if rs == nil || rs.Len() == 0 {
+ return nil, fmt.Errorf("empty result set")
+ }
+
+ results := []map[string]interface{}{}
+
+ for i := 0; i < rs.Len(); i++ { // Iterate through rows
+ row := map[string]interface{}{}
+
+ for _, fieldName := range outputFields {
+ column := rs.GetColumn(fieldName)
+ if column == nil {
+ return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
+ }
+
+ switch column.Type() {
+ case entity.FieldTypeInt64:
+ value, err := column.GetAsInt64(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
+ }
+ row[fieldName] = value
+
+ case entity.FieldTypeInt32:
+ value, err := column.GetAsInt64(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
+ }
+ row[fieldName] = value
+
+ case entity.FieldTypeFloat:
+ value, err := column.GetAsDouble(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
+ }
+ row[fieldName] = value
+
+ case entity.FieldTypeDouble:
+ value, err := column.GetAsDouble(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
+ }
+ row[fieldName] = value
+
+ case entity.FieldTypeVarChar:
+ value, err := column.GetAsString(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
+ }
+ row[fieldName] = value
+
+ default:
+ return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
+ }
+ }
+
+ results = append(results, row)
+ }
+
+ return results, nil
+}
+
+func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) {
+ if rs.ResultCount == 0 {
+ return nil, fmt.Errorf("empty result set")
+ }
+
+ result := make([]map[string]interface{}, rs.ResultCount)
+
+ for i := 0; i < rs.ResultCount; i++ { // Iterate through rows
+ result[i] = make(map[string]interface{})
+ for _, fieldName := range outputFields {
+ column := rs.Fields.GetColumn(fieldName)
+ result[i]["Score"] = rs.Scores[i]
+
+ if column == nil {
+ return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
+ }
+
+ switch column.Type() {
+ case entity.FieldTypeInt64:
+ value, err := column.GetAsInt64(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
+ }
+ result[i][fieldName] = value
+
+ case entity.FieldTypeInt32:
+ value, err := column.GetAsInt64(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
+ }
+ result[i][fieldName] = value
+
+ case entity.FieldTypeFloat:
+ value, err := column.GetAsDouble(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
+ }
+ result[i][fieldName] = value
+
+ case entity.FieldTypeDouble:
+ value, err := column.GetAsDouble(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
+ }
+ result[i][fieldName] = value
+
+ case entity.FieldTypeVarChar:
+ value, err := column.GetAsString(i)
+ if err != nil {
+ return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
+ }
+ result[i][fieldName] = value
+
+ default:
+ return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
+ }
+ }
+ }
+
+ return result, nil
+}
diff --git a/internal/pkg/database/milvus/operations.go b/internal/pkg/database/milvus/operations.go
new file mode 100644
index 0000000..db27de0
--- /dev/null
+++ b/internal/pkg/database/milvus/operations.go
@@ -0,0 +1,270 @@
+package milvus
+
+import (
+ "context"
+ "fmt"
+ "sort"
+
+ "easy_rag/internal/models"
+
+ "github.com/milvus-io/milvus-sdk-go/v2/client"
+ "github.com/milvus-io/milvus-sdk-go/v2/entity"
+)
+
+// InsertDocuments inserts documents into the "documents" collection.
+func (m *Client) InsertDocuments(ctx context.Context, docs []models.Document) error {
+ idColumn := entity.NewColumnVarChar("ID", extractIDs(docs))
+ contentColumn := entity.NewColumnVarChar("Content", extractContents(docs))
+ linkColumn := entity.NewColumnVarChar("Link", extractLinks(docs))
+ filenameColumn := entity.NewColumnVarChar("Filename", extractFilenames(docs))
+ categoryColumn := entity.NewColumnVarChar("Category", extractCategories(docs))
+ embeddingModelColumn := entity.NewColumnVarChar("EmbeddingModel", extractEmbeddingModels(docs))
+ summaryColumn := entity.NewColumnVarChar("Summary", extractSummaries(docs))
+ metadataColumn := entity.NewColumnVarChar("Metadata", extractMetadata(docs))
+ vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectorsDocs(docs))
+ // Insert the data
+ _, err := m.Instance.Insert(ctx, "documents", "_default", idColumn, contentColumn, linkColumn, filenameColumn,
+ categoryColumn, embeddingModelColumn, summaryColumn, metadataColumn, vectorColumn)
+ if err != nil {
+ return fmt.Errorf("failed to insert documents: %w", err)
+ }
+
+ // Flush the collection
+ err = m.Instance.Flush(ctx, "documents", false)
+ if err != nil {
+ return fmt.Errorf("failed to flush documents collection: %w", err)
+ }
+
+ return nil
+}
+
+// InsertEmbeddings inserts embeddings into the "chunks" collection.
+func (m *Client) InsertEmbeddings(ctx context.Context, embeddings []models.Embedding) error {
+ idColumn := entity.NewColumnVarChar("ID", extractEmbeddingIDs(embeddings))
+ documentIDColumn := entity.NewColumnVarChar("DocumentID", extractDocumentIDs(embeddings))
+ vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectors(embeddings))
+ textChunkColumn := entity.NewColumnVarChar("TextChunk", extractTextChunks(embeddings))
+ dimensionColumn := entity.NewColumnInt32("Dimension", extractDimensions(embeddings))
+ orderColumn := entity.NewColumnInt32("Order", extractOrders(embeddings))
+
+ _, err := m.Instance.Insert(ctx, "chunks", "_default", idColumn, documentIDColumn, vectorColumn,
+ textChunkColumn, dimensionColumn, orderColumn)
+
+ if err != nil {
+ return fmt.Errorf("failed to insert embeddings: %w", err)
+ }
+
+ err = m.Instance.Flush(ctx, "chunks", false)
+ if err != nil {
+ return fmt.Errorf("failed to flush chunks collection: %w", err)
+ }
+
+ return nil
+}
+
+// GetDocumentByID retrieves a document from the "documents" collection by ID.
+func (m *Client) GetDocumentByID(ctx context.Context, id string) (map[string]interface{}, error) {
+ collectionName := "documents"
+ expr := fmt.Sprintf("ID == '%s'", id)
+ projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
+
+ results, err := m.Instance.Query(ctx, collectionName, nil, expr, projections)
+ if err != nil {
+ return nil, fmt.Errorf("failed to query document by ID: %w", err)
+ }
+
+ if len(results) == 0 {
+ return nil, fmt.Errorf("document with ID '%s' not found", id)
+ }
+
+ mp, err := transformResultSet(results, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal document: %w", err)
+ }
+
+ // convert metadata to map
+ mp[0]["Metadata"] = convertToMetadata(mp[0]["Metadata"].(string))
+
+ return mp[0], err
+}
+
+// GetAllDocuments retrieves all documents from the "documents" collection.
+func (m *Client) GetAllDocuments(ctx context.Context) ([]models.Document, error) {
+ collectionName := "documents"
+ projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
+ expr := ""
+
+ rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
+ if err != nil {
+ return nil, fmt.Errorf("failed to query all documents: %w", err)
+ }
+
+ if len(rs) == 0 {
+ return nil, fmt.Errorf("no documents found in the collection")
+ }
+
+ results, err := transformResultSet(rs, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
+ }
+
+ var docs []models.Document = make([]models.Document, len(results))
+ for i, result := range results {
+ docs[i] = models.Document{
+ ID: result["ID"].(string),
+ Content: result["Content"].(string),
+ Link: result["Link"].(string),
+ Filename: result["Filename"].(string),
+ Category: result["Category"].(string),
+ EmbeddingModel: result["EmbeddingModel"].(string),
+ Summary: result["Summary"].(string),
+ Metadata: convertToMetadata(results[0]["Metadata"].(string)),
+ }
+ }
+
+ return docs, nil
+}
+
+// GetAllEmbeddingByDocID retrieves all embeddings linked to a specific DocumentID from the "chunks" collection.
+func (m *Client) GetAllEmbeddingByDocID(ctx context.Context, documentID string) ([]models.Embedding, error) {
+ collectionName := "chunks"
+ projections := []string{"ID", "DocumentID", "TextChunk", "Order"} // Fetch all fields
+ expr := fmt.Sprintf("DocumentID == '%s'", documentID)
+
+ rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to query embeddings by DocumentID: %w", err)
+ }
+
+ if rs.Len() == 0 {
+ return nil, fmt.Errorf("no embeddings found for DocumentID '%s'", documentID)
+ }
+
+ results, err := transformResultSet(rs, "ID", "DocumentID", "TextChunk", "Order")
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
+ }
+
+ var embeddings []models.Embedding = make([]models.Embedding, rs.Len())
+
+ for i, result := range results {
+ embeddings[i] = models.Embedding{
+ ID: result["ID"].(string),
+ DocumentID: result["DocumentID"].(string),
+ TextChunk: result["TextChunk"].(string),
+ Order: result["Order"].(int64),
+ }
+ }
+
+ return embeddings, nil
+}
+
+func (m *Client) Search(ctx context.Context, vectors [][]float32, topK int) ([]models.Embedding, error) {
+ const (
+ collectionName = "chunks"
+ vectorDim = 1024 // Replace with your actual vector dimension
+ )
+ projections := []string{"ID", "DocumentID", "TextChunk", "Order"}
+ metricType := entity.L2 // Default metric type
+
+ // Validate and convert input vectors
+ searchVectors, err := validateAndConvertVectors(vectors, vectorDim)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set search parameters
+ searchParams, err := entity.NewIndexIvfFlatSearchParam(16) // 16 is the number of clusters for IVF_FLAT index
+ if err != nil {
+ return nil, fmt.Errorf("failed to create search params: %w", err)
+ }
+
+ // Perform the search
+ searchResults, err := m.Instance.Search(ctx, collectionName, nil, "", projections, searchVectors, "Vector", metricType, topK, searchParams, client.WithLimit(10))
+ if err != nil {
+ return nil, fmt.Errorf("failed to search collection: %w", err)
+ }
+
+ // Process search results
+ embeddings, err := processSearchResults(searchResults)
+ if err != nil {
+ return nil, fmt.Errorf("failed to process search results: %w", err)
+ }
+
+ return embeddings, nil
+}
+
+// validateAndConvertVectors validates vector dimensions and converts them to Milvus-compatible format.
+func validateAndConvertVectors(vectors [][]float32, expectedDim int) ([]entity.Vector, error) {
+ searchVectors := make([]entity.Vector, len(vectors))
+ for i, vector := range vectors {
+ if len(vector) != expectedDim {
+ return nil, fmt.Errorf("vector dimension mismatch: expected %d, got %d", expectedDim, len(vector))
+ }
+ searchVectors[i] = entity.FloatVector(vector)
+ }
+ return searchVectors, nil
+}
+
+// processSearchResults transforms and aggregates the search results into embeddings and sorts by score.
+func processSearchResults(results []client.SearchResult) ([]models.Embedding, error) {
+ var embeddings []models.Embedding
+
+ for _, result := range results {
+ for i := 0; i < result.ResultCount; i++ {
+ embeddingMap, err := transformSearchResultSet(result, "ID", "DocumentID", "TextChunk", "Order")
+ if err != nil {
+ return nil, fmt.Errorf("failed to transform search result set: %w", err)
+ }
+
+ for _, embedding := range embeddingMap {
+ embeddings = append(embeddings, models.Embedding{
+ ID: embedding["ID"].(string),
+ DocumentID: embedding["DocumentID"].(string),
+ TextChunk: embedding["TextChunk"].(string),
+ Order: embedding["Order"].(int64), // Assuming 'Order' is a float64 type
+ Score: embedding["Score"].(float32),
+ })
+ }
+ }
+ }
+
+ // Sort embeddings by score in descending order (higher is better)
+ sort.Slice(embeddings, func(i, j int) bool {
+ return embeddings[i].Score > embeddings[j].Score
+ })
+
+ return embeddings, nil
+}
+
+// DeleteDocument deletes a document from the "documents" collection by ID.
+func (m *Client) DeleteDocument(ctx context.Context, id string) error {
+ collectionName := "documents"
+ partitionName := "_default"
+ expr := fmt.Sprintf("ID == '%s'", id)
+
+ err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
+ if err != nil {
+ return fmt.Errorf("failed to delete document by ID: %w", err)
+ }
+
+ return nil
+}
+
+// DeleteEmbedding deletes an embedding from the "chunks" collection by ID.
+func (m *Client) DeleteEmbedding(ctx context.Context, id string) error {
+ collectionName := "chunks"
+ partitionName := "_default"
+ expr := fmt.Sprintf("DocumentID == '%s'", id)
+
+ err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
+ if err != nil {
+ return fmt.Errorf("failed to delete embedding by DocumentID: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/pkg/rag/rag.go b/internal/pkg/rag/rag.go
new file mode 100644
index 0000000..c367990
--- /dev/null
+++ b/internal/pkg/rag/rag.go
@@ -0,0 +1,21 @@
+package rag
+
+import (
+ "easy_rag/internal/database"
+ "easy_rag/internal/embeddings"
+ "easy_rag/internal/llm"
+)
+
+type Rag struct {
+ LLM llm.LLMService
+ Embeddings embeddings.EmbeddingsService
+ Database database.Database
+}
+
+func NewRag(llm llm.LLMService, embeddings embeddings.EmbeddingsService, database database.Database) *Rag {
+ return &Rag{
+ LLM: llm,
+ Embeddings: embeddings,
+ Database: database,
+ }
+}
diff --git a/internal/pkg/textprocessor/textprocessor.go b/internal/pkg/textprocessor/textprocessor.go
new file mode 100644
index 0000000..32ef39e
--- /dev/null
+++ b/internal/pkg/textprocessor/textprocessor.go
@@ -0,0 +1,50 @@
+package textprocessor
+
+import (
+ "bytes"
+ "strings"
+
+ "github.com/jonathanhecl/chunker"
+)
+
+func CreateChunks(text string) []string {
+ // Maximum characters per chunk
+ const maxCharacters = 5000 // too slow otherwise
+
+ var chunks []string
+ var currentChunk strings.Builder
+
+ // Use the chunker library to split text into sentences
+ sentences := chunker.ChunkSentences(text)
+
+ for _, sentence := range sentences {
+ // Check if adding the sentence exceeds the character limit
+ if currentChunk.Len()+len(sentence) <= maxCharacters {
+ if currentChunk.Len() > 0 {
+ currentChunk.WriteString(" ") // Add a space between sentences
+ }
+ currentChunk.WriteString(sentence)
+ } else {
+ // Add the completed chunk to the chunks slice
+ chunks = append(chunks, currentChunk.String())
+ currentChunk.Reset() // Start a new chunk
+ currentChunk.WriteString(sentence) // Add the sentence to the new chunk
+ }
+ }
+
+ // Add the last chunk if it has content
+ if currentChunk.Len() > 0 {
+ chunks = append(chunks, currentChunk.String())
+ }
+
+ // Return the chunks
+ return chunks
+}
+
+func ConcatenateStrings(strings []string) string {
+ var result bytes.Buffer
+ for _, str := range strings {
+ result.WriteString(str)
+ }
+ return result.String()
+}
diff --git a/scripts/standalone_embed.sh b/scripts/standalone_embed.sh
new file mode 100644
index 0000000..6f29e51
--- /dev/null
+++ b/scripts/standalone_embed.sh
@@ -0,0 +1,169 @@
+#!/usr/bin/env bash
+
+# Licensed to the LF AI & Data foundation under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+run_embed() {
+ cat << EOF > embedEtcd.yaml
+listen-client-urls: http://0.0.0.0:2379
+advertise-client-urls: http://0.0.0.0:2379
+quota-backend-bytes: 4294967296
+auto-compaction-mode: revision
+auto-compaction-retention: '1000'
+EOF
+
+ cat << EOF > user.yaml
+# Extra config to override default milvus.yaml
+EOF
+
+ sudo docker run -d \
+ --name milvus-standalone \
+ --security-opt seccomp:unconfined \
+ -e ETCD_USE_EMBED=true \
+ -e ETCD_DATA_DIR=/var/lib/milvus/etcd \
+ -e ETCD_CONFIG_PATH=/milvus/configs/embedEtcd.yaml \
+ -e COMMON_STORAGETYPE=local \
+ -v $(pwd)/volumes/milvus:/var/lib/milvus \
+ -v $(pwd)/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml \
+ -v $(pwd)/user.yaml:/milvus/configs/user.yaml \
+ -p 19530:19530 \
+ -p 9091:9091 \
+ -p 2379:2379 \
+ --health-cmd="curl -f http://localhost:9091/healthz" \
+ --health-interval=30s \
+ --health-start-period=90s \
+ --health-timeout=20s \
+ --health-retries=3 \
+ milvusdb/milvus:v2.4.16 \
+ milvus run standalone 1> /dev/null
+}
+
+wait_for_milvus_running() {
+ echo "Wait for Milvus Starting..."
+ while true
+ do
+ res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l`
+ if [ $res -eq 1 ]
+ then
+ echo "Start successfully."
+ echo "To change the default Milvus configuration, add your settings to the user.yaml file and then restart the service."
+ break
+ fi
+ sleep 1
+ done
+}
+
+start() {
+ res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l`
+ if [ $res -eq 1 ]
+ then
+ echo "Milvus is running."
+ exit 0
+ fi
+
+ res=`sudo docker ps -a|grep milvus-standalone|wc -l`
+ if [ $res -eq 1 ]
+ then
+ sudo docker start milvus-standalone 1> /dev/null
+ else
+ run_embed
+ fi
+
+ if [ $? -ne 0 ]
+ then
+ echo "Start failed."
+ exit 1
+ fi
+
+ wait_for_milvus_running
+}
+
+stop() {
+ sudo docker stop milvus-standalone 1> /dev/null
+
+ if [ $? -ne 0 ]
+ then
+ echo "Stop failed."
+ exit 1
+ fi
+ echo "Stop successfully."
+
+}
+
+delete_container() {
+ res=`sudo docker ps|grep milvus-standalone|wc -l`
+ if [ $res -eq 1 ]
+ then
+ echo "Please stop Milvus service before delete."
+ exit 1
+ fi
+ sudo docker rm milvus-standalone 1> /dev/null
+ if [ $? -ne 0 ]
+ then
+ echo "Delete milvus container failed."
+ exit 1
+ fi
+ echo "Delete milvus container successfully."
+}
+
+delete() {
+ delete_container
+ sudo rm -rf $(pwd)/volumes
+ sudo rm -rf $(pwd)/embedEtcd.yaml
+ sudo rm -rf $(pwd)/user.yaml
+ echo "Delete successfully."
+}
+
+upgrade() {
+ read -p "Please confirm if you'd like to proceed with the upgrade. The default will be to the latest version. Confirm with 'y' for yes or 'n' for no. > " check
+ if [ "$check" == "y" ] ||[ "$check" == "Y" ];then
+ res=`sudo docker ps -a|grep milvus-standalone|wc -l`
+ if [ $res -eq 1 ]
+ then
+ stop
+ delete_container
+ fi
+
+ curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed_latest.sh && \
+ bash standalone_embed_latest.sh start 1> /dev/null && \
+ echo "Upgrade successfully."
+ else
+ echo "Exit upgrade"
+ exit 0
+ fi
+}
+
+case $1 in
+ restart)
+ stop
+ start
+ ;;
+ start)
+ start
+ ;;
+ stop)
+ stop
+ ;;
+ upgrade)
+ upgrade
+ ;;
+ delete)
+ delete
+ ;;
+ *)
+ echo "please use bash standalone_embed.sh restart|start|stop|upgrade|delete"
+ ;;
+esac
diff --git a/tests/api_test.go b/tests/api_test.go
new file mode 100644
index 0000000..38bb073
--- /dev/null
+++ b/tests/api_test.go
@@ -0,0 +1,313 @@
+package api_test
+
+import (
+ "bytes"
+ "easy_rag/internal/models"
+ "encoding/json"
+ "github.com/google/uuid"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "easy_rag/api"
+ "easy_rag/internal/pkg/rag"
+ "github.com/labstack/echo/v4"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+// Example: Test UploadHandler
+func TestUploadHandler(t *testing.T) {
+ e := echo.New()
+
+ // Create a mock for the LLM, Embeddings, and Database
+ mockLLM := new(MockLLMService)
+ mockEmbeddings := new(MockEmbeddingsService)
+ mockDB := new(MockDatabase)
+
+ // Setup the Rag object
+ r := &rag.Rag{
+ LLM: mockLLM,
+ Embeddings: mockEmbeddings,
+ Database: mockDB,
+ }
+
+ // We expect calls to these mocks in the background goroutine, for each document.
+
+ // The request body
+ requestBody := api.RequestUpload{
+ Docs: []api.UploadDoc{
+ {
+ Content: "Test document content",
+ Link: "http://example.com/doc",
+ Filename: "doc1.txt",
+ Category: "TestCategory",
+ Metadata: map[string]string{"Author": "Me"},
+ },
+ },
+ }
+
+ // Convert requestBody to JSON
+ reqBodyBytes, _ := json.Marshal(requestBody)
+
+ // Create a new request
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/upload", bytes.NewReader(reqBodyBytes))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+
+ // Create a ResponseRecorder
+ rec := httptest.NewRecorder()
+
+ // New echo context
+ c := e.NewContext(req, rec)
+ // Set the rag object in context
+ c.Set("Rag", r)
+
+ // Because the UploadHandler spawns a goroutine, we only test the immediate HTTP response.
+ // We can still set expectations for the calls that happen in the goroutine to ensure they're invoked.
+ // For example, we expect the summary to be generated, so:
+
+ testSummary := "Test summary from LLM"
+ mockLLM.On("Generate", mock.Anything).Return(testSummary, nil).Maybe() // .Maybe() because the concurrency might not complete by the time we assert
+
+ // The embedding vector returned from the embeddings service
+ testVector := [][]float32{{0.1, 0.2, 0.3, 0.4}}
+
+ // We'll mock calls to Vectorize() for summary and each chunk
+ mockEmbeddings.On("Vectorize", mock.AnythingOfType("string")).Return(testVector, nil).Maybe()
+
+ // The database SaveDocument / SaveEmbeddings calls
+ mockDB.On("SaveDocument", mock.AnythingOfType("models.Document")).Return(nil).Maybe()
+ mockDB.On("SaveEmbeddings", mock.AnythingOfType("[]models.Embedding")).Return(nil).Maybe()
+
+ // Invoke the handler
+ err := api.UploadHandler(c)
+
+ // Check no immediate errors
+ assert.NoError(t, err)
+
+ // Check the response
+ assert.Equal(t, http.StatusAccepted, rec.Code)
+
+ var resp map[string]interface{}
+ _ = json.Unmarshal(rec.Body.Bytes(), &resp)
+
+ // We expect certain fields in the JSON response
+ assert.Equal(t, "v1", resp["version"])
+ assert.NotEmpty(t, resp["task_id"])
+ assert.Equal(t, "Processing started", resp["status"])
+
+ // Typically, you might want to wait or do more advanced concurrency checks if you want to test
+ // the logic in the goroutine, but that goes beyond a simple unit test.
+ // The background process can be tested more thoroughly in integration or end-to-end tests.
+
+ // Optionally assert that our mocks were called
+ mockLLM.AssertExpectations(t)
+ mockEmbeddings.AssertExpectations(t)
+ mockDB.AssertExpectations(t)
+}
+
+// Example: Test ListAllDocsHandler
+func TestListAllDocsHandler(t *testing.T) {
+ e := echo.New()
+
+ mockLLM := new(MockLLMService)
+ mockEmbeddings := new(MockEmbeddingsService)
+ mockDB := new(MockDatabase)
+
+ r := &rag.Rag{
+ LLM: mockLLM,
+ Embeddings: mockEmbeddings,
+ Database: mockDB,
+ }
+
+ // Mock data
+ doc1 := models.Document{
+ ID: uuid.NewString(),
+ Filename: "doc1.txt",
+ Summary: "summary doc1",
+ }
+ doc2 := models.Document{
+ ID: uuid.NewString(),
+ Filename: "doc2.txt",
+ Summary: "summary doc2",
+ }
+ docs := []models.Document{doc1, doc2}
+
+ // Expect the database to return the docs
+ mockDB.On("ListDocuments").Return(docs, nil)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/docs", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ c.Set("Rag", r)
+
+ err := api.ListAllDocsHandler(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var resp map[string]interface{}
+ _ = json.Unmarshal(rec.Body.Bytes(), &resp)
+ assert.Equal(t, "v1", resp["version"])
+
+ // The "docs" field should match the ones we returned
+ docsInterface, ok := resp["docs"].([]interface{})
+ assert.True(t, ok)
+ assert.Len(t, docsInterface, 2)
+
+ // Verify mocks
+ mockDB.AssertExpectations(t)
+}
+
+// Example: Test GetDocHandler
+func TestGetDocHandler(t *testing.T) {
+ e := echo.New()
+
+ mockLLM := new(MockLLMService)
+ mockEmbeddings := new(MockEmbeddingsService)
+ mockDB := new(MockDatabase)
+
+ r := &rag.Rag{
+ LLM: mockLLM,
+ Embeddings: mockEmbeddings,
+ Database: mockDB,
+ }
+
+ // Mock a single doc
+ docID := "123"
+ testDoc := models.Document{
+ ID: docID,
+ Filename: "doc3.txt",
+ Summary: "summary doc3",
+ }
+
+ mockDB.On("GetDocument", docID).Return(testDoc, nil)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/doc/123", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ // set path param
+ c.SetParamNames("id")
+ c.SetParamValues(docID)
+ c.Set("Rag", r)
+
+ err := api.GetDocHandler(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var resp map[string]interface{}
+ _ = json.Unmarshal(rec.Body.Bytes(), &resp)
+ assert.Equal(t, "v1", resp["version"])
+
+ docInterface := resp["doc"].(map[string]interface{})
+ assert.Equal(t, "doc3.txt", docInterface["filename"])
+
+ // Verify mocks
+ mockDB.AssertExpectations(t)
+}
+
+// Example: Test AskDocHandler
+func TestAskDocHandler(t *testing.T) {
+ e := echo.New()
+
+ mockLLM := new(MockLLMService)
+ mockEmbeddings := new(MockEmbeddingsService)
+ mockDB := new(MockDatabase)
+
+ r := &rag.Rag{
+ LLM: mockLLM,
+ Embeddings: mockEmbeddings,
+ Database: mockDB,
+ }
+
+ // 1) We expect to Vectorize the question
+ question := "What is the summary of doc?"
+ questionVector := [][]float32{{0.5, 0.2, 0.1}}
+ mockEmbeddings.On("Vectorize", question).Return(questionVector, nil)
+
+ // 2) We expect a DB search
+ emb := []models.Embedding{
+ {
+ ID: "emb1",
+ DocumentID: "doc123",
+ TextChunk: "Relevant content chunk",
+ Score: 0.99,
+ },
+ }
+ mockDB.On("Search", questionVector).Return(emb, nil)
+
+ // 3) We expect the LLM to generate an answer from the chunk
+ generatedAnswer := "Here is an answer from the chunk"
+ // The prompt we pass is something like: "Given the following information: chunk ... Answer the question: question"
+ mockLLM.On("Generate", mock.AnythingOfType("string")).Return(generatedAnswer, nil)
+
+ // Prepare request
+ reqBody := api.RequestQuestion{
+ Question: question,
+ }
+ reqBytes, _ := json.Marshal(reqBody)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/ask", bytes.NewReader(reqBytes))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ c.Set("Rag", r)
+
+ // Execute
+ err := api.AskDocHandler(c)
+
+ // Verify
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var resp map[string]interface{}
+ _ = json.Unmarshal(rec.Body.Bytes(), &resp)
+ assert.Equal(t, "v1", resp["version"])
+ assert.Equal(t, generatedAnswer, resp["answer"])
+
+ // The docs field should have the docID "doc123"
+ docsInterface := resp["docs"].([]interface{})
+ assert.Len(t, docsInterface, 1)
+ assert.Equal(t, "doc123", docsInterface[0])
+
+ // Verify mocks
+ mockLLM.AssertExpectations(t)
+ mockEmbeddings.AssertExpectations(t)
+ mockDB.AssertExpectations(t)
+}
+
+// Example: Test DeleteDocHandler
+func TestDeleteDocHandler(t *testing.T) {
+ e := echo.New()
+ mockLLM := new(MockLLMService)
+ mockEmbeddings := new(MockEmbeddingsService)
+ mockDB := new(MockDatabase)
+
+ r := &rag.Rag{
+ LLM: mockLLM,
+ Embeddings: mockEmbeddings,
+ Database: mockDB,
+ }
+
+ docID := "abc"
+ mockDB.On("DeleteDocument", docID).Return(nil)
+
+ req := httptest.NewRequest(http.MethodDelete, "/api/v1/doc/abc", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ c.SetParamNames("id")
+ c.SetParamValues(docID)
+ c.Set("Rag", r)
+
+ err := api.DeleteDocHandler(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ var resp map[string]interface{}
+ _ = json.Unmarshal(rec.Body.Bytes(), &resp)
+ assert.Equal(t, "v1", resp["version"])
+
+ // docs should be nil according to DeleteDocHandler
+ assert.Nil(t, resp["docs"])
+
+ // Verify mocks
+ mockDB.AssertExpectations(t)
+}
diff --git a/tests/mock_test.go b/tests/mock_test.go
new file mode 100644
index 0000000..85e3136
--- /dev/null
+++ b/tests/mock_test.go
@@ -0,0 +1,89 @@
+package api_test
+
+import (
+ "easy_rag/internal/models"
+ "github.com/stretchr/testify/mock"
+)
+
+// --------------------
+// Mock LLM
+// --------------------
+type MockLLMService struct {
+ mock.Mock
+}
+
+func (m *MockLLMService) Generate(prompt string) (string, error) {
+ args := m.Called(prompt)
+ return args.String(0), args.Error(1)
+}
+
+func (m *MockLLMService) GetModel() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// --------------------
+// Mock Embeddings
+// --------------------
+type MockEmbeddingsService struct {
+ mock.Mock
+}
+
+func (m *MockEmbeddingsService) Vectorize(text string) ([][]float32, error) {
+ args := m.Called(text)
+ return args.Get(0).([][]float32), args.Error(1)
+}
+
+func (m *MockEmbeddingsService) GetModel() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// --------------------
+// Mock Database
+// --------------------
+type MockDatabase struct {
+ mock.Mock
+}
+
+// GetDocumentInfo(id string) (models.DocumentInfo, error)
+func (m *MockDatabase) GetDocumentInfo(id string) (models.Document, error) {
+ args := m.Called(id)
+ return args.Get(0).(models.Document), args.Error(1)
+}
+
+// SaveDocument(document Document) error
+func (m *MockDatabase) SaveDocument(doc models.Document) error {
+ args := m.Called(doc)
+ return args.Error(0)
+}
+
+// SaveEmbeddings([]Embedding) error
+func (m *MockDatabase) SaveEmbeddings(emb []models.Embedding) error {
+ args := m.Called(emb)
+ return args.Error(0)
+}
+
+// ListDocuments() ([]Document, error)
+func (m *MockDatabase) ListDocuments() ([]models.Document, error) {
+ args := m.Called()
+ return args.Get(0).([]models.Document), args.Error(1)
+}
+
+// GetDocument(id string) (Document, error)
+func (m *MockDatabase) GetDocument(id string) (models.Document, error) {
+ args := m.Called(id)
+ return args.Get(0).(models.Document), args.Error(1)
+}
+
+// DeleteDocument(id string) error
+func (m *MockDatabase) DeleteDocument(id string) error {
+ args := m.Called(id)
+ return args.Error(0)
+}
+
+// Search(vector []float32) ([]models.Embedding, error)
+func (m *MockDatabase) Search(vector [][]float32) ([]models.Embedding, error) {
+ args := m.Called(vector)
+ return args.Get(0).([]models.Embedding), args.Error(1)
+}
diff --git a/tests/openrouter_test.go b/tests/openrouter_test.go
new file mode 100644
index 0000000..accb9d3
--- /dev/null
+++ b/tests/openrouter_test.go
@@ -0,0 +1,125 @@
+package api_test
+
+import (
+ "context"
+ openroute2 "easy_rag/internal/llm/openroute"
+ "fmt"
+ "testing"
+)
+
+func TestFetchChatCompletions(t *testing.T) {
+ client := openroute2.NewOpenRouterClient("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")
+
+ request := openroute2.Request{
+ Model: "qwen/qwen2.5-vl-72b-instruct:free",
+ Messages: []openroute2.MessageRequest{
+ {openroute2.RoleUser, "Привет!", "", ""},
+ },
+ }
+
+ output, err := client.FetchChatCompletions(request)
+ if err != nil {
+ t.Errorf("error %v", err)
+ }
+
+ t.Logf("output: %v", output.Choices[0].Message.Content)
+}
+
+func TestFetchChatCompletionsStreaming(t *testing.T) {
+ client := openroute2.NewOpenRouterClient("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")
+
+ request := openroute2.Request{
+ Model: "qwen/qwen2.5-vl-72b-instruct:free",
+ Messages: []openroute2.MessageRequest{
+ {openroute2.RoleUser, "Привет!", "", ""},
+ },
+ Stream: true,
+ }
+
+ outputChan := make(chan openroute2.Response)
+ processingChan := make(chan interface{})
+ errChan := make(chan error)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ go client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
+
+ for {
+ select {
+ case output := <-outputChan:
+ if len(output.Choices) > 0 {
+ t.Logf("%s", output.Choices[0].Delta.Content)
+ }
+ case <-processingChan:
+ t.Logf("Обработка\n")
+ case err := <-errChan:
+ if err != nil {
+ t.Errorf("Ошибка: %v", err)
+ return
+ }
+ return
+ case <-ctx.Done():
+ fmt.Println("Контекст отменен:", ctx.Err())
+ return
+ }
+ }
+}
+
+func TestFetchChatCompletionsAgentStreaming(t *testing.T) {
+ client := openroute2.NewOpenRouterClient("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")
+ agent := openroute2.NewRouterAgent(client, "qwen/qwen2.5-vl-72b-instruct:freet", openroute2.RouterAgentConfig{
+ Temperature: 0.7,
+ MaxTokens: 100,
+ })
+
+ outputChan := make(chan openroute2.Response)
+ processingChan := make(chan interface{})
+ errChan := make(chan error)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ chat := []openroute2.MessageRequest{
+ {Role: openroute2.RoleSystem, Content: "Вы полезный помощник."},
+ {Role: openroute2.RoleUser, Content: "Привет!"},
+ }
+
+ go agent.ChatStream(chat, outputChan, processingChan, errChan, ctx)
+
+ for {
+ select {
+ case output := <-outputChan:
+ if len(output.Choices) > 0 {
+ t.Logf("%s", output.Choices[0].Delta.Content)
+ }
+ case <-processingChan:
+ t.Logf("Обработка\n")
+ case err := <-errChan:
+ if err != nil {
+ t.Errorf("Ошибка: %v", err)
+ return
+ }
+ return
+ case <-ctx.Done():
+ fmt.Println("Контекст отменен:", ctx.Err())
+ return
+ }
+ }
+}
+
+func TestFetchChatCompletionsAgentSimpleChat(t *testing.T) {
+ client := openroute2.NewOpenRouterClient("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")
+ agent := openroute2.NewRouterAgentChat(client, "qwen/qwen2.5-vl-72b-instruct:free", openroute2.RouterAgentConfig{
+ Temperature: 0.0,
+ MaxTokens: 100,
+ }, "Вы полезный помощник, отвечайте короткими словами.")
+
+ agent.Chat("Запомни это: \"wojtess\"")
+ agent.Chat("Что я просил вас запомнить?")
+
+ for _, msg := range agent.Messages {
+ content, ok := msg.Content.(string)
+ if ok {
+ t.Logf("%s: %s", msg.Role, content)
+ }
+ }
+}
diff --git a/user.yaml b/user.yaml
new file mode 100644
index 0000000..8d31269
--- /dev/null
+++ b/user.yaml
@@ -0,0 +1 @@
+# Extra config to override default milvus.yaml