From 636096fd34d39255c2d059f3cbc10269ea54ca80 Mon Sep 17 00:00:00 2001 From: Dmitriy Fofanov Date: Fri, 19 Sep 2025 11:38:31 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A0=D0=B5=D0=B0=D0=BB=D0=B8=D0=B7=D0=BE?= =?UTF-8?q?=D0=B2=D0=B0=D0=BD=D0=B0=20=D0=BE=D0=BF=D0=B5=D1=80=D0=B0=D1=86?= =?UTF-8?q?=D0=B8=D0=B8=20Milvus=20=D0=B4=D0=BB=D1=8F=20=D1=83=D0=BF=D1=80?= =?UTF-8?q?=D0=B0=D0=B2=D0=BB=D0=B5=D0=BD=D0=B8=D1=8F=20=D0=B4=D0=BE=D0=BA?= =?UTF-8?q?=D1=83=D0=BC=D0=B5=D0=BD=D1=82=D0=B0=D0=BC=D0=B8=20=D0=B8=20?= =?UTF-8?q?=D0=B2=D1=81=D1=82=D1=80=D0=B0=D0=B8=D0=B2=D0=B0=D0=BD=D0=B8?= =?UTF-8?q?=D0=B5=D0=BC,=20=D0=B2=D0=BA=D0=BB=D1=8E=D1=87=D0=B0=D1=8F=20?= =?UTF-8?q?=D1=84=D1=83=D0=BD=D0=BA=D1=86=D0=B8=D0=B8=20=D0=B2=D1=81=D1=82?= =?UTF-8?q?=D0=B0=D0=B2=D0=BA=D0=B8,=20=D0=B7=D0=B0=D0=BF=D1=80=D0=BE?= =?UTF-8?q?=D1=81=D0=B0=20=D0=B8=20=D1=83=D0=B4=D0=B0=D0=BB=D0=B5=D0=BD?= =?UTF-8?q?=D0=B8=D1=8F.=20=D0=92=D0=BD=D0=B5=D0=B4=D1=80=D0=B8=D1=82?= =?UTF-8?q?=D0=B5=20=D0=B0=D1=80=D1=85=D0=B8=D1=82=D0=B5=D0=BA=D1=82=D1=83?= =?UTF-8?q?=D1=80=D1=83=20RAG=20=D1=81=20LLM=20=D0=B8=20=D1=81=D0=B5=D1=80?= =?UTF-8?q?=D0=B2=D0=B8=D1=81=D0=B0=D0=BC=D0=B8=20=D0=B2=D1=81=D1=82=D1=80?= =?UTF-8?q?=D0=B0=D0=B8=D0=B2=D0=B0=D0=BD=D0=B8=D1=8F.=20=D0=94=D0=BE?= =?UTF-8?q?=D0=B1=D0=B0=D0=B2=D1=8C=D1=82=D0=B5=20=D0=BE=D0=B1=D1=80=D0=B0?= =?UTF-8?q?=D0=B1=D0=BE=D1=82=D0=BA=D1=83=20=D1=82=D0=B5=D0=BA=D1=81=D1=82?= =?UTF-8?q?=D0=B0=20=D0=B4=D0=BB=D1=8F=20=D1=84=D1=80=D0=B0=D0=B3=D0=BC?= =?UTF-8?q?=D0=B5=D0=BD=D1=82=D0=B0=D1=86=D0=B8=D0=B8=20=D0=B8=20=D0=BA?= =?UTF-8?q?=D0=BE=D0=BD=D0=BA=D0=B0=D1=82=D0=B5=D0=BD=D0=B0=D1=86=D0=B8?= =?UTF-8?q?=D0=B8.=20=D0=A1=D0=BE=D0=B7=D0=B4=D0=B0=D0=B9=D1=82=D0=B5=20?= =?UTF-8?q?=D0=B0=D0=B2=D1=82=D0=BE=D0=BD=D0=BE=D0=BC=D0=BD=D1=8B=D0=B9=20?= =?UTF-8?q?=D1=81=D0=BA=D1=80=D0=B8=D0=BF=D1=82=20=D0=B4=D0=BB=D1=8F=20?= =?UTF-8?q?=D0=BD=D0=B0=D1=81=D1=82=D1=80=D0=BE=D0=B9=D0=BA=D0=B8=20=D0=B8?= =?UTF-8?q?=20=D1=83=D0=BF=D1=80=D0=B0=D0=B2=D0=BB=D0=B5=D0=BD=D0=B8=D1=8F?= =?UTF-8?q?=20Milvus.=20=D0=A0=D0=B0=D0=B7=D1=80=D0=B0=D0=B1=D0=BE=D1=82?= =?UTF-8?q?=D0=B0=D0=B9=D1=82=D0=B5=20=D0=BA=D0=BE=D0=BC=D0=BF=D0=BB=D0=B5?= =?UTF-8?q?=D0=BA=D1=81=D0=BD=D1=8B=D0=B5=20=D1=82=D0=B5=D1=81=D1=82=D1=8B?= =?UTF-8?q?=20API=20=D0=B4=D0=BB=D1=8F=20=D0=BE=D0=B1=D1=80=D0=B0=D0=B1?= =?UTF-8?q?=D0=BE=D1=82=D0=BA=D0=B8=20=D0=B4=D0=BE=D0=BA=D1=83=D0=BC=D0=B5?= =?UTF-8?q?=D0=BD=D1=82=D0=BE=D0=B2=20=D0=B8=20=D0=B2=D0=B7=D0=B0=D0=B8?= =?UTF-8?q?=D0=BC=D0=BE=D0=B4=D0=B5=D0=B9=D1=81=D1=82=D0=B2=D0=B8=D1=8F=20?= =?UTF-8?q?=D1=81=20LLM,=20=D0=B2=D0=BA=D0=BB=D1=8E=D1=87=D0=B0=D1=8F=20?= =?UTF-8?q?=D0=B8=D0=BC=D0=B8=D1=82=D0=B0=D1=86=D0=B8=D0=B8=20=D0=B4=D0=BB?= =?UTF-8?q?=D1=8F=20=D1=81=D0=B5=D1=80=D0=B2=D0=B8=D1=81=D0=BE=D0=B2.=20?= =?UTF-8?q?=D0=A0=D0=B0=D1=81=D1=88=D0=B8=D1=80=D1=8C=D1=82=D0=B5=20=D0=B2?= =?UTF-8?q?=D0=BE=D0=B7=D0=BC=D0=BE=D0=B6=D0=BD=D0=BE=D1=81=D1=82=D0=B8=20?= =?UTF-8?q?=D0=BA=D0=BE=D0=BD=D1=84=D0=B8=D0=B3=D1=83=D1=80=D0=B0=D1=86?= =?UTF-8?q?=D0=B8=D0=B8=20=D0=BF=D0=BE=D0=BB=D1=8C=D0=B7=D0=BE=D0=B2=D0=B0?= =?UTF-8?q?=D1=82=D0=B5=D0=BB=D1=8F=20=D1=81=20=D0=BF=D0=BE=D0=BC=D0=BE?= =?UTF-8?q?=D1=89=D1=8C=D1=8E=20=D0=B4=D0=BE=D0=BF=D0=BE=D0=BB=D0=BD=D0=B8?= =?UTF-8?q?=D1=82=D0=B5=D0=BB=D1=8C=D0=BD=D1=8B=D1=85=20=D0=BD=D0=B0=D1=81?= =?UTF-8?q?=D1=82=D1=80=D0=BE=D0=B5=D0=BA=20YAML.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 33 +- .idea/.gitignore | 8 + .idea/easy_rag.iml | 13 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + README.md | 454 +++++++++++++++++++- api/api.go | 35 ++ api/handler.go | 248 +++++++++++ cmd/rag/main.go | 43 ++ config/config.go | 38 ++ deploy/Dockerfile | 16 + deploy/docker-compose.yaml | 0 embedEtcd.yaml | 5 + go.mod | 49 +++ internal/database/database.go | 20 + internal/database/database_milvus.go | 142 ++++++ internal/embeddings/embeddings.go | 8 + internal/embeddings/ollama_embeddings.go | 78 ++++ internal/embeddings/openai_embeddings.go | 23 + internal/llm/llm.go | 8 + internal/llm/ollama_llm.go | 82 ++++ internal/llm/openai_llm.go | 24 ++ internal/llm/openroute/definitions.go | 155 +++++++ internal/llm/openroute/route_agent.go | 198 +++++++++ internal/llm/openroute/route_client.go | 188 ++++++++ internal/llm/openroute_llm.go | 24 ++ internal/models/models.go | 27 ++ internal/pkg/database/milvus/client.go | 169 ++++++++ internal/pkg/database/milvus/client_test.go | 32 ++ internal/pkg/database/milvus/helpers.go | 276 ++++++++++++ internal/pkg/database/milvus/operations.go | 270 ++++++++++++ internal/pkg/rag/rag.go | 21 + internal/pkg/textprocessor/textprocessor.go | 50 +++ scripts/standalone_embed.sh | 169 ++++++++ tests/api_test.go | 313 ++++++++++++++ tests/mock_test.go | 89 ++++ tests/openrouter_test.go | 125 ++++++ user.yaml | 1 + 38 files changed, 3420 insertions(+), 28 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/easy_rag.iml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 api/api.go create mode 100644 api/handler.go create mode 100644 cmd/rag/main.go create mode 100644 config/config.go create mode 100644 deploy/Dockerfile create mode 100644 deploy/docker-compose.yaml create mode 100644 embedEtcd.yaml create mode 100644 go.mod create mode 100644 internal/database/database.go create mode 100644 internal/database/database_milvus.go create mode 100644 internal/embeddings/embeddings.go create mode 100644 internal/embeddings/ollama_embeddings.go create mode 100644 internal/embeddings/openai_embeddings.go create mode 100644 internal/llm/llm.go create mode 100644 internal/llm/ollama_llm.go create mode 100644 internal/llm/openai_llm.go create mode 100644 internal/llm/openroute/definitions.go create mode 100644 internal/llm/openroute/route_agent.go create mode 100644 internal/llm/openroute/route_client.go create mode 100644 internal/llm/openroute_llm.go create mode 100644 internal/models/models.go create mode 100644 internal/pkg/database/milvus/client.go create mode 100644 internal/pkg/database/milvus/client_test.go create mode 100644 internal/pkg/database/milvus/helpers.go create mode 100644 internal/pkg/database/milvus/operations.go create mode 100644 internal/pkg/rag/rag.go create mode 100644 internal/pkg/textprocessor/textprocessor.go create mode 100644 scripts/standalone_embed.sh create mode 100644 tests/api_test.go create mode 100644 tests/mock_test.go create mode 100644 tests/openrouter_test.go create mode 100644 user.yaml diff --git a/.gitignore b/.gitignore index 5b90e79..0874f28 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,6 @@ -# ---> Go -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins -*.exe -*.exe~ -*.dll -*.so -*.dylib - -# Test binary, built with `go test -c` -*.test - -# Output of the go coverage tool, specifically when used with LiteIDE -*.out - -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file -go.work -go.work.sum - -# env file -.env - +*.env +*.sum +id_rsa2 +id_rsa2.pub +/tests/RAG.postman_collection.json +/volumes/ \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/easy_rag.iml b/.idea/easy_rag.iml new file mode 100644 index 0000000..b279d48 --- /dev/null +++ b/.idea/easy_rag.iml @@ -0,0 +1,13 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..8536ff3 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 89cd039..9e3f99f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,454 @@ -# easy_rag +# Easy RAG - Система Retrieval-Augmented Generation +## Обзор + +Easy RAG - это мощная система для управления документами и генерации ответов на основе метода Retrieval-Augmented Generation (RAG). Проект реализует полнофункциональное API для загрузки, хранения, поиска и анализа документов с использованием векторных баз данных и современных языковых моделей. + +### Ключевые возможности + +- 🔍 **Семантический поиск** по документам с использованием векторных эмбеддингов +- 📄 **Управление документами** - загрузка, хранение, получение и удаление +- 🤖 **Интеграция с LLM** - поддержка OpenAI, Ollama и OpenRoute +- 💾 **Векторная база данных** - использование Milvus для хранения эмбеддингов +- 🔧 **Модульная архитектура** - легко расширяемая и настраиваемая система +- 🚀 **RESTful API** - простое и понятное API для интеграции +- 🐳 **Docker-готовность** - контейнеризация для простого развертывания + +--- + +## Архитектура системы + +### Компоненты + +1. **API слой** (`api/`) - HTTP API на базе Echo Framework +2. **Основная логика** (`internal/pkg/rag/`) - ядро RAG системы +3. **Провайдеры LLM** (`internal/llm/`) - интеграция с языковыми моделями +4. **Провайдеры эмбеддингов** (`internal/embeddings/`) - генерация векторных представлений +5. **База данных** (`internal/database/`) - работа с векторной БД Milvus +6. **Обработка текста** (`internal/pkg/textprocessor/`) - разбиение на чанки +7. **Конфигурация** (`config/`) - управление настройками + +### Поддерживаемые провайдеры + +#### LLM (Языковые модели): +- **OpenAI** - GPT-3.5, GPT-4 и другие модели +- **Ollama** - локальные открытые модели +- **OpenRoute** - доступ к различным моделям через единый API + +#### Эмбеддинги: +- **OpenAI Embeddings** - text-embedding-ada-002 и новые модели +- **Ollama Embeddings** - локальные модели эмбеддингов (например, bge-m3) + +#### Векторная база данных: +- **Milvus** - высокопроизводительная векторная база данных + +--- + +## API Документация + +### Базовый URL +``` +http://localhost:4002/api/v1 +``` + +### Эндпоинты + +#### 1. Получить все документы +```http +GET /api/v1/docs +``` +**Описание**: Получить список всех сохраненных документов. + +**Ответ**: +```json +{ + "version": "v1", + "docs": [ + { + "id": "document_id", + "filename": "document.txt", + "summary": "Краткое описание документа", + "metadata": { + "category": "Техническая документация", + "author": "Автор" + } + } + ] +} +``` + +#### 2. Загрузить документы +```http +POST /api/v1/upload +``` +**Описание**: Загрузить один или несколько документов для обработки и индексации. + +**Тело запроса**: +```json +{ + "docs": [ + { + "content": "Содержимое документа...", + "link": "https://example.com/document", + "filename": "document.txt", + "category": "Категория", + "metadata": { + "author": "Автор", + "date": "2024-01-01" + } + } + ] +} +``` + +**Ответ**: +```json +{ + "version": "v1", + "task_id": "unique_task_id", + "expected_time": "10m", + "status": "Обработка начата" +} +``` + +#### 3. Получить документ по ID +```http +GET /api/v1/doc/{id} +``` +**Описание**: Получить детальную информацию о документе по его идентификатору. + +**Ответ**: +```json +{ + "version": "v1", + "doc": { + "id": "document_id", + "content": "Полное содержимое документа", + "filename": "document.txt", + "summary": "Краткое описание", + "metadata": { + "category": "Категория", + "author": "Автор" + } + } +} +``` + +#### 4. Задать вопрос +```http +POST /api/v1/ask +``` +**Описание**: Задать вопрос на основе проиндексированных документов. + +**Тело запроса**: +```json +{ + "question": "Что такое ISO 27001?" +} +``` + +**Ответ**: +```json +{ + "version": "v1", + "docs": ["document_id_1", "document_id_2"], + "answer": "ISO 27001 - это международный стандарт информационной безопасности..." +} +``` + +#### 5. Удалить документ +```http +DELETE /api/v1/doc/{id} +``` +**Описание**: Удалить документ по его идентификатору. + +**Ответ**: +```json +{ + "version": "v1", + "docs": null +} +``` + +--- + +## Структуры данных + +### Document (Документ) +```go +type Document struct { + ID string `json:"id"` // Уникальный идентификатор + Content string `json:"content"` // Содержимое документа + Link string `json:"link"` // Ссылка на источник + Filename string `json:"filename"` // Имя файла + Category string `json:"category"` // Категория + EmbeddingModel string `json:"embedding_model"` // Модель эмбеддингов + Summary string `json:"summary"` // Краткое описание + Metadata map[string]string `json:"metadata"` // Метаданные + Vector []float32 `json:"vector"` // Векторное представление +} +``` + +### Embedding (Эмбеддинг) +```go +type Embedding struct { + ID string `json:"id"` // Уникальный идентификатор + DocumentID string `json:"document_id"` // ID связанного документа + Vector []float32 `json:"vector"` // Векторное представление + TextChunk string `json:"text_chunk"` // Фрагмент текста + Dimension int64 `json:"dimension"` // Размерность вектора + Order int64 `json:"order"` // Порядок фрагмента + Score float32 `json:"score"` // Оценка релевантности +} +``` + +--- + +## Установка и настройка + +### Требования +- Go 1.24.3+ +- Milvus векторная база данных +- Ollama (для локальных моделей) или API ключи для OpenAI/OpenRoute + +### 1. Клонирование репозитория +```bash +git clone https://github.com/elchemista/easy_rag.git +cd easy_rag +``` + +### 2. Установка зависимостей +```bash +go mod tidy +``` + +### 3. Настройка окружения +Создайте файл `.env` или установите переменные окружения: + +```env +# LLM настройки +OPENAI_API_KEY=your_openai_api_key +OPENAI_ENDPOINT=https://api.openai.com/v1 +OPENAI_MODEL=gpt-3.5-turbo + +OPENROUTE_API_KEY=your_openroute_api_key +OPENROUTE_ENDPOINT=https://openrouter.ai/api/v1 +OPENROUTE_MODEL=anthropic/claude-3-haiku + +OLLAMA_ENDPOINT=http://localhost:11434/api/chat +OLLAMA_MODEL=qwen3:latest + +# Эмбеддинги +OPENAI_EMBEDDING_API_KEY=your_openai_api_key +OPENAI_EMBEDDING_ENDPOINT=https://api.openai.com/v1 +OPENAI_EMBEDDING_MODEL=text-embedding-ada-002 + +OLLAMA_EMBEDDING_ENDPOINT=http://localhost:11434 +OLLAMA_EMBEDDING_MODEL=bge-m3 + +# База данных +MILVUS_HOST=localhost:19530 +``` + +### 4. Запуск Milvus +```bash +# Используя Docker +docker run -d --name milvus-standalone \ + -p 19530:19530 -p 9091:9091 \ + milvusdb/milvus:latest +``` + +### 5. Запуск приложения +```bash +go run cmd/rag/main.go +``` + +API будет доступно по адресу `http://localhost:4002` + +--- + +## Запуск с Docker + +### Сборка образа +```bash +docker build -f deploy/Dockerfile -t easy-rag . +``` + +### Запуск контейнера +```bash +docker run -d -p 4002:4002 \ + -e MILVUS_HOST=your_milvus_host:19530 \ + -e OLLAMA_ENDPOINT=http://your_ollama_host:11434/api/chat \ + -e OLLAMA_MODEL=qwen3:latest \ + easy-rag +``` + +--- + +## Примеры использования + +### 1. Загрузка документа +```bash +curl -X POST http://localhost:4002/api/v1/upload \ + -H "Content-Type: application/json" \ + -d '{ + "docs": [{ + "content": "Это тестовый документ для демонстрации RAG системы.", + "filename": "test.txt", + "category": "Тест", + "metadata": { + "author": "Пользователь", + "type": "demo" + } + }] + }' +``` + +### 2. Поиск ответа +```bash +curl -X POST http://localhost:4002/api/v1/ask \ + -H "Content-Type: application/json" \ + -d '{ + "question": "О чем этот документ?" + }' +``` + +### 3. Получение всех документов +```bash +curl http://localhost:4002/api/v1/docs +``` + +--- + +## Архитектурные особенности + +### Модульность +Система спроектирована с использованием интерфейсов, что позволяет легко: +- Переключаться между различными LLM провайдерами +- Использовать разные модели эмбеддингов +- Менять векторную базу данных +- Добавлять новые методы обработки текста + +### Обработка текста +- Автоматическое разбиение документов на чанки +- Генерация эмбеддингов для каждого фрагмента +- Сохранение порядка фрагментов для корректной реконструкции + +### Поиск и ранжирование +- Семантический поиск по векторным представлениям +- Ранжирование результатов по релевантности +- Контекстная генерация ответов на основе найденных документов + +--- + +## Разработка и тестирование + +### Структура проекта +``` +easy_rag/ +├── api/ # HTTP API обработчики +├── cmd/rag/ # Точка входа приложения +├── config/ # Конфигурация +├── deploy/ # Docker файлы +├── internal/ # Внутренняя логика +│ ├── database/ # Интерфейсы БД +│ ├── embeddings/ # Провайдеры эмбеддингов +│ ├── llm/ # Провайдеры LLM +│ ├── models/ # Модели данных +│ └── pkg/ # Пакеты общего назначения +├── scripts/ # Вспомогательные скрипты +└── tests/ # Тесты и коллекции Postman +``` + +### Запуск тестов +```bash +go test ./... +``` + +### Использование Postman +В папке `tests/` находится коллекция Postman для тестирования API: +``` +tests/RAG.postman_collection.json +``` + +--- + +## Производительность и масштабирование + +### Рекомендации по производительности +- Используйте SSD для хранения векторной базы данных +- Настройте индексы Milvus для оптимальной производительности +- Рассмотрите использование GPU для генерации эмбеддингов +- Кэшируйте часто запрашиваемые результаты + +### Масштабирование +- Горизонтальное масштабирование Milvus кластера +- Балансировка нагрузки между несколькими экземплярами API +- Асинхронная обработка загрузки документов +- Использование очередей для обработки больших объемов данных + +--- + +## Устранение неполадок + +### Частые проблемы + +1. **Не удается подключиться к Milvus** + - Проверьте, что Milvus запущен и доступен + - Убедитесь в правильности MILVUS_HOST + +2. **Ошибки LLM провайдера** + - Проверьте API ключи + - Убедитесь в доступности эндпоинтов + - Проверьте правильность названий моделей + +3. **Медленная обработка документов** + - Уменьшите размер чанков + - Используйте более быстрые модели эмбеддингов + - Оптимизируйте настройки Milvus + +--- + +## Вклад в проект + +Мы приветствуем вклад в развитие проекта! Пожалуйста: + +1. Форкните репозиторий +2. Создайте ветку для новой функции +3. Внесите изменения +4. Добавьте тесты +5. Отправьте Pull Request + +--- + +## Лицензия + +Проект распространяется под лицензией MIT. См. файл `LICENSE` для подробностей. + +--- + +## Поддержка + +Если у вас есть вопросы или проблемы: +- Создайте Issue в GitHub репозитории +- Обратитесь к документации API +- Проверьте примеры в папке `tests/` + +--- + +## Roadmap + +### Ближайшие планы +- [ ] Поддержка дополнительных форматов документов (PDF, DOCX) +- [ ] Веб-интерфейс для управления документами +- [ ] Улучшенная система метаданных +- [ ] Поддержка multimodal моделей +- [ ] Кэширование результатов +- [ ] Мониторинг и метрики + +### Долгосрочные цели +- [ ] Поддержка множественных языков +- [ ] Федеративный поиск по нескольким источникам +- [ ] Интеграция с внешними системами (SharePoint, Confluence) +- [ ] Продвинутая аналитика использования +- [ ] Система разрешений и ролей diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..0403c15 --- /dev/null +++ b/api/api.go @@ -0,0 +1,35 @@ +package api + +import ( + "fmt" + "github.com/labstack/echo/v4" + + "easy_rag/internal/pkg/rag" + "github.com/labstack/echo/v4/middleware" +) + +const ( + // APIVersion is the version of the API + APIVersion = "v1" +) + +func NewAPI(e *echo.Echo, rag *rag.Rag) { + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + // put rag pointer in context + e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("Rag", rag) + return next(c) + } + }) + + api := e.Group(fmt.Sprintf("/api/%s", APIVersion)) + + api.POST("/upload", UploadHandler) + api.POST("/ask", AskDocHandler) + api.GET("/docs", ListAllDocsHandler) + api.GET("/doc/:id", GetDocHandler) + api.DELETE("/doc/:id", DeleteDocHandler) +} diff --git a/api/handler.go b/api/handler.go new file mode 100644 index 0000000..4fcf3ef --- /dev/null +++ b/api/handler.go @@ -0,0 +1,248 @@ +package api + +import ( + "fmt" + "github.com/labstack/echo/v4" + "log" + "net/http" + + "easy_rag/internal/models" + "easy_rag/internal/pkg/rag" + "easy_rag/internal/pkg/textprocessor" + "github.com/google/uuid" +) + +type UploadDoc struct { + Content string `json:"content"` + Link string `json:"link"` + Filename string `json:"filename"` + Category string `json:"category"` + Metadata map[string]string `json:"metadata"` +} + +type RequestUpload struct { + Docs []UploadDoc `json:"docs"` +} + +type RequestQuestion struct { + Question string `json:"question"` +} + +type ResposeQuestion struct { + Version string `json:"version"` + Docs []models.Document `json:"docs"` + Answer string `json:"answer"` +} + +func UploadHandler(c echo.Context) error { + // Retrieve the RAG instance from context + rag := c.Get("Rag").(*rag.Rag) + + var request RequestUpload + if err := c.Bind(&request); err != nil { + return ErrorHandler(err, c) + } + + // Generate a unique task ID + taskID := uuid.NewString() + + // Launch the upload process in a separate goroutine + go func(taskID string, request RequestUpload) { + log.Printf("Task %s: started processing", taskID) + defer log.Printf("Task %s: completed processing", taskID) + + var docs []models.Document + + for idx, doc := range request.Docs { + // Generate a unique ID for each document + docID := uuid.NewString() + log.Printf("Task %s: processing document %d with generated ID %s (filename: %s)", taskID, idx, docID, doc.Filename) + + // Step 1: Create chunks from document content + chunks := textprocessor.CreateChunks(doc.Content) + log.Printf("Task %s: created %d chunks for document %s", taskID, len(chunks), docID) + + // Step 2: Generate summary for the document + var summaryChunks string + if len(chunks) < 4 { + summaryChunks = doc.Content + } else { + summaryChunks = textprocessor.ConcatenateStrings(chunks[:3]) + } + + log.Printf("Task %s: generating summary for document %s", taskID, docID) + summary, err := rag.LLM.Generate(fmt.Sprintf("Give me only summary of the following text: %s", summaryChunks)) + if err != nil { + log.Printf("Task %s: error generating summary for document %s: %v", taskID, docID, err) + return + } + log.Printf("Task %s: generated summary for document %s", taskID, docID) + + // Step 3: Vectorize the summary + log.Printf("Task %s: vectorizing summary for document %s", taskID, docID) + vectorSum, err := rag.Embeddings.Vectorize(summary) + if err != nil { + log.Printf("Task %s: error vectorizing summary for document %s: %v", taskID, docID, err) + return + } + log.Printf("Task %s: vectorized summary for document %s", taskID, docID) + + // Step 4: Save the document + document := models.Document{ + ID: docID, // Use generated ID + Content: "", + Link: doc.Link, + Filename: doc.Filename, + Category: doc.Category, + EmbeddingModel: rag.Embeddings.GetModel(), + Summary: summary, + Vector: vectorSum[0], + Metadata: doc.Metadata, + } + log.Printf("Task %s: saving document %s", taskID, docID) + if err := rag.Database.SaveDocument(document); err != nil { + log.Printf("Task %s: error saving document %s: %v", taskID, docID, err) + return + } + log.Printf("Task %s: saved document %s", taskID, docID) + + // Step 5: Process and save embeddings for each chunk + var embeddings []models.Embedding + for order, chunk := range chunks { + log.Printf("Task %s: vectorizing chunk %d for document %s", taskID, order, docID) + vectorEmbedding, err := rag.Embeddings.Vectorize(chunk) + if err != nil { + log.Printf("Task %s: error vectorizing chunk %d for document %s: %v", taskID, order, docID, err) + return + } + log.Printf("Task %s: vectorized chunk %d for document %s", taskID, order, docID) + + embedding := models.Embedding{ + ID: uuid.NewString(), + DocumentID: docID, + Vector: vectorEmbedding[0], + TextChunk: chunk, + Dimension: int64(1024), + Order: int64(order), + } + embeddings = append(embeddings, embedding) + } + + log.Printf("Task %s: saving %d embeddings for document %s", taskID, len(embeddings), docID) + if err := rag.Database.SaveEmbeddings(embeddings); err != nil { + log.Printf("Task %s: error saving embeddings for document %s: %v", taskID, docID, err) + return + } + log.Printf("Task %s: saved embeddings for document %s", taskID, docID) + + docs = append(docs, document) + } + }(taskID, request) + + // Return the task ID and expected completion time + return c.JSON(http.StatusAccepted, map[string]interface{}{ + "version": APIVersion, + "task_id": taskID, + "expected_time": "10m", + "status": "Processing started", + }) +} + +func ListAllDocsHandler(c echo.Context) error { + rag := c.Get("Rag").(*rag.Rag) + docs, err := rag.Database.ListDocuments() + if err != nil { + return ErrorHandler(err, c) + } + return c.JSON(http.StatusOK, map[string]interface{}{ + "version": APIVersion, + "docs": docs, + }) +} + +func GetDocHandler(c echo.Context) error { + rag := c.Get("Rag").(*rag.Rag) + id := c.Param("id") + doc, err := rag.Database.GetDocument(id) + if err != nil { + return ErrorHandler(err, c) + } + return c.JSON(http.StatusOK, map[string]interface{}{ + "version": APIVersion, + "doc": doc, + }) +} + +func AskDocHandler(c echo.Context) error { + rag := c.Get("Rag").(*rag.Rag) + + var request RequestQuestion + err := c.Bind(&request) + + if err != nil { + return ErrorHandler(err, c) + } + + questionV, err := rag.Embeddings.Vectorize(request.Question) + + if err != nil { + return ErrorHandler(err, c) + } + + embeddings, err := rag.Database.Search(questionV) + + if err != nil { + return ErrorHandler(err, c) + } + + if len(embeddings) == 0 { + return c.JSON(http.StatusOK, map[string]interface{}{ + "version": APIVersion, + "docs": nil, + "answer": "Don't found any relevant documents", + }) + } + + answer, err := rag.LLM.Generate(fmt.Sprintf("Given the following information: %s \nAnswer the question: %s", embeddings[0].TextChunk, request.Question)) + + if err != nil { + return ErrorHandler(err, c) + } + + // Use a map to track unique DocumentIDs + docSet := make(map[string]struct{}) + for _, embedding := range embeddings { + docSet[embedding.DocumentID] = struct{}{} + } + + // Convert the map keys to a slice + docs := make([]string, 0, len(docSet)) + for docID := range docSet { + docs = append(docs, docID) + } + + return c.JSON(http.StatusOK, map[string]interface{}{ + "version": APIVersion, + "docs": docs, + "answer": answer, + }) +} + +func DeleteDocHandler(c echo.Context) error { + rag := c.Get("Rag").(*rag.Rag) + id := c.Param("id") + err := rag.Database.DeleteDocument(id) + if err != nil { + return ErrorHandler(err, c) + } + return c.JSON(http.StatusOK, map[string]interface{}{ + "version": APIVersion, + "docs": nil, + }) +} + +func ErrorHandler(err error, c echo.Context) error { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "error": err.Error(), + }) +} diff --git a/cmd/rag/main.go b/cmd/rag/main.go new file mode 100644 index 0000000..2653d2b --- /dev/null +++ b/cmd/rag/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "easy_rag/api" + "easy_rag/config" + "easy_rag/internal/database" + "easy_rag/internal/embeddings" + "easy_rag/internal/llm" + "easy_rag/internal/pkg/rag" + "github.com/labstack/echo/v4" +) + +// Rag is the main struct for the rag application + +func main() { + cfg := config.NewConfig() + + llm := llm.NewOllama( + cfg.OllamaEndpoint, + cfg.OllamaModel, + ) + embeddings := embeddings.NewOllamaEmbeddings( + cfg.OllamaEmbeddingEndpoint, + cfg.OllamaEmbeddingModel, + ) + database := database.NewMilvus(cfg.MilvusHost) + + // Rag instance + rag := rag.NewRag( + llm, + embeddings, + database, + ) + + // Echo WebServer instance + e := echo.New() + + // Wrapper for API + api.NewAPI(e, rag) + + // Start Server + e.Logger.Fatal(e.Start(":4002")) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..d6a7e05 --- /dev/null +++ b/config/config.go @@ -0,0 +1,38 @@ +package config + +import cfg "github.com/eschao/config" + +type Config struct { + // LLM + OpenAIAPIKey string `env:"OPENAI_API_KEY"` + OpenAIEndpoint string `env:"OPENAI_ENDPOINT"` + OpenAIModel string `env:"OPENAI_MODEL"` + OpenRouteAPIKey string `env:"OPENROUTE_API_KEY"` + OpenRouteEndpoint string `env:"OPENROUTE_ENDPOINT"` + OpenRouteModel string `env:"OPENROUTE_MODEL"` + + OllamaEndpoint string `env:"OLLAMA_ENDPOINT"` + OllamaModel string `env:"OLLAMA_MODEL"` + + // Embeddings + OpenAIEmbeddingAPIKey string `env:"OPENAI_EMBEDDING_API_KEY"` + OpenAIEmbeddingEndpoint string `env:"OPENAI_EMBEDDING_ENDPOINT"` + OpenAIEmbeddingModel string `env:"OPENAI_EMBEDDING_MODEL"` + OllamaEmbeddingEndpoint string `env:"OLLAMA_EMBEDDING_ENDPOINT"` + OllamaEmbeddingModel string `env:"OLLAMA_EMBEDDING_MODEL"` + + // Database + MilvusHost string `env:"MILVUS_HOST"` +} + +func NewConfig() Config { + config := Config{ + MilvusHost: "192.168.10.56:19530", + OllamaEmbeddingEndpoint: "http://192.168.10.56:11434", + OllamaEmbeddingModel: "bge-m3", + OllamaEndpoint: "http://192.168.10.56:11434/api/chat", + OllamaModel: "qwen3:latest", + } + cfg.ParseEnv(&config) + return config +} diff --git a/deploy/Dockerfile b/deploy/Dockerfile new file mode 100644 index 0000000..ccf9f09 --- /dev/null +++ b/deploy/Dockerfile @@ -0,0 +1,16 @@ +FROM golang:1.20 AS builder + +WORKDIR /cmd + +COPY . . + +RUN go mod download +RUN go build -o rag ./cmd/rag + +FROM alpine:latest + +WORKDIR /root/ + +COPY --from=builder /cmd/rag ./ + +CMD ["./rag"] \ No newline at end of file diff --git a/deploy/docker-compose.yaml b/deploy/docker-compose.yaml new file mode 100644 index 0000000..e69de29 diff --git a/embedEtcd.yaml b/embedEtcd.yaml new file mode 100644 index 0000000..32954fa --- /dev/null +++ b/embedEtcd.yaml @@ -0,0 +1,5 @@ +listen-client-urls: http://0.0.0.0:2379 +advertise-client-urls: http://0.0.0.0:2379 +quota-backend-bytes: 4294967296 +auto-compaction-mode: revision +auto-compaction-retention: '1000' diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ce63a30 --- /dev/null +++ b/go.mod @@ -0,0 +1,49 @@ +module easy_rag + +go 1.24.3 + +require ( + github.com/eschao/config v0.1.0 + github.com/google/uuid v1.6.0 + github.com/jonathanhecl/chunker v0.0.1 + github.com/labstack/echo/v4 v4.13.4 + github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/cockroachdb/errors v1.9.1 // indirect + github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect + github.com/cockroachdb/redact v1.1.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/getsentry/sentry-go v0.12.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/labstack/gommon v0.4.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.8.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + golang.org/x/time v0.11.0 // indirect + google.golang.org/genproto v0.0.0-20220503193339-ba3ae3f07e29 // indirect + google.golang.org/grpc v1.48.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 0000000..c543e7e --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,20 @@ +package database + +import "easy_rag/internal/models" + +// database interface + +// Database defines the interface for interacting with a database +type Database interface { + SaveDocument(document models.Document) error // the content will be chunked and saved + GetDocumentInfo(id string) (models.Document, error) // return the document with the given id without content + GetDocument(id string) (models.Document, error) // return the document with the given id with content assembled + Search(vector [][]float32) ([]models.Embedding, error) + ListDocuments() ([]models.Document, error) + DeleteDocument(id string) error + SaveEmbeddings(embeddings []models.Embedding) error + // to implement in future + // SearchByCategory(category []string) ([]Embedding, error) + // SearchByMetadata(metadata map[string]string) ([]Embedding, error) + // GetAllEmbeddingByDocumentID(documentID string) ([]Embedding, error) +} diff --git a/internal/database/database_milvus.go b/internal/database/database_milvus.go new file mode 100644 index 0000000..1bf6771 --- /dev/null +++ b/internal/database/database_milvus.go @@ -0,0 +1,142 @@ +package database + +import ( + "bytes" + "context" + "fmt" + "sort" + + "easy_rag/internal/models" + "easy_rag/internal/pkg/database/milvus" +) + +// implement database interface for milvus +type Milvus struct { + Host string + Client *milvus.Client +} + +func NewMilvus(host string) *Milvus { + + milviusClient, err := milvus.NewClient(host) + if err != nil { + panic(err) + } + + return &Milvus{ + Host: host, + Client: milviusClient, + } +} + +func (m *Milvus) SaveDocument(document models.Document) error { + // for now lets use context background + ctx := context.Background() + + return m.Client.InsertDocuments(ctx, []models.Document{document}) +} + +func (m *Milvus) SaveEmbeddings(embeddings []models.Embedding) error { + ctx := context.Background() + return m.Client.InsertEmbeddings(ctx, embeddings) +} + +func (m *Milvus) GetDocumentInfo(id string) (models.Document, error) { + ctx := context.Background() + doc, err := m.Client.GetDocumentByID(ctx, id) + + if err != nil { + return models.Document{}, err + } + + if len(doc) == 0 { + return models.Document{}, nil + } + return models.Document{ + ID: doc["ID"].(string), + Link: doc["Link"].(string), + Filename: doc["Filename"].(string), + Category: doc["Category"].(string), + EmbeddingModel: doc["EmbeddingModel"].(string), + Summary: doc["Summary"].(string), + Metadata: doc["Metadata"].(map[string]string), + }, nil +} + +func (m *Milvus) GetDocument(id string) (models.Document, error) { + ctx := context.Background() + doc, err := m.Client.GetDocumentByID(ctx, id) + if err != nil { + return models.Document{}, err + } + + embeds, err := m.Client.GetAllEmbeddingByDocID(ctx, id) + + if err != nil { + return models.Document{}, err + } + + // order embed by order + sort.Slice(embeds, func(i, j int) bool { + return embeds[i].Order < embeds[j].Order + }) + + // concatenate text chunks + var buf bytes.Buffer + for _, embed := range embeds { + buf.WriteString(embed.TextChunk) + } + + textChunks := buf.String() + + if len(doc) == 0 { + return models.Document{}, nil + } + return models.Document{ + ID: doc["ID"].(string), + Content: textChunks, + Link: doc["Link"].(string), + Filename: doc["Filename"].(string), + Category: doc["Category"].(string), + EmbeddingModel: doc["EmbeddingModel"].(string), + Summary: doc["Summary"].(string), + Metadata: doc["Metadata"].(map[string]string), + }, nil +} + +func (m *Milvus) Search(vector [][]float32) ([]models.Embedding, error) { + ctx := context.Background() + results, err := m.Client.Search(ctx, vector, 10) + if err != nil { + return nil, err + } + + return results, nil +} + +func (m *Milvus) ListDocuments() ([]models.Document, error) { + ctx := context.Background() + + docs, err := m.Client.GetAllDocuments(ctx) + + if err != nil { + return nil, fmt.Errorf("failed to get docs: %w", err) + } + + return docs, nil +} + +func (m *Milvus) DeleteDocument(id string) error { + ctx := context.Background() + err := m.Client.DeleteDocument(ctx, id) + if err != nil { + return err + } + + err = m.Client.DeleteEmbedding(ctx, id) + if err != nil { + return err + } + + return nil +} diff --git a/internal/embeddings/embeddings.go b/internal/embeddings/embeddings.go new file mode 100644 index 0000000..77a6fd3 --- /dev/null +++ b/internal/embeddings/embeddings.go @@ -0,0 +1,8 @@ +package embeddings + +// implement embeddings interface +type EmbeddingsService interface { + // generate embedding from text + Vectorize(text string) ([][]float32, error) + GetModel() string +} diff --git a/internal/embeddings/ollama_embeddings.go b/internal/embeddings/ollama_embeddings.go new file mode 100644 index 0000000..ac8f918 --- /dev/null +++ b/internal/embeddings/ollama_embeddings.go @@ -0,0 +1,78 @@ +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" +) + +type OllamaEmbeddings struct { + Endpoint string + Model string +} + +func NewOllamaEmbeddings(endpoint string, model string) *OllamaEmbeddings { + return &OllamaEmbeddings{ + Endpoint: endpoint, + Model: model, + } +} + +// Vectorize generates an embedding for the provided text +func (o *OllamaEmbeddings) Vectorize(text string) ([][]float32, error) { + // Define the request payload + payload := map[string]string{ + "model": o.Model, + "input": text, + } + + // Convert the payload to JSON + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request payload: %w", err) + } + + // Create the HTTP request + url := fmt.Sprintf("%s/api/embed", o.Endpoint) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + // Execute the HTTP request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make HTTP request: %w", err) + } + defer resp.Body.Close() + + // Check for non-200 status code + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("received non-200 response: %s", body) + } + + // Read and parse the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Assuming the response JSON contains an "embedding" field with a float32 array + var response struct { + Embeddings [][]float32 `json:"embeddings"` + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return response.Embeddings, nil +} + +func (o *OllamaEmbeddings) GetModel() string { + return o.Model +} diff --git a/internal/embeddings/openai_embeddings.go b/internal/embeddings/openai_embeddings.go new file mode 100644 index 0000000..f42fbb8 --- /dev/null +++ b/internal/embeddings/openai_embeddings.go @@ -0,0 +1,23 @@ +package embeddings + +type OpenAIEmbeddings struct { + APIKey string + Endpoint string + Model string +} + +func NewOpenAIEmbeddings(apiKey string, endpoint string, model string) *OpenAIEmbeddings { + return &OpenAIEmbeddings{ + APIKey: apiKey, + Endpoint: endpoint, + Model: model, + } +} + +func (o *OpenAIEmbeddings) Vectorize(text string) ([]float32, error) { + return nil, nil +} + +func (o *OpenAIEmbeddings) GetModel() string { + return o.Model +} diff --git a/internal/llm/llm.go b/internal/llm/llm.go new file mode 100644 index 0000000..88ef5f3 --- /dev/null +++ b/internal/llm/llm.go @@ -0,0 +1,8 @@ +package llm + +// implement llm interface +type LLMService interface { + // generate text from prompt + Generate(prompt string) (string, error) + GetModel() string +} diff --git a/internal/llm/ollama_llm.go b/internal/llm/ollama_llm.go new file mode 100644 index 0000000..54d9271 --- /dev/null +++ b/internal/llm/ollama_llm.go @@ -0,0 +1,82 @@ +package llm + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" +) + +type Ollama struct { + Endpoint string + Model string +} + +func NewOllama(endpoint string, model string) *Ollama { + return &Ollama{ + Endpoint: endpoint, + Model: model, + } +} + +// Response represents the structure of the expected response from the API. +type Response struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` +} + +// Generate sends a prompt to the Ollama endpoint and returns the response +func (o *Ollama) Generate(prompt string) (string, error) { + // Create the request payload + payload := map[string]interface{}{ + "model": o.Model, + "messages": []map[string]string{ + { + "role": "user", + "content": prompt, + }, + }, + "stream": false, + } + + // Marshal the payload into JSON + data, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + // Make the POST request + resp, err := http.Post(o.Endpoint, "application/json", bytes.NewBuffer(data)) + if err != nil { + return "", fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() + + // Read and parse the response + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API returned error: %s", string(body)) + } + + // Unmarshal the response into a predefined structure + var response Response + if err := json.Unmarshal(body, &response); err != nil { + return "", fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Extract and return the content from the nested structure + return response.Message.Content, nil +} + +func (o *Ollama) GetModel() string { + return o.Model +} diff --git a/internal/llm/openai_llm.go b/internal/llm/openai_llm.go new file mode 100644 index 0000000..d88f3f8 --- /dev/null +++ b/internal/llm/openai_llm.go @@ -0,0 +1,24 @@ +package llm + +type OpenAI struct { + APIKey string + Endpoint string + Model string +} + +func NewOpenAI(apiKey string, endpoint string, model string) *OpenAI { + return &OpenAI{ + APIKey: apiKey, + Endpoint: endpoint, + Model: model, + } +} + +func (o *OpenAI) Generate(prompt string) (string, error) { + return "", nil + // TODO: implement +} + +func (o *OpenAI) GetModel() string { + return o.Model +} diff --git a/internal/llm/openroute/definitions.go b/internal/llm/openroute/definitions.go new file mode 100644 index 0000000..f427600 --- /dev/null +++ b/internal/llm/openroute/definitions.go @@ -0,0 +1,155 @@ +package openroute + +// Request represents the main request structure. +type Request struct { + Messages []MessageRequest `json:"messages,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice ToolChoice `json:"tool_choice,omitempty"` + Seed int `json:"seed,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + RepetitionPenalty float64 `json:"repetition_penalty,omitempty"` + LogitBias map[int]float64 `json:"logit_bias,omitempty"` + TopLogprobs int `json:"top_logprobs,omitempty"` + MinP float64 `json:"min_p,omitempty"` + TopA float64 `json:"top_a,omitempty"` + Prediction *Prediction `json:"prediction,omitempty"` + Transforms []string `json:"transforms,omitempty"` + Models []string `json:"models,omitempty"` + Route string `json:"route,omitempty"` + Provider *ProviderPreferences `json:"provider,omitempty"` + IncludeReasoning bool `json:"include_reasoning,omitempty"` +} + +// ResponseFormat represents the response format structure. +type ResponseFormat struct { + Type string `json:"type"` +} + +// Prediction represents the prediction structure. +type Prediction struct { + Type string `json:"type"` + Content string `json:"content"` +} + +// ProviderPreferences represents the provider preferences structure. +type ProviderPreferences struct { + RefererURL string `json:"referer_url,omitempty"` + SiteName string `json:"site_name,omitempty"` +} + +// Message represents the message structure. +type MessageRequest struct { + Role MessageRole `json:"role"` + Content interface{} `json:"content"` // Can be string or []ContentPart + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type MessageRole string + +const ( + RoleSystem MessageRole = "system" + RoleUser MessageRole = "user" + RoleAssistant MessageRole = "assistant" +) + +// ContentPart represents the content part structure. +type ContentPart struct { + Type ContnetType `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +type ContnetType string + +const ( + ContentTypeText ContnetType = "text" + ContentTypeImage ContnetType = "image_url" +) + +// ImageURL represents the image URL structure. +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// FunctionDescription represents the function description structure. +type FunctionDescription struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters interface{} `json:"parameters"` // JSON Schema object +} + +// Tool represents the tool structure. +type Tool struct { + Type string `json:"type"` + Function FunctionDescription `json:"function"` +} + +// ToolChoice represents the tool choice structure. +type ToolChoice struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` +} + +type Response struct { + ID string `json:"id"` + Choices []Choice `json:"choices"` + Created int64 `json:"created"` + Model string `json:"model"` + Object string `json:"object"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` + Usage *ResponseUsage `json:"usage,omitempty"` +} + +type ResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Choice struct { + FinishReason string `json:"finish_reason"` + Text string `json:"text,omitempty"` + Message *MessageResponse `json:"message,omitempty"` + Delta *Delta `json:"delta,omitempty"` + Error *ErrorResponse `json:"error,omitempty"` +} + +type MessageResponse struct { + Content string `json:"content"` + Role string `json:"role"` + Reasoning string `json:"reasoning,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type Delta struct { + Content string `json:"content"` + Role string `json:"role,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ErrorResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function interface{} `json:"function"` +} diff --git a/internal/llm/openroute/route_agent.go b/internal/llm/openroute/route_agent.go new file mode 100644 index 0000000..8fd58bd --- /dev/null +++ b/internal/llm/openroute/route_agent.go @@ -0,0 +1,198 @@ +package openroute + +import "context" + +type RouterAgentConfig struct { + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Stop []string `json:"stop,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice ToolChoice `json:"tool_choice,omitempty"` + Seed int `json:"seed,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + RepetitionPenalty float64 `json:"repetition_penalty,omitempty"` + LogitBias map[int]float64 `json:"logit_bias,omitempty"` + TopLogprobs int `json:"top_logprobs,omitempty"` + MinP float64 `json:"min_p,omitempty"` + TopA float64 `json:"top_a,omitempty"` +} + +type RouterAgent struct { + client *OpenRouterClient + model string + config RouterAgentConfig +} + +func NewRouterAgent(client *OpenRouterClient, model string, config RouterAgentConfig) *RouterAgent { + return &RouterAgent{ + client: client, + model: model, + config: config, + } +} + +func (agent RouterAgent) Completion(prompt string) (*Response, error) { + request := Request{ + Prompt: prompt, + Model: agent.model, + ResponseFormat: agent.config.ResponseFormat, + Stop: agent.config.Stop, + MaxTokens: agent.config.MaxTokens, + Temperature: agent.config.Temperature, + Tools: agent.config.Tools, + ToolChoice: agent.config.ToolChoice, + Seed: agent.config.Seed, + TopP: agent.config.TopP, + TopK: agent.config.TopK, + FrequencyPenalty: agent.config.FrequencyPenalty, + PresencePenalty: agent.config.PresencePenalty, + RepetitionPenalty: agent.config.RepetitionPenalty, + LogitBias: agent.config.LogitBias, + TopLogprobs: agent.config.TopLogprobs, + MinP: agent.config.MinP, + TopA: agent.config.TopA, + Stream: false, + } + + return agent.client.FetchChatCompletions(request) +} + +func (agent RouterAgent) CompletionStream(prompt string, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) { + request := Request{ + Prompt: prompt, + Model: agent.model, + ResponseFormat: agent.config.ResponseFormat, + Stop: agent.config.Stop, + MaxTokens: agent.config.MaxTokens, + Temperature: agent.config.Temperature, + Tools: agent.config.Tools, + ToolChoice: agent.config.ToolChoice, + Seed: agent.config.Seed, + TopP: agent.config.TopP, + TopK: agent.config.TopK, + FrequencyPenalty: agent.config.FrequencyPenalty, + PresencePenalty: agent.config.PresencePenalty, + RepetitionPenalty: agent.config.RepetitionPenalty, + LogitBias: agent.config.LogitBias, + TopLogprobs: agent.config.TopLogprobs, + MinP: agent.config.MinP, + TopA: agent.config.TopA, + Stream: true, + } + + agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx) +} + +func (agent RouterAgent) Chat(messages []MessageRequest) (*Response, error) { + request := Request{ + Messages: messages, + Model: agent.model, + ResponseFormat: agent.config.ResponseFormat, + Stop: agent.config.Stop, + MaxTokens: agent.config.MaxTokens, + Temperature: agent.config.Temperature, + Tools: agent.config.Tools, + ToolChoice: agent.config.ToolChoice, + Seed: agent.config.Seed, + TopP: agent.config.TopP, + TopK: agent.config.TopK, + FrequencyPenalty: agent.config.FrequencyPenalty, + PresencePenalty: agent.config.PresencePenalty, + RepetitionPenalty: agent.config.RepetitionPenalty, + LogitBias: agent.config.LogitBias, + TopLogprobs: agent.config.TopLogprobs, + MinP: agent.config.MinP, + TopA: agent.config.TopA, + Stream: false, + } + + return agent.client.FetchChatCompletions(request) +} + +func (agent RouterAgent) ChatStream(messages []MessageRequest, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) { + request := Request{ + Messages: messages, + Model: agent.model, + ResponseFormat: agent.config.ResponseFormat, + Stop: agent.config.Stop, + MaxTokens: agent.config.MaxTokens, + Temperature: agent.config.Temperature, + Tools: agent.config.Tools, + ToolChoice: agent.config.ToolChoice, + Seed: agent.config.Seed, + TopP: agent.config.TopP, + TopK: agent.config.TopK, + FrequencyPenalty: agent.config.FrequencyPenalty, + PresencePenalty: agent.config.PresencePenalty, + RepetitionPenalty: agent.config.RepetitionPenalty, + LogitBias: agent.config.LogitBias, + TopLogprobs: agent.config.TopLogprobs, + MinP: agent.config.MinP, + TopA: agent.config.TopA, + Stream: true, + } + + agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx) +} + +type RouterAgentChat struct { + RouterAgent + Messages []MessageRequest +} + +func NewRouterAgentChat(client *OpenRouterClient, model string, config RouterAgentConfig, system_prompt string) RouterAgentChat { + return RouterAgentChat{ + RouterAgent: RouterAgent{ + client: client, + model: model, + config: config, + }, + Messages: []MessageRequest{ + { + Role: RoleSystem, + Content: system_prompt, + }, + }, + } +} + +func (agent *RouterAgentChat) Chat(message string) error { + agent.Messages = append(agent.Messages, MessageRequest{ + Role: RoleUser, + Content: message, + }) + request := Request{ + Messages: agent.Messages, + Model: agent.model, + ResponseFormat: agent.config.ResponseFormat, + Stop: agent.config.Stop, + MaxTokens: agent.config.MaxTokens, + Temperature: agent.config.Temperature, + Tools: agent.config.Tools, + ToolChoice: agent.config.ToolChoice, + Seed: agent.config.Seed, + TopP: agent.config.TopP, + TopK: agent.config.TopK, + FrequencyPenalty: agent.config.FrequencyPenalty, + PresencePenalty: agent.config.PresencePenalty, + RepetitionPenalty: agent.config.RepetitionPenalty, + LogitBias: agent.config.LogitBias, + TopLogprobs: agent.config.TopLogprobs, + MinP: agent.config.MinP, + TopA: agent.config.TopA, + Stream: false, + } + + response, err := agent.client.FetchChatCompletions(request) + + agent.Messages = append(agent.Messages, MessageRequest{ + Role: RoleAssistant, + Content: response.Choices[0].Message.Content, + }) + + return err +} diff --git a/internal/llm/openroute/route_client.go b/internal/llm/openroute/route_client.go new file mode 100644 index 0000000..eade483 --- /dev/null +++ b/internal/llm/openroute/route_client.go @@ -0,0 +1,188 @@ +package openroute + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +type OpenRouterClient struct { + apiKey string + apiURL string + httpClient *http.Client +} + +func NewOpenRouterClient(apiKey string) *OpenRouterClient { + return &OpenRouterClient{ + apiKey: apiKey, + apiURL: "https://openrouter.ai/api/v1", + httpClient: &http.Client{}, + } +} + +func NewOpenRouterClientFull(apiKey string, apiUrl string, client *http.Client) *OpenRouterClient { + return &OpenRouterClient{ + apiKey: apiKey, + apiURL: apiUrl, + httpClient: client, + } +} + +func (c *OpenRouterClient) FetchChatCompletions(request Request) (*Response, error) { + headers := map[string]string{ + "Authorization": "Bearer " + c.apiKey, + "Content-Type": "application/json", + } + + if request.Provider != nil { + headers["HTTP-Referer"] = request.Provider.RefererURL + headers["X-Title"] = request.Provider.SiteName + } + + body, err := json.Marshal(request) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + output, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + outputReponse := &Response{} + err = json.Unmarshal(output, outputReponse) + if err != nil { + return nil, err + } + + return outputReponse, nil +} + +func (c *OpenRouterClient) FetchChatCompletionsStream(request Request, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) { + headers := map[string]string{ + "Authorization": "Bearer " + c.apiKey, + "Content-Type": "application/json", + } + + if request.Provider != nil { + headers["HTTP-Referer"] = request.Provider.RefererURL + headers["X-Title"] = request.Provider.SiteName + } + + body, err := json.Marshal(request) + if err != nil { + errChan <- err + close(errChan) + close(outputChan) + close(processingChan) + return + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body)) + if err != nil { + errChan <- err + close(errChan) + close(outputChan) + close(processingChan) + return + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + errChan <- err + close(errChan) + close(outputChan) + close(processingChan) + return + } + + go func() { + defer resp.Body.Close() + + defer close(errChan) + defer close(outputChan) + defer close(processingChan) + if resp.StatusCode != http.StatusOK { + errChan <- fmt.Errorf("unexpected status code: %d", resp.StatusCode) + return + } + + reader := bufio.NewReader(resp.Body) + for { + select { + case <-ctx.Done(): + errChan <- ctx.Err() + close(errChan) + close(outputChan) + close(processingChan) + return + default: + line, err := reader.ReadString('\n') + line = strings.TrimSpace(line) + if strings.HasPrefix(line, ":") { + select { + case processingChan <- true: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + continue + } + + if line != "" { + if strings.Compare(line[6:], "[DONE]") == 0 { + return + } + response := Response{} + err = json.Unmarshal([]byte(line[6:]), &response) + if err != nil { + errChan <- err + return + } + select { + case outputChan <- response: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + + if err != nil { + if err == io.EOF { + return + } + errChan <- err + return + } + } + } + }() +} diff --git a/internal/llm/openroute_llm.go b/internal/llm/openroute_llm.go new file mode 100644 index 0000000..bba1168 --- /dev/null +++ b/internal/llm/openroute_llm.go @@ -0,0 +1,24 @@ +package llm + +type OpenRoute struct { + APIKey string + Endpoint string + Model string +} + +func NewOpenRoute(apiKey string, endpoint string, model string) *OpenAI { + return &OpenAI{ + APIKey: apiKey, + Endpoint: endpoint, + Model: model, + } +} + +func (o *OpenRoute) Generate(prompt string) (string, error) { + return "", nil + // TODO: implement +} + +func (o *OpenRoute) GetModel() string { + return o.Model +} diff --git a/internal/models/models.go b/internal/models/models.go new file mode 100644 index 0000000..0531a9e --- /dev/null +++ b/internal/models/models.go @@ -0,0 +1,27 @@ +package models + +// type VectorEmbedding [][]float32 +// type Vector []float32 +// Document represents the data structure for storing documents +type Document struct { + ID string `json:"id" milvus:"ID"` // Unique identifier for the document + Content string `json:"content" milvus:"Content"` // Text content of the document become chunks of data will not be saved + Link string `json:"link" milvus:"Link"` // Link to the document + Filename string `json:"filename" milvus:"Filename"` // Filename of the document + Category string `json:"category" milvus:"Category"` // Category of the document + EmbeddingModel string `json:"embedding_model" milvus:"EmbeddingModel"` // Embedding model used to generate the embedding + Summary string `json:"summary" milvus:"Summary"` // Summary of the document + Metadata map[string]string `json:"metadata" milvus:"Metadata"` // Additional metadata (e.g., author, timestamp) + Vector []float32 `json:"vector" milvus:"Vector"` +} + +// Embedding represents the vector embedding for a document or query +type Embedding struct { + ID string `json:"id" milvus:"ID"` // Unique identifier + DocumentID string `json:"document_id" milvus:"DocumentID"` // Unique identifier linked to a Document + Vector []float32 `json:"vector" milvus:"Vector"` // The embedding vector + TextChunk string `json:"text_chunk" milvus:"TextChunk"` // Text chunk of the document + Dimension int64 `json:"dimension" milvus:"Dimension"` // Dimensionality of the vector + Order int64 `json:"order" milvus:"Order"` // Order of the embedding to build the content back + Score float32 `json:"score"` // Score of the embedding +} diff --git a/internal/pkg/database/milvus/client.go b/internal/pkg/database/milvus/client.go new file mode 100644 index 0000000..c4a7e9d --- /dev/null +++ b/internal/pkg/database/milvus/client.go @@ -0,0 +1,169 @@ +package milvus + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +type Client struct { + Instance client.Client +} + +// InitMilvusClient initializes the Milvus client and returns a wrapper around it. +func NewClient(milvusAddr string) (*Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c, err := client.NewClient(ctx, client.Config{Address: milvusAddr}) + if err != nil { + log.Printf("Failed to connect to Milvus: %v", err) + return nil, err + } + + client := &Client{Instance: c} + + err = client.EnsureCollections(ctx) + if err != nil { + return nil, err + } + + return client, nil +} + +// EnsureCollections ensures that the required collections ("documents" and "chunks") exist. +// If they don't exist, it creates them based on the predefined structs. +func (m *Client) EnsureCollections(ctx context.Context) error { + collections := []struct { + Name string + Schema *entity.Schema + IndexField string + IndexType string + MetricType entity.MetricType + Nlist int + }{ + { + Name: "documents", + Schema: createDocumentSchema(), + IndexField: "Vector", // Indexing the Vector field for similarity search + IndexType: "IVF_FLAT", + MetricType: entity.L2, + Nlist: 10, // Number of clusters for IVF_FLAT index + }, + { + Name: "chunks", + Schema: createEmbeddingSchema(), + IndexField: "Vector", // Indexing the Vector field for similarity search + IndexType: "IVF_FLAT", + MetricType: entity.L2, + Nlist: 10, + }, + } + + for _, collection := range collections { + // drop collection + // err := m.Instance.DropCollection(ctx, collection.Name) + // if err != nil { + // return fmt.Errorf("failed to drop collection '%s': %w", collection.Name, err) + // } + + // Ensure the collection exists + exists, err := m.Instance.HasCollection(ctx, collection.Name) + if err != nil { + return fmt.Errorf("failed to check collection existence: %w", err) + } + + if !exists { + err := m.Instance.CreateCollection(ctx, collection.Schema, entity.DefaultShardNumber) + if err != nil { + return fmt.Errorf("failed to create collection '%s': %w", collection.Name, err) + } + log.Printf("Collection '%s' created successfully", collection.Name) + } else { + log.Printf("Collection '%s' already exists", collection.Name) + } + + // Ensure the default partition exists + hasPartition, err := m.Instance.HasPartition(ctx, collection.Name, "_default") + if err != nil { + return fmt.Errorf("failed to check default partition for collection '%s': %w", collection.Name, err) + } + + if !hasPartition { + err = m.Instance.CreatePartition(ctx, collection.Name, "_default") + if err != nil { + return fmt.Errorf("failed to create default partition for collection '%s': %w", collection.Name, err) + } + log.Printf("Default partition created for collection '%s'", collection.Name) + } + + // Skip index creation if IndexField is empty + if collection.IndexField == "" { + continue + } + + // Ensure the index exists + log.Printf("Creating index on field '%s' for collection '%s'", collection.IndexField, collection.Name) + + idx, err := entity.NewIndexIvfFlat(collection.MetricType, collection.Nlist) + if err != nil { + return fmt.Errorf("failed to create IVF_FLAT index: %w", err) + } + + err = m.Instance.CreateIndex(ctx, collection.Name, collection.IndexField, idx, false) + if err != nil { + return fmt.Errorf("failed to create index on field '%s' for collection '%s': %w", collection.IndexField, collection.Name, err) + } + + log.Printf("Index created on field '%s' for collection '%s'", collection.IndexField, collection.Name) + } + + err := m.Instance.LoadCollection(ctx, "documents", false) + if err != nil { + log.Fatalf("failed to load collection, err: %v", err) + } + + err = m.Instance.LoadCollection(ctx, "chunks", false) + if err != nil { + log.Fatalf("failed to load collection, err: %v", err) + } + + return nil +} + +// Helper functions for creating schemas +func createDocumentSchema() *entity.Schema { + return entity.NewSchema(). + WithName("documents"). + WithDescription("Collection for storing documents"). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)). + WithField(entity.NewField().WithName("Content").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)). + WithField(entity.NewField().WithName("Link").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)). + WithField(entity.NewField().WithName("Filename").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)). + WithField(entity.NewField().WithName("Category").WithDataType(entity.FieldTypeVarChar).WithMaxLength(8048)). + WithField(entity.NewField().WithName("EmbeddingModel").WithDataType(entity.FieldTypeVarChar).WithMaxLength(256)). + WithField(entity.NewField().WithName("Summary").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)). + WithField(entity.NewField().WithName("Metadata").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)) // bge-m3 +} + +func createEmbeddingSchema() *entity.Schema { + return entity.NewSchema(). + WithName("chunks"). + WithDescription("Collection for storing document embeddings"). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)). + WithField(entity.NewField().WithName("DocumentID").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)). // bge-m3 + WithField(entity.NewField().WithName("TextChunk").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)). + WithField(entity.NewField().WithName("Dimension").WithDataType(entity.FieldTypeInt32)). + WithField(entity.NewField().WithName("Order").WithDataType(entity.FieldTypeInt32)) +} + +// Close closes the Milvus client connection. +func (m *Client) Close() { + m.Instance.Close() +} diff --git a/internal/pkg/database/milvus/client_test.go b/internal/pkg/database/milvus/client_test.go new file mode 100644 index 0000000..360ca32 --- /dev/null +++ b/internal/pkg/database/milvus/client_test.go @@ -0,0 +1,32 @@ +package milvus + +import ( + "reflect" + "testing" +) + +func TestNewClient(t *testing.T) { + type args struct { + milvusAddr string + } + tests := []struct { + name string + args args + want *Client + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClient(tt.args.milvusAddr) + if (err != nil) != tt.wantErr { + t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewClient() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/pkg/database/milvus/helpers.go b/internal/pkg/database/milvus/helpers.go new file mode 100644 index 0000000..cfecbab --- /dev/null +++ b/internal/pkg/database/milvus/helpers.go @@ -0,0 +1,276 @@ +package milvus + +import ( + "encoding/json" + "fmt" + + "easy_rag/internal/models" + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +// Helper functions for extracting data +func extractIDs(docs []models.Document) []string { + ids := make([]string, len(docs)) + for i, doc := range docs { + ids[i] = doc.ID + } + return ids +} + +// extractLinks extracts the "Link" field from the documents. +func extractLinks(docs []models.Document) []string { + links := make([]string, len(docs)) + for i, doc := range docs { + links[i] = doc.Link + } + return links +} + +// extractFilenames extracts the "Filename" field from the documents. +func extractFilenames(docs []models.Document) []string { + filenames := make([]string, len(docs)) + for i, doc := range docs { + filenames[i] = doc.Filename + } + return filenames +} + +// extractCategories extracts the "Category" field from the documents as a comma-separated string. +func extractCategories(docs []models.Document) []string { + categories := make([]string, len(docs)) + for i, doc := range docs { + categories[i] = fmt.Sprintf("%v", doc.Category) + } + return categories +} + +// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents. +func extractEmbeddingModels(docs []models.Document) []string { + models := make([]string, len(docs)) + for i, doc := range docs { + models[i] = doc.EmbeddingModel + } + return models +} + +// extractSummaries extracts the "Summary" field from the documents. +func extractSummaries(docs []models.Document) []string { + summaries := make([]string, len(docs)) + for i, doc := range docs { + summaries[i] = doc.Summary + } + return summaries +} + +// extractMetadata extracts the "Metadata" field from the documents as a JSON string. +func extractMetadata(docs []models.Document) []string { + metadata := make([]string, len(docs)) + for i, doc := range docs { + metaBytes, _ := json.Marshal(doc.Metadata) + metadata[i] = string(metaBytes) + } + return metadata +} + +func convertToMetadata(metadata string) map[string]string { + var metadataMap map[string]string + json.Unmarshal([]byte(metadata), &metadataMap) + return metadataMap +} + +func extractContents(docs []models.Document) []string { + contents := make([]string, len(docs)) + for i, doc := range docs { + contents[i] = doc.Content + } + return contents +} + +// extractEmbeddingIDs extracts the "ID" field from the embeddings. +func extractEmbeddingIDs(embeddings []models.Embedding) []string { + ids := make([]string, len(embeddings)) + for i, embedding := range embeddings { + ids[i] = embedding.ID + } + return ids +} + +// extractDocumentIDs extracts the "DocumentID" field from the embeddings. +func extractDocumentIDs(embeddings []models.Embedding) []string { + documentIDs := make([]string, len(embeddings)) + for i, embedding := range embeddings { + documentIDs[i] = embedding.DocumentID + } + return documentIDs +} + +// extractVectors extracts the "Vector" field from the embeddings. +func extractVectors(embeddings []models.Embedding) [][]float32 { + vectors := make([][]float32, len(embeddings)) + for i, embedding := range embeddings { + vectors[i] = embedding.Vector // Direct assignment since it's already []float32 + } + return vectors +} + +// extractVectorsDocs extracts the "Vector" field from the documents. +func extractVectorsDocs(docs []models.Document) [][]float32 { + vectors := make([][]float32, len(docs)) + for i, doc := range docs { + vectors[i] = doc.Vector // Direct assignment since it's already []float32 + } + return vectors +} + +// extractTextChunks extracts the "TextChunk" field from the embeddings. +func extractTextChunks(embeddings []models.Embedding) []string { + textChunks := make([]string, len(embeddings)) + for i, embedding := range embeddings { + textChunks[i] = embedding.TextChunk + } + return textChunks +} + +// extractDimensions extracts the "Dimension" field from the embeddings. +func extractDimensions(embeddings []models.Embedding) []int32 { + dimensions := make([]int32, len(embeddings)) + for i, embedding := range embeddings { + dimensions[i] = int32(embedding.Dimension) + } + return dimensions +} + +// extractOrders extracts the "Order" field from the embeddings. +func extractOrders(embeddings []models.Embedding) []int32 { + orders := make([]int32, len(embeddings)) + for i, embedding := range embeddings { + orders[i] = int32(embedding.Order) + } + return orders +} + +func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) { + if rs == nil || rs.Len() == 0 { + return nil, fmt.Errorf("empty result set") + } + + results := []map[string]interface{}{} + + for i := 0; i < rs.Len(); i++ { // Iterate through rows + row := map[string]interface{}{} + + for _, fieldName := range outputFields { + column := rs.GetColumn(fieldName) + if column == nil { + return nil, fmt.Errorf("column %s does not exist in result set", fieldName) + } + + switch column.Type() { + case entity.FieldTypeInt64: + value, err := column.GetAsInt64(i) + if err != nil { + return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err) + } + row[fieldName] = value + + case entity.FieldTypeInt32: + value, err := column.GetAsInt64(i) + if err != nil { + return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err) + } + row[fieldName] = value + + case entity.FieldTypeFloat: + value, err := column.GetAsDouble(i) + if err != nil { + return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err) + } + row[fieldName] = value + + case entity.FieldTypeDouble: + value, err := column.GetAsDouble(i) + if err != nil { + return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err) + } + row[fieldName] = value + + case entity.FieldTypeVarChar: + value, err := column.GetAsString(i) + if err != nil { + return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err) + } + row[fieldName] = value + + default: + return nil, fmt.Errorf("unsupported field type for column %s", fieldName) + } + } + + results = append(results, row) + } + + return results, nil +} + +func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) { + if rs.ResultCount == 0 { + return nil, fmt.Errorf("empty result set") + } + + result := make([]map[string]interface{}, rs.ResultCount) + + for i := 0; i < rs.ResultCount; i++ { // Iterate through rows + result[i] = make(map[string]interface{}) + for _, fieldName := range outputFields { + column := rs.Fields.GetColumn(fieldName) + result[i]["Score"] = rs.Scores[i] + + if column == nil { + return nil, fmt.Errorf("column %s does not exist in result set", fieldName) + } + + switch column.Type() { + case entity.FieldTypeInt64: + value, err := column.GetAsInt64(i) + if err != nil { + return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err) + } + result[i][fieldName] = value + + case entity.FieldTypeInt32: + value, err := column.GetAsInt64(i) + if err != nil { + return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err) + } + result[i][fieldName] = value + + case entity.FieldTypeFloat: + value, err := column.GetAsDouble(i) + if err != nil { + return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err) + } + result[i][fieldName] = value + + case entity.FieldTypeDouble: + value, err := column.GetAsDouble(i) + if err != nil { + return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err) + } + result[i][fieldName] = value + + case entity.FieldTypeVarChar: + value, err := column.GetAsString(i) + if err != nil { + return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err) + } + result[i][fieldName] = value + + default: + return nil, fmt.Errorf("unsupported field type for column %s", fieldName) + } + } + } + + return result, nil +} diff --git a/internal/pkg/database/milvus/operations.go b/internal/pkg/database/milvus/operations.go new file mode 100644 index 0000000..db27de0 --- /dev/null +++ b/internal/pkg/database/milvus/operations.go @@ -0,0 +1,270 @@ +package milvus + +import ( + "context" + "fmt" + "sort" + + "easy_rag/internal/models" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +// InsertDocuments inserts documents into the "documents" collection. +func (m *Client) InsertDocuments(ctx context.Context, docs []models.Document) error { + idColumn := entity.NewColumnVarChar("ID", extractIDs(docs)) + contentColumn := entity.NewColumnVarChar("Content", extractContents(docs)) + linkColumn := entity.NewColumnVarChar("Link", extractLinks(docs)) + filenameColumn := entity.NewColumnVarChar("Filename", extractFilenames(docs)) + categoryColumn := entity.NewColumnVarChar("Category", extractCategories(docs)) + embeddingModelColumn := entity.NewColumnVarChar("EmbeddingModel", extractEmbeddingModels(docs)) + summaryColumn := entity.NewColumnVarChar("Summary", extractSummaries(docs)) + metadataColumn := entity.NewColumnVarChar("Metadata", extractMetadata(docs)) + vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectorsDocs(docs)) + // Insert the data + _, err := m.Instance.Insert(ctx, "documents", "_default", idColumn, contentColumn, linkColumn, filenameColumn, + categoryColumn, embeddingModelColumn, summaryColumn, metadataColumn, vectorColumn) + if err != nil { + return fmt.Errorf("failed to insert documents: %w", err) + } + + // Flush the collection + err = m.Instance.Flush(ctx, "documents", false) + if err != nil { + return fmt.Errorf("failed to flush documents collection: %w", err) + } + + return nil +} + +// InsertEmbeddings inserts embeddings into the "chunks" collection. +func (m *Client) InsertEmbeddings(ctx context.Context, embeddings []models.Embedding) error { + idColumn := entity.NewColumnVarChar("ID", extractEmbeddingIDs(embeddings)) + documentIDColumn := entity.NewColumnVarChar("DocumentID", extractDocumentIDs(embeddings)) + vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectors(embeddings)) + textChunkColumn := entity.NewColumnVarChar("TextChunk", extractTextChunks(embeddings)) + dimensionColumn := entity.NewColumnInt32("Dimension", extractDimensions(embeddings)) + orderColumn := entity.NewColumnInt32("Order", extractOrders(embeddings)) + + _, err := m.Instance.Insert(ctx, "chunks", "_default", idColumn, documentIDColumn, vectorColumn, + textChunkColumn, dimensionColumn, orderColumn) + + if err != nil { + return fmt.Errorf("failed to insert embeddings: %w", err) + } + + err = m.Instance.Flush(ctx, "chunks", false) + if err != nil { + return fmt.Errorf("failed to flush chunks collection: %w", err) + } + + return nil +} + +// GetDocumentByID retrieves a document from the "documents" collection by ID. +func (m *Client) GetDocumentByID(ctx context.Context, id string) (map[string]interface{}, error) { + collectionName := "documents" + expr := fmt.Sprintf("ID == '%s'", id) + projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields + + results, err := m.Instance.Query(ctx, collectionName, nil, expr, projections) + if err != nil { + return nil, fmt.Errorf("failed to query document by ID: %w", err) + } + + if len(results) == 0 { + return nil, fmt.Errorf("document with ID '%s' not found", id) + } + + mp, err := transformResultSet(results, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata") + + if err != nil { + return nil, fmt.Errorf("failed to unmarshal document: %w", err) + } + + // convert metadata to map + mp[0]["Metadata"] = convertToMetadata(mp[0]["Metadata"].(string)) + + return mp[0], err +} + +// GetAllDocuments retrieves all documents from the "documents" collection. +func (m *Client) GetAllDocuments(ctx context.Context) ([]models.Document, error) { + collectionName := "documents" + projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields + expr := "" + + rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000)) + if err != nil { + return nil, fmt.Errorf("failed to query all documents: %w", err) + } + + if len(rs) == 0 { + return nil, fmt.Errorf("no documents found in the collection") + } + + results, err := transformResultSet(rs, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata") + + if err != nil { + return nil, fmt.Errorf("failed to unmarshal all documents: %w", err) + } + + var docs []models.Document = make([]models.Document, len(results)) + for i, result := range results { + docs[i] = models.Document{ + ID: result["ID"].(string), + Content: result["Content"].(string), + Link: result["Link"].(string), + Filename: result["Filename"].(string), + Category: result["Category"].(string), + EmbeddingModel: result["EmbeddingModel"].(string), + Summary: result["Summary"].(string), + Metadata: convertToMetadata(results[0]["Metadata"].(string)), + } + } + + return docs, nil +} + +// GetAllEmbeddingByDocID retrieves all embeddings linked to a specific DocumentID from the "chunks" collection. +func (m *Client) GetAllEmbeddingByDocID(ctx context.Context, documentID string) ([]models.Embedding, error) { + collectionName := "chunks" + projections := []string{"ID", "DocumentID", "TextChunk", "Order"} // Fetch all fields + expr := fmt.Sprintf("DocumentID == '%s'", documentID) + + rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000)) + + if err != nil { + return nil, fmt.Errorf("failed to query embeddings by DocumentID: %w", err) + } + + if rs.Len() == 0 { + return nil, fmt.Errorf("no embeddings found for DocumentID '%s'", documentID) + } + + results, err := transformResultSet(rs, "ID", "DocumentID", "TextChunk", "Order") + + if err != nil { + return nil, fmt.Errorf("failed to unmarshal all documents: %w", err) + } + + var embeddings []models.Embedding = make([]models.Embedding, rs.Len()) + + for i, result := range results { + embeddings[i] = models.Embedding{ + ID: result["ID"].(string), + DocumentID: result["DocumentID"].(string), + TextChunk: result["TextChunk"].(string), + Order: result["Order"].(int64), + } + } + + return embeddings, nil +} + +func (m *Client) Search(ctx context.Context, vectors [][]float32, topK int) ([]models.Embedding, error) { + const ( + collectionName = "chunks" + vectorDim = 1024 // Replace with your actual vector dimension + ) + projections := []string{"ID", "DocumentID", "TextChunk", "Order"} + metricType := entity.L2 // Default metric type + + // Validate and convert input vectors + searchVectors, err := validateAndConvertVectors(vectors, vectorDim) + if err != nil { + return nil, err + } + + // Set search parameters + searchParams, err := entity.NewIndexIvfFlatSearchParam(16) // 16 is the number of clusters for IVF_FLAT index + if err != nil { + return nil, fmt.Errorf("failed to create search params: %w", err) + } + + // Perform the search + searchResults, err := m.Instance.Search(ctx, collectionName, nil, "", projections, searchVectors, "Vector", metricType, topK, searchParams, client.WithLimit(10)) + if err != nil { + return nil, fmt.Errorf("failed to search collection: %w", err) + } + + // Process search results + embeddings, err := processSearchResults(searchResults) + if err != nil { + return nil, fmt.Errorf("failed to process search results: %w", err) + } + + return embeddings, nil +} + +// validateAndConvertVectors validates vector dimensions and converts them to Milvus-compatible format. +func validateAndConvertVectors(vectors [][]float32, expectedDim int) ([]entity.Vector, error) { + searchVectors := make([]entity.Vector, len(vectors)) + for i, vector := range vectors { + if len(vector) != expectedDim { + return nil, fmt.Errorf("vector dimension mismatch: expected %d, got %d", expectedDim, len(vector)) + } + searchVectors[i] = entity.FloatVector(vector) + } + return searchVectors, nil +} + +// processSearchResults transforms and aggregates the search results into embeddings and sorts by score. +func processSearchResults(results []client.SearchResult) ([]models.Embedding, error) { + var embeddings []models.Embedding + + for _, result := range results { + for i := 0; i < result.ResultCount; i++ { + embeddingMap, err := transformSearchResultSet(result, "ID", "DocumentID", "TextChunk", "Order") + if err != nil { + return nil, fmt.Errorf("failed to transform search result set: %w", err) + } + + for _, embedding := range embeddingMap { + embeddings = append(embeddings, models.Embedding{ + ID: embedding["ID"].(string), + DocumentID: embedding["DocumentID"].(string), + TextChunk: embedding["TextChunk"].(string), + Order: embedding["Order"].(int64), // Assuming 'Order' is a float64 type + Score: embedding["Score"].(float32), + }) + } + } + } + + // Sort embeddings by score in descending order (higher is better) + sort.Slice(embeddings, func(i, j int) bool { + return embeddings[i].Score > embeddings[j].Score + }) + + return embeddings, nil +} + +// DeleteDocument deletes a document from the "documents" collection by ID. +func (m *Client) DeleteDocument(ctx context.Context, id string) error { + collectionName := "documents" + partitionName := "_default" + expr := fmt.Sprintf("ID == '%s'", id) + + err := m.Instance.Delete(ctx, collectionName, partitionName, expr) + if err != nil { + return fmt.Errorf("failed to delete document by ID: %w", err) + } + + return nil +} + +// DeleteEmbedding deletes an embedding from the "chunks" collection by ID. +func (m *Client) DeleteEmbedding(ctx context.Context, id string) error { + collectionName := "chunks" + partitionName := "_default" + expr := fmt.Sprintf("DocumentID == '%s'", id) + + err := m.Instance.Delete(ctx, collectionName, partitionName, expr) + if err != nil { + return fmt.Errorf("failed to delete embedding by DocumentID: %w", err) + } + + return nil +} diff --git a/internal/pkg/rag/rag.go b/internal/pkg/rag/rag.go new file mode 100644 index 0000000..c367990 --- /dev/null +++ b/internal/pkg/rag/rag.go @@ -0,0 +1,21 @@ +package rag + +import ( + "easy_rag/internal/database" + "easy_rag/internal/embeddings" + "easy_rag/internal/llm" +) + +type Rag struct { + LLM llm.LLMService + Embeddings embeddings.EmbeddingsService + Database database.Database +} + +func NewRag(llm llm.LLMService, embeddings embeddings.EmbeddingsService, database database.Database) *Rag { + return &Rag{ + LLM: llm, + Embeddings: embeddings, + Database: database, + } +} diff --git a/internal/pkg/textprocessor/textprocessor.go b/internal/pkg/textprocessor/textprocessor.go new file mode 100644 index 0000000..32ef39e --- /dev/null +++ b/internal/pkg/textprocessor/textprocessor.go @@ -0,0 +1,50 @@ +package textprocessor + +import ( + "bytes" + "strings" + + "github.com/jonathanhecl/chunker" +) + +func CreateChunks(text string) []string { + // Maximum characters per chunk + const maxCharacters = 5000 // too slow otherwise + + var chunks []string + var currentChunk strings.Builder + + // Use the chunker library to split text into sentences + sentences := chunker.ChunkSentences(text) + + for _, sentence := range sentences { + // Check if adding the sentence exceeds the character limit + if currentChunk.Len()+len(sentence) <= maxCharacters { + if currentChunk.Len() > 0 { + currentChunk.WriteString(" ") // Add a space between sentences + } + currentChunk.WriteString(sentence) + } else { + // Add the completed chunk to the chunks slice + chunks = append(chunks, currentChunk.String()) + currentChunk.Reset() // Start a new chunk + currentChunk.WriteString(sentence) // Add the sentence to the new chunk + } + } + + // Add the last chunk if it has content + if currentChunk.Len() > 0 { + chunks = append(chunks, currentChunk.String()) + } + + // Return the chunks + return chunks +} + +func ConcatenateStrings(strings []string) string { + var result bytes.Buffer + for _, str := range strings { + result.WriteString(str) + } + return result.String() +} diff --git a/scripts/standalone_embed.sh b/scripts/standalone_embed.sh new file mode 100644 index 0000000..6f29e51 --- /dev/null +++ b/scripts/standalone_embed.sh @@ -0,0 +1,169 @@ +#!/usr/bin/env bash + +# Licensed to the LF AI & Data foundation under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +run_embed() { + cat << EOF > embedEtcd.yaml +listen-client-urls: http://0.0.0.0:2379 +advertise-client-urls: http://0.0.0.0:2379 +quota-backend-bytes: 4294967296 +auto-compaction-mode: revision +auto-compaction-retention: '1000' +EOF + + cat << EOF > user.yaml +# Extra config to override default milvus.yaml +EOF + + sudo docker run -d \ + --name milvus-standalone \ + --security-opt seccomp:unconfined \ + -e ETCD_USE_EMBED=true \ + -e ETCD_DATA_DIR=/var/lib/milvus/etcd \ + -e ETCD_CONFIG_PATH=/milvus/configs/embedEtcd.yaml \ + -e COMMON_STORAGETYPE=local \ + -v $(pwd)/volumes/milvus:/var/lib/milvus \ + -v $(pwd)/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml \ + -v $(pwd)/user.yaml:/milvus/configs/user.yaml \ + -p 19530:19530 \ + -p 9091:9091 \ + -p 2379:2379 \ + --health-cmd="curl -f http://localhost:9091/healthz" \ + --health-interval=30s \ + --health-start-period=90s \ + --health-timeout=20s \ + --health-retries=3 \ + milvusdb/milvus:v2.4.16 \ + milvus run standalone 1> /dev/null +} + +wait_for_milvus_running() { + echo "Wait for Milvus Starting..." + while true + do + res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l` + if [ $res -eq 1 ] + then + echo "Start successfully." + echo "To change the default Milvus configuration, add your settings to the user.yaml file and then restart the service." + break + fi + sleep 1 + done +} + +start() { + res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l` + if [ $res -eq 1 ] + then + echo "Milvus is running." + exit 0 + fi + + res=`sudo docker ps -a|grep milvus-standalone|wc -l` + if [ $res -eq 1 ] + then + sudo docker start milvus-standalone 1> /dev/null + else + run_embed + fi + + if [ $? -ne 0 ] + then + echo "Start failed." + exit 1 + fi + + wait_for_milvus_running +} + +stop() { + sudo docker stop milvus-standalone 1> /dev/null + + if [ $? -ne 0 ] + then + echo "Stop failed." + exit 1 + fi + echo "Stop successfully." + +} + +delete_container() { + res=`sudo docker ps|grep milvus-standalone|wc -l` + if [ $res -eq 1 ] + then + echo "Please stop Milvus service before delete." + exit 1 + fi + sudo docker rm milvus-standalone 1> /dev/null + if [ $? -ne 0 ] + then + echo "Delete milvus container failed." + exit 1 + fi + echo "Delete milvus container successfully." +} + +delete() { + delete_container + sudo rm -rf $(pwd)/volumes + sudo rm -rf $(pwd)/embedEtcd.yaml + sudo rm -rf $(pwd)/user.yaml + echo "Delete successfully." +} + +upgrade() { + read -p "Please confirm if you'd like to proceed with the upgrade. The default will be to the latest version. Confirm with 'y' for yes or 'n' for no. > " check + if [ "$check" == "y" ] ||[ "$check" == "Y" ];then + res=`sudo docker ps -a|grep milvus-standalone|wc -l` + if [ $res -eq 1 ] + then + stop + delete_container + fi + + curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed_latest.sh && \ + bash standalone_embed_latest.sh start 1> /dev/null && \ + echo "Upgrade successfully." + else + echo "Exit upgrade" + exit 0 + fi +} + +case $1 in + restart) + stop + start + ;; + start) + start + ;; + stop) + stop + ;; + upgrade) + upgrade + ;; + delete) + delete + ;; + *) + echo "please use bash standalone_embed.sh restart|start|stop|upgrade|delete" + ;; +esac diff --git a/tests/api_test.go b/tests/api_test.go new file mode 100644 index 0000000..38bb073 --- /dev/null +++ b/tests/api_test.go @@ -0,0 +1,313 @@ +package api_test + +import ( + "bytes" + "easy_rag/internal/models" + "encoding/json" + "github.com/google/uuid" + "net/http" + "net/http/httptest" + "testing" + + "easy_rag/api" + "easy_rag/internal/pkg/rag" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Example: Test UploadHandler +func TestUploadHandler(t *testing.T) { + e := echo.New() + + // Create a mock for the LLM, Embeddings, and Database + mockLLM := new(MockLLMService) + mockEmbeddings := new(MockEmbeddingsService) + mockDB := new(MockDatabase) + + // Setup the Rag object + r := &rag.Rag{ + LLM: mockLLM, + Embeddings: mockEmbeddings, + Database: mockDB, + } + + // We expect calls to these mocks in the background goroutine, for each document. + + // The request body + requestBody := api.RequestUpload{ + Docs: []api.UploadDoc{ + { + Content: "Test document content", + Link: "http://example.com/doc", + Filename: "doc1.txt", + Category: "TestCategory", + Metadata: map[string]string{"Author": "Me"}, + }, + }, + } + + // Convert requestBody to JSON + reqBodyBytes, _ := json.Marshal(requestBody) + + // Create a new request + req := httptest.NewRequest(http.MethodPost, "/api/v1/upload", bytes.NewReader(reqBodyBytes)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + + // Create a ResponseRecorder + rec := httptest.NewRecorder() + + // New echo context + c := e.NewContext(req, rec) + // Set the rag object in context + c.Set("Rag", r) + + // Because the UploadHandler spawns a goroutine, we only test the immediate HTTP response. + // We can still set expectations for the calls that happen in the goroutine to ensure they're invoked. + // For example, we expect the summary to be generated, so: + + testSummary := "Test summary from LLM" + mockLLM.On("Generate", mock.Anything).Return(testSummary, nil).Maybe() // .Maybe() because the concurrency might not complete by the time we assert + + // The embedding vector returned from the embeddings service + testVector := [][]float32{{0.1, 0.2, 0.3, 0.4}} + + // We'll mock calls to Vectorize() for summary and each chunk + mockEmbeddings.On("Vectorize", mock.AnythingOfType("string")).Return(testVector, nil).Maybe() + + // The database SaveDocument / SaveEmbeddings calls + mockDB.On("SaveDocument", mock.AnythingOfType("models.Document")).Return(nil).Maybe() + mockDB.On("SaveEmbeddings", mock.AnythingOfType("[]models.Embedding")).Return(nil).Maybe() + + // Invoke the handler + err := api.UploadHandler(c) + + // Check no immediate errors + assert.NoError(t, err) + + // Check the response + assert.Equal(t, http.StatusAccepted, rec.Code) + + var resp map[string]interface{} + _ = json.Unmarshal(rec.Body.Bytes(), &resp) + + // We expect certain fields in the JSON response + assert.Equal(t, "v1", resp["version"]) + assert.NotEmpty(t, resp["task_id"]) + assert.Equal(t, "Processing started", resp["status"]) + + // Typically, you might want to wait or do more advanced concurrency checks if you want to test + // the logic in the goroutine, but that goes beyond a simple unit test. + // The background process can be tested more thoroughly in integration or end-to-end tests. + + // Optionally assert that our mocks were called + mockLLM.AssertExpectations(t) + mockEmbeddings.AssertExpectations(t) + mockDB.AssertExpectations(t) +} + +// Example: Test ListAllDocsHandler +func TestListAllDocsHandler(t *testing.T) { + e := echo.New() + + mockLLM := new(MockLLMService) + mockEmbeddings := new(MockEmbeddingsService) + mockDB := new(MockDatabase) + + r := &rag.Rag{ + LLM: mockLLM, + Embeddings: mockEmbeddings, + Database: mockDB, + } + + // Mock data + doc1 := models.Document{ + ID: uuid.NewString(), + Filename: "doc1.txt", + Summary: "summary doc1", + } + doc2 := models.Document{ + ID: uuid.NewString(), + Filename: "doc2.txt", + Summary: "summary doc2", + } + docs := []models.Document{doc1, doc2} + + // Expect the database to return the docs + mockDB.On("ListDocuments").Return(docs, nil) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/docs", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("Rag", r) + + err := api.ListAllDocsHandler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + _ = json.Unmarshal(rec.Body.Bytes(), &resp) + assert.Equal(t, "v1", resp["version"]) + + // The "docs" field should match the ones we returned + docsInterface, ok := resp["docs"].([]interface{}) + assert.True(t, ok) + assert.Len(t, docsInterface, 2) + + // Verify mocks + mockDB.AssertExpectations(t) +} + +// Example: Test GetDocHandler +func TestGetDocHandler(t *testing.T) { + e := echo.New() + + mockLLM := new(MockLLMService) + mockEmbeddings := new(MockEmbeddingsService) + mockDB := new(MockDatabase) + + r := &rag.Rag{ + LLM: mockLLM, + Embeddings: mockEmbeddings, + Database: mockDB, + } + + // Mock a single doc + docID := "123" + testDoc := models.Document{ + ID: docID, + Filename: "doc3.txt", + Summary: "summary doc3", + } + + mockDB.On("GetDocument", docID).Return(testDoc, nil) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/doc/123", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + // set path param + c.SetParamNames("id") + c.SetParamValues(docID) + c.Set("Rag", r) + + err := api.GetDocHandler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + _ = json.Unmarshal(rec.Body.Bytes(), &resp) + assert.Equal(t, "v1", resp["version"]) + + docInterface := resp["doc"].(map[string]interface{}) + assert.Equal(t, "doc3.txt", docInterface["filename"]) + + // Verify mocks + mockDB.AssertExpectations(t) +} + +// Example: Test AskDocHandler +func TestAskDocHandler(t *testing.T) { + e := echo.New() + + mockLLM := new(MockLLMService) + mockEmbeddings := new(MockEmbeddingsService) + mockDB := new(MockDatabase) + + r := &rag.Rag{ + LLM: mockLLM, + Embeddings: mockEmbeddings, + Database: mockDB, + } + + // 1) We expect to Vectorize the question + question := "What is the summary of doc?" + questionVector := [][]float32{{0.5, 0.2, 0.1}} + mockEmbeddings.On("Vectorize", question).Return(questionVector, nil) + + // 2) We expect a DB search + emb := []models.Embedding{ + { + ID: "emb1", + DocumentID: "doc123", + TextChunk: "Relevant content chunk", + Score: 0.99, + }, + } + mockDB.On("Search", questionVector).Return(emb, nil) + + // 3) We expect the LLM to generate an answer from the chunk + generatedAnswer := "Here is an answer from the chunk" + // The prompt we pass is something like: "Given the following information: chunk ... Answer the question: question" + mockLLM.On("Generate", mock.AnythingOfType("string")).Return(generatedAnswer, nil) + + // Prepare request + reqBody := api.RequestQuestion{ + Question: question, + } + reqBytes, _ := json.Marshal(reqBody) + req := httptest.NewRequest(http.MethodPost, "/api/v1/ask", bytes.NewReader(reqBytes)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("Rag", r) + + // Execute + err := api.AskDocHandler(c) + + // Verify + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + _ = json.Unmarshal(rec.Body.Bytes(), &resp) + assert.Equal(t, "v1", resp["version"]) + assert.Equal(t, generatedAnswer, resp["answer"]) + + // The docs field should have the docID "doc123" + docsInterface := resp["docs"].([]interface{}) + assert.Len(t, docsInterface, 1) + assert.Equal(t, "doc123", docsInterface[0]) + + // Verify mocks + mockLLM.AssertExpectations(t) + mockEmbeddings.AssertExpectations(t) + mockDB.AssertExpectations(t) +} + +// Example: Test DeleteDocHandler +func TestDeleteDocHandler(t *testing.T) { + e := echo.New() + mockLLM := new(MockLLMService) + mockEmbeddings := new(MockEmbeddingsService) + mockDB := new(MockDatabase) + + r := &rag.Rag{ + LLM: mockLLM, + Embeddings: mockEmbeddings, + Database: mockDB, + } + + docID := "abc" + mockDB.On("DeleteDocument", docID).Return(nil) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/doc/abc", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetParamNames("id") + c.SetParamValues(docID) + c.Set("Rag", r) + + err := api.DeleteDocHandler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + _ = json.Unmarshal(rec.Body.Bytes(), &resp) + assert.Equal(t, "v1", resp["version"]) + + // docs should be nil according to DeleteDocHandler + assert.Nil(t, resp["docs"]) + + // Verify mocks + mockDB.AssertExpectations(t) +} diff --git a/tests/mock_test.go b/tests/mock_test.go new file mode 100644 index 0000000..85e3136 --- /dev/null +++ b/tests/mock_test.go @@ -0,0 +1,89 @@ +package api_test + +import ( + "easy_rag/internal/models" + "github.com/stretchr/testify/mock" +) + +// -------------------- +// Mock LLM +// -------------------- +type MockLLMService struct { + mock.Mock +} + +func (m *MockLLMService) Generate(prompt string) (string, error) { + args := m.Called(prompt) + return args.String(0), args.Error(1) +} + +func (m *MockLLMService) GetModel() string { + args := m.Called() + return args.String(0) +} + +// -------------------- +// Mock Embeddings +// -------------------- +type MockEmbeddingsService struct { + mock.Mock +} + +func (m *MockEmbeddingsService) Vectorize(text string) ([][]float32, error) { + args := m.Called(text) + return args.Get(0).([][]float32), args.Error(1) +} + +func (m *MockEmbeddingsService) GetModel() string { + args := m.Called() + return args.String(0) +} + +// -------------------- +// Mock Database +// -------------------- +type MockDatabase struct { + mock.Mock +} + +// GetDocumentInfo(id string) (models.DocumentInfo, error) +func (m *MockDatabase) GetDocumentInfo(id string) (models.Document, error) { + args := m.Called(id) + return args.Get(0).(models.Document), args.Error(1) +} + +// SaveDocument(document Document) error +func (m *MockDatabase) SaveDocument(doc models.Document) error { + args := m.Called(doc) + return args.Error(0) +} + +// SaveEmbeddings([]Embedding) error +func (m *MockDatabase) SaveEmbeddings(emb []models.Embedding) error { + args := m.Called(emb) + return args.Error(0) +} + +// ListDocuments() ([]Document, error) +func (m *MockDatabase) ListDocuments() ([]models.Document, error) { + args := m.Called() + return args.Get(0).([]models.Document), args.Error(1) +} + +// GetDocument(id string) (Document, error) +func (m *MockDatabase) GetDocument(id string) (models.Document, error) { + args := m.Called(id) + return args.Get(0).(models.Document), args.Error(1) +} + +// DeleteDocument(id string) error +func (m *MockDatabase) DeleteDocument(id string) error { + args := m.Called(id) + return args.Error(0) +} + +// Search(vector []float32) ([]models.Embedding, error) +func (m *MockDatabase) Search(vector [][]float32) ([]models.Embedding, error) { + args := m.Called(vector) + return args.Get(0).([]models.Embedding), args.Error(1) +} diff --git a/tests/openrouter_test.go b/tests/openrouter_test.go new file mode 100644 index 0000000..accb9d3 --- /dev/null +++ b/tests/openrouter_test.go @@ -0,0 +1,125 @@ +package api_test + +import ( + "context" + openroute2 "easy_rag/internal/llm/openroute" + "fmt" + "testing" +) + +func TestFetchChatCompletions(t *testing.T) { + client := openroute2.NewOpenRouterClient("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") + + request := openroute2.Request{ + Model: "qwen/qwen2.5-vl-72b-instruct:free", + Messages: []openroute2.MessageRequest{ + {openroute2.RoleUser, "Привет!", "", ""}, + }, + } + + output, err := client.FetchChatCompletions(request) + if err != nil { + t.Errorf("error %v", err) + } + + t.Logf("output: %v", output.Choices[0].Message.Content) +} + +func TestFetchChatCompletionsStreaming(t *testing.T) { + client := openroute2.NewOpenRouterClient("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") + + request := openroute2.Request{ + Model: "qwen/qwen2.5-vl-72b-instruct:free", + Messages: []openroute2.MessageRequest{ + {openroute2.RoleUser, "Привет!", "", ""}, + }, + Stream: true, + } + + outputChan := make(chan openroute2.Response) + processingChan := make(chan interface{}) + errChan := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx) + + for { + select { + case output := <-outputChan: + if len(output.Choices) > 0 { + t.Logf("%s", output.Choices[0].Delta.Content) + } + case <-processingChan: + t.Logf("Обработка\n") + case err := <-errChan: + if err != nil { + t.Errorf("Ошибка: %v", err) + return + } + return + case <-ctx.Done(): + fmt.Println("Контекст отменен:", ctx.Err()) + return + } + } +} + +func TestFetchChatCompletionsAgentStreaming(t *testing.T) { + client := openroute2.NewOpenRouterClient("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") + agent := openroute2.NewRouterAgent(client, "qwen/qwen2.5-vl-72b-instruct:freet", openroute2.RouterAgentConfig{ + Temperature: 0.7, + MaxTokens: 100, + }) + + outputChan := make(chan openroute2.Response) + processingChan := make(chan interface{}) + errChan := make(chan error) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + chat := []openroute2.MessageRequest{ + {Role: openroute2.RoleSystem, Content: "Вы полезный помощник."}, + {Role: openroute2.RoleUser, Content: "Привет!"}, + } + + go agent.ChatStream(chat, outputChan, processingChan, errChan, ctx) + + for { + select { + case output := <-outputChan: + if len(output.Choices) > 0 { + t.Logf("%s", output.Choices[0].Delta.Content) + } + case <-processingChan: + t.Logf("Обработка\n") + case err := <-errChan: + if err != nil { + t.Errorf("Ошибка: %v", err) + return + } + return + case <-ctx.Done(): + fmt.Println("Контекст отменен:", ctx.Err()) + return + } + } +} + +func TestFetchChatCompletionsAgentSimpleChat(t *testing.T) { + client := openroute2.NewOpenRouterClient("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") + agent := openroute2.NewRouterAgentChat(client, "qwen/qwen2.5-vl-72b-instruct:free", openroute2.RouterAgentConfig{ + Temperature: 0.0, + MaxTokens: 100, + }, "Вы полезный помощник, отвечайте короткими словами.") + + agent.Chat("Запомни это: \"wojtess\"") + agent.Chat("Что я просил вас запомнить?") + + for _, msg := range agent.Messages { + content, ok := msg.Content.(string) + if ok { + t.Logf("%s: %s", msg.Role, content) + } + } +} diff --git a/user.yaml b/user.yaml new file mode 100644 index 0000000..8d31269 --- /dev/null +++ b/user.yaml @@ -0,0 +1 @@ +# Extra config to override default milvus.yaml