Реализована операции Milvus для управления документами и встраиванием, включая функции вставки, запроса и удаления. Внедрите архитектуру RAG с LLM и сервисами встраивания. Добавьте обработку текста для фрагментации и конкатенации. Создайте автономный скрипт для настройки и управления Milvus. Разработайте комплексные тесты API для обработки документов и взаимодействия с LLM, включая имитации для сервисов. Расширьте возможности конфигурации пользователя с помощью дополнительных настроек YAML.
This commit is contained in:
33
.gitignore
vendored
33
.gitignore
vendored
@@ -1,27 +1,6 @@
|
|||||||
# ---> Go
|
*.env
|
||||||
# If you prefer the allow list template instead of the deny list, see community template:
|
*.sum
|
||||||
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
|
id_rsa2
|
||||||
#
|
id_rsa2.pub
|
||||||
# Binaries for programs and plugins
|
/tests/RAG.postman_collection.json
|
||||||
*.exe
|
/volumes/
|
||||||
*.exe~
|
|
||||||
*.dll
|
|
||||||
*.so
|
|
||||||
*.dylib
|
|
||||||
|
|
||||||
# Test binary, built with `go test -c`
|
|
||||||
*.test
|
|
||||||
|
|
||||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
|
||||||
*.out
|
|
||||||
|
|
||||||
# Dependency directories (remove the comment below to include it)
|
|
||||||
# vendor/
|
|
||||||
|
|
||||||
# Go workspace file
|
|
||||||
go.work
|
|
||||||
go.work.sum
|
|
||||||
|
|
||||||
# env file
|
|
||||||
.env
|
|
||||||
|
|
||||||
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
13
.idea/easy_rag.iml
generated
Normal file
13
.idea/easy_rag.iml
generated
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="WEB_MODULE" version="4">
|
||||||
|
<component name="Go" enabled="true">
|
||||||
|
<buildTags>
|
||||||
|
<option name="arch" value="amd64" />
|
||||||
|
</buildTags>
|
||||||
|
</component>
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/easy_rag.iml" filepath="$PROJECT_DIR$/.idea/easy_rag.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
454
README.md
454
README.md
@@ -1,2 +1,454 @@
|
|||||||
# easy_rag
|
# Easy RAG - Система Retrieval-Augmented Generation
|
||||||
|
|
||||||
|
## Обзор
|
||||||
|
|
||||||
|
Easy RAG - это мощная система для управления документами и генерации ответов на основе метода Retrieval-Augmented Generation (RAG). Проект реализует полнофункциональное API для загрузки, хранения, поиска и анализа документов с использованием векторных баз данных и современных языковых моделей.
|
||||||
|
|
||||||
|
### Ключевые возможности
|
||||||
|
|
||||||
|
- 🔍 **Семантический поиск** по документам с использованием векторных эмбеддингов
|
||||||
|
- 📄 **Управление документами** - загрузка, хранение, получение и удаление
|
||||||
|
- 🤖 **Интеграция с LLM** - поддержка OpenAI, Ollama и OpenRoute
|
||||||
|
- 💾 **Векторная база данных** - использование Milvus для хранения эмбеддингов
|
||||||
|
- 🔧 **Модульная архитектура** - легко расширяемая и настраиваемая система
|
||||||
|
- 🚀 **RESTful API** - простое и понятное API для интеграции
|
||||||
|
- 🐳 **Docker-готовность** - контейнеризация для простого развертывания
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Архитектура системы
|
||||||
|
|
||||||
|
### Компоненты
|
||||||
|
|
||||||
|
1. **API слой** (`api/`) - HTTP API на базе Echo Framework
|
||||||
|
2. **Основная логика** (`internal/pkg/rag/`) - ядро RAG системы
|
||||||
|
3. **Провайдеры LLM** (`internal/llm/`) - интеграция с языковыми моделями
|
||||||
|
4. **Провайдеры эмбеддингов** (`internal/embeddings/`) - генерация векторных представлений
|
||||||
|
5. **База данных** (`internal/database/`) - работа с векторной БД Milvus
|
||||||
|
6. **Обработка текста** (`internal/pkg/textprocessor/`) - разбиение на чанки
|
||||||
|
7. **Конфигурация** (`config/`) - управление настройками
|
||||||
|
|
||||||
|
### Поддерживаемые провайдеры
|
||||||
|
|
||||||
|
#### LLM (Языковые модели):
|
||||||
|
- **OpenAI** - GPT-3.5, GPT-4 и другие модели
|
||||||
|
- **Ollama** - локальные открытые модели
|
||||||
|
- **OpenRoute** - доступ к различным моделям через единый API
|
||||||
|
|
||||||
|
#### Эмбеддинги:
|
||||||
|
- **OpenAI Embeddings** - text-embedding-ada-002 и новые модели
|
||||||
|
- **Ollama Embeddings** - локальные модели эмбеддингов (например, bge-m3)
|
||||||
|
|
||||||
|
#### Векторная база данных:
|
||||||
|
- **Milvus** - высокопроизводительная векторная база данных
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Документация
|
||||||
|
|
||||||
|
### Базовый URL
|
||||||
|
```
|
||||||
|
http://localhost:4002/api/v1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Эндпоинты
|
||||||
|
|
||||||
|
#### 1. Получить все документы
|
||||||
|
```http
|
||||||
|
GET /api/v1/docs
|
||||||
|
```
|
||||||
|
**Описание**: Получить список всех сохраненных документов.
|
||||||
|
|
||||||
|
**Ответ**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "v1",
|
||||||
|
"docs": [
|
||||||
|
{
|
||||||
|
"id": "document_id",
|
||||||
|
"filename": "document.txt",
|
||||||
|
"summary": "Краткое описание документа",
|
||||||
|
"metadata": {
|
||||||
|
"category": "Техническая документация",
|
||||||
|
"author": "Автор"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Загрузить документы
|
||||||
|
```http
|
||||||
|
POST /api/v1/upload
|
||||||
|
```
|
||||||
|
**Описание**: Загрузить один или несколько документов для обработки и индексации.
|
||||||
|
|
||||||
|
**Тело запроса**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{
|
||||||
|
"content": "Содержимое документа...",
|
||||||
|
"link": "https://example.com/document",
|
||||||
|
"filename": "document.txt",
|
||||||
|
"category": "Категория",
|
||||||
|
"metadata": {
|
||||||
|
"author": "Автор",
|
||||||
|
"date": "2024-01-01"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ответ**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "v1",
|
||||||
|
"task_id": "unique_task_id",
|
||||||
|
"expected_time": "10m",
|
||||||
|
"status": "Обработка начата"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Получить документ по ID
|
||||||
|
```http
|
||||||
|
GET /api/v1/doc/{id}
|
||||||
|
```
|
||||||
|
**Описание**: Получить детальную информацию о документе по его идентификатору.
|
||||||
|
|
||||||
|
**Ответ**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "v1",
|
||||||
|
"doc": {
|
||||||
|
"id": "document_id",
|
||||||
|
"content": "Полное содержимое документа",
|
||||||
|
"filename": "document.txt",
|
||||||
|
"summary": "Краткое описание",
|
||||||
|
"metadata": {
|
||||||
|
"category": "Категория",
|
||||||
|
"author": "Автор"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Задать вопрос
|
||||||
|
```http
|
||||||
|
POST /api/v1/ask
|
||||||
|
```
|
||||||
|
**Описание**: Задать вопрос на основе проиндексированных документов.
|
||||||
|
|
||||||
|
**Тело запроса**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"question": "Что такое ISO 27001?"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ответ**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "v1",
|
||||||
|
"docs": ["document_id_1", "document_id_2"],
|
||||||
|
"answer": "ISO 27001 - это международный стандарт информационной безопасности..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5. Удалить документ
|
||||||
|
```http
|
||||||
|
DELETE /api/v1/doc/{id}
|
||||||
|
```
|
||||||
|
**Описание**: Удалить документ по его идентификатору.
|
||||||
|
|
||||||
|
**Ответ**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"version": "v1",
|
||||||
|
"docs": null
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Структуры данных
|
||||||
|
|
||||||
|
### Document (Документ)
|
||||||
|
```go
|
||||||
|
type Document struct {
|
||||||
|
ID string `json:"id"` // Уникальный идентификатор
|
||||||
|
Content string `json:"content"` // Содержимое документа
|
||||||
|
Link string `json:"link"` // Ссылка на источник
|
||||||
|
Filename string `json:"filename"` // Имя файла
|
||||||
|
Category string `json:"category"` // Категория
|
||||||
|
EmbeddingModel string `json:"embedding_model"` // Модель эмбеддингов
|
||||||
|
Summary string `json:"summary"` // Краткое описание
|
||||||
|
Metadata map[string]string `json:"metadata"` // Метаданные
|
||||||
|
Vector []float32 `json:"vector"` // Векторное представление
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Embedding (Эмбеддинг)
|
||||||
|
```go
|
||||||
|
type Embedding struct {
|
||||||
|
ID string `json:"id"` // Уникальный идентификатор
|
||||||
|
DocumentID string `json:"document_id"` // ID связанного документа
|
||||||
|
Vector []float32 `json:"vector"` // Векторное представление
|
||||||
|
TextChunk string `json:"text_chunk"` // Фрагмент текста
|
||||||
|
Dimension int64 `json:"dimension"` // Размерность вектора
|
||||||
|
Order int64 `json:"order"` // Порядок фрагмента
|
||||||
|
Score float32 `json:"score"` // Оценка релевантности
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Установка и настройка
|
||||||
|
|
||||||
|
### Требования
|
||||||
|
- Go 1.24.3+
|
||||||
|
- Milvus векторная база данных
|
||||||
|
- Ollama (для локальных моделей) или API ключи для OpenAI/OpenRoute
|
||||||
|
|
||||||
|
### 1. Клонирование репозитория
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/elchemista/easy_rag.git
|
||||||
|
cd easy_rag
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Установка зависимостей
|
||||||
|
```bash
|
||||||
|
go mod tidy
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Настройка окружения
|
||||||
|
Создайте файл `.env` или установите переменные окружения:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# LLM настройки
|
||||||
|
OPENAI_API_KEY=your_openai_api_key
|
||||||
|
OPENAI_ENDPOINT=https://api.openai.com/v1
|
||||||
|
OPENAI_MODEL=gpt-3.5-turbo
|
||||||
|
|
||||||
|
OPENROUTE_API_KEY=your_openroute_api_key
|
||||||
|
OPENROUTE_ENDPOINT=https://openrouter.ai/api/v1
|
||||||
|
OPENROUTE_MODEL=anthropic/claude-3-haiku
|
||||||
|
|
||||||
|
OLLAMA_ENDPOINT=http://localhost:11434/api/chat
|
||||||
|
OLLAMA_MODEL=qwen3:latest
|
||||||
|
|
||||||
|
# Эмбеддинги
|
||||||
|
OPENAI_EMBEDDING_API_KEY=your_openai_api_key
|
||||||
|
OPENAI_EMBEDDING_ENDPOINT=https://api.openai.com/v1
|
||||||
|
OPENAI_EMBEDDING_MODEL=text-embedding-ada-002
|
||||||
|
|
||||||
|
OLLAMA_EMBEDDING_ENDPOINT=http://localhost:11434
|
||||||
|
OLLAMA_EMBEDDING_MODEL=bge-m3
|
||||||
|
|
||||||
|
# База данных
|
||||||
|
MILVUS_HOST=localhost:19530
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Запуск Milvus
|
||||||
|
```bash
|
||||||
|
# Используя Docker
|
||||||
|
docker run -d --name milvus-standalone \
|
||||||
|
-p 19530:19530 -p 9091:9091 \
|
||||||
|
milvusdb/milvus:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Запуск приложения
|
||||||
|
```bash
|
||||||
|
go run cmd/rag/main.go
|
||||||
|
```
|
||||||
|
|
||||||
|
API будет доступно по адресу `http://localhost:4002`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Запуск с Docker
|
||||||
|
|
||||||
|
### Сборка образа
|
||||||
|
```bash
|
||||||
|
docker build -f deploy/Dockerfile -t easy-rag .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Запуск контейнера
|
||||||
|
```bash
|
||||||
|
docker run -d -p 4002:4002 \
|
||||||
|
-e MILVUS_HOST=your_milvus_host:19530 \
|
||||||
|
-e OLLAMA_ENDPOINT=http://your_ollama_host:11434/api/chat \
|
||||||
|
-e OLLAMA_MODEL=qwen3:latest \
|
||||||
|
easy-rag
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Примеры использования
|
||||||
|
|
||||||
|
### 1. Загрузка документа
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:4002/api/v1/upload \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"docs": [{
|
||||||
|
"content": "Это тестовый документ для демонстрации RAG системы.",
|
||||||
|
"filename": "test.txt",
|
||||||
|
"category": "Тест",
|
||||||
|
"metadata": {
|
||||||
|
"author": "Пользователь",
|
||||||
|
"type": "demo"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Поиск ответа
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:4002/api/v1/ask \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"question": "О чем этот документ?"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Получение всех документов
|
||||||
|
```bash
|
||||||
|
curl http://localhost:4002/api/v1/docs
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Архитектурные особенности
|
||||||
|
|
||||||
|
### Модульность
|
||||||
|
Система спроектирована с использованием интерфейсов, что позволяет легко:
|
||||||
|
- Переключаться между различными LLM провайдерами
|
||||||
|
- Использовать разные модели эмбеддингов
|
||||||
|
- Менять векторную базу данных
|
||||||
|
- Добавлять новые методы обработки текста
|
||||||
|
|
||||||
|
### Обработка текста
|
||||||
|
- Автоматическое разбиение документов на чанки
|
||||||
|
- Генерация эмбеддингов для каждого фрагмента
|
||||||
|
- Сохранение порядка фрагментов для корректной реконструкции
|
||||||
|
|
||||||
|
### Поиск и ранжирование
|
||||||
|
- Семантический поиск по векторным представлениям
|
||||||
|
- Ранжирование результатов по релевантности
|
||||||
|
- Контекстная генерация ответов на основе найденных документов
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Разработка и тестирование
|
||||||
|
|
||||||
|
### Структура проекта
|
||||||
|
```
|
||||||
|
easy_rag/
|
||||||
|
├── api/ # HTTP API обработчики
|
||||||
|
├── cmd/rag/ # Точка входа приложения
|
||||||
|
├── config/ # Конфигурация
|
||||||
|
├── deploy/ # Docker файлы
|
||||||
|
├── internal/ # Внутренняя логика
|
||||||
|
│ ├── database/ # Интерфейсы БД
|
||||||
|
│ ├── embeddings/ # Провайдеры эмбеддингов
|
||||||
|
│ ├── llm/ # Провайдеры LLM
|
||||||
|
│ ├── models/ # Модели данных
|
||||||
|
│ └── pkg/ # Пакеты общего назначения
|
||||||
|
├── scripts/ # Вспомогательные скрипты
|
||||||
|
└── tests/ # Тесты и коллекции Postman
|
||||||
|
```
|
||||||
|
|
||||||
|
### Запуск тестов
|
||||||
|
```bash
|
||||||
|
go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Использование Postman
|
||||||
|
В папке `tests/` находится коллекция Postman для тестирования API:
|
||||||
|
```
|
||||||
|
tests/RAG.postman_collection.json
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Производительность и масштабирование
|
||||||
|
|
||||||
|
### Рекомендации по производительности
|
||||||
|
- Используйте SSD для хранения векторной базы данных
|
||||||
|
- Настройте индексы Milvus для оптимальной производительности
|
||||||
|
- Рассмотрите использование GPU для генерации эмбеддингов
|
||||||
|
- Кэшируйте часто запрашиваемые результаты
|
||||||
|
|
||||||
|
### Масштабирование
|
||||||
|
- Горизонтальное масштабирование Milvus кластера
|
||||||
|
- Балансировка нагрузки между несколькими экземплярами API
|
||||||
|
- Асинхронная обработка загрузки документов
|
||||||
|
- Использование очередей для обработки больших объемов данных
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Устранение неполадок
|
||||||
|
|
||||||
|
### Частые проблемы
|
||||||
|
|
||||||
|
1. **Не удается подключиться к Milvus**
|
||||||
|
- Проверьте, что Milvus запущен и доступен
|
||||||
|
- Убедитесь в правильности MILVUS_HOST
|
||||||
|
|
||||||
|
2. **Ошибки LLM провайдера**
|
||||||
|
- Проверьте API ключи
|
||||||
|
- Убедитесь в доступности эндпоинтов
|
||||||
|
- Проверьте правильность названий моделей
|
||||||
|
|
||||||
|
3. **Медленная обработка документов**
|
||||||
|
- Уменьшите размер чанков
|
||||||
|
- Используйте более быстрые модели эмбеддингов
|
||||||
|
- Оптимизируйте настройки Milvus
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Вклад в проект
|
||||||
|
|
||||||
|
Мы приветствуем вклад в развитие проекта! Пожалуйста:
|
||||||
|
|
||||||
|
1. Форкните репозиторий
|
||||||
|
2. Создайте ветку для новой функции
|
||||||
|
3. Внесите изменения
|
||||||
|
4. Добавьте тесты
|
||||||
|
5. Отправьте Pull Request
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Лицензия
|
||||||
|
|
||||||
|
Проект распространяется под лицензией MIT. См. файл `LICENSE` для подробностей.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Поддержка
|
||||||
|
|
||||||
|
Если у вас есть вопросы или проблемы:
|
||||||
|
- Создайте Issue в GitHub репозитории
|
||||||
|
- Обратитесь к документации API
|
||||||
|
- Проверьте примеры в папке `tests/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
|
||||||
|
### Ближайшие планы
|
||||||
|
- [ ] Поддержка дополнительных форматов документов (PDF, DOCX)
|
||||||
|
- [ ] Веб-интерфейс для управления документами
|
||||||
|
- [ ] Улучшенная система метаданных
|
||||||
|
- [ ] Поддержка multimodal моделей
|
||||||
|
- [ ] Кэширование результатов
|
||||||
|
- [ ] Мониторинг и метрики
|
||||||
|
|
||||||
|
### Долгосрочные цели
|
||||||
|
- [ ] Поддержка множественных языков
|
||||||
|
- [ ] Федеративный поиск по нескольким источникам
|
||||||
|
- [ ] Интеграция с внешними системами (SharePoint, Confluence)
|
||||||
|
- [ ] Продвинутая аналитика использования
|
||||||
|
- [ ] Система разрешений и ролей
|
||||||
|
|||||||
35
api/api.go
Normal file
35
api/api.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
|
||||||
|
"easy_rag/internal/pkg/rag"
|
||||||
|
"github.com/labstack/echo/v4/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// APIVersion is the version of the API
|
||||||
|
APIVersion = "v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewAPI(e *echo.Echo, rag *rag.Rag) {
|
||||||
|
// Middleware
|
||||||
|
e.Use(middleware.Logger())
|
||||||
|
e.Use(middleware.Recover())
|
||||||
|
// put rag pointer in context
|
||||||
|
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
c.Set("Rag", rag)
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
api := e.Group(fmt.Sprintf("/api/%s", APIVersion))
|
||||||
|
|
||||||
|
api.POST("/upload", UploadHandler)
|
||||||
|
api.POST("/ask", AskDocHandler)
|
||||||
|
api.GET("/docs", ListAllDocsHandler)
|
||||||
|
api.GET("/doc/:id", GetDocHandler)
|
||||||
|
api.DELETE("/doc/:id", DeleteDocHandler)
|
||||||
|
}
|
||||||
248
api/handler.go
Normal file
248
api/handler.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
"easy_rag/internal/pkg/rag"
|
||||||
|
"easy_rag/internal/pkg/textprocessor"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UploadDoc struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Link string `json:"link"`
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestUpload struct {
|
||||||
|
Docs []UploadDoc `json:"docs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestQuestion struct {
|
||||||
|
Question string `json:"question"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResposeQuestion struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Docs []models.Document `json:"docs"`
|
||||||
|
Answer string `json:"answer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func UploadHandler(c echo.Context) error {
|
||||||
|
// Retrieve the RAG instance from context
|
||||||
|
rag := c.Get("Rag").(*rag.Rag)
|
||||||
|
|
||||||
|
var request RequestUpload
|
||||||
|
if err := c.Bind(&request); err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique task ID
|
||||||
|
taskID := uuid.NewString()
|
||||||
|
|
||||||
|
// Launch the upload process in a separate goroutine
|
||||||
|
go func(taskID string, request RequestUpload) {
|
||||||
|
log.Printf("Task %s: started processing", taskID)
|
||||||
|
defer log.Printf("Task %s: completed processing", taskID)
|
||||||
|
|
||||||
|
var docs []models.Document
|
||||||
|
|
||||||
|
for idx, doc := range request.Docs {
|
||||||
|
// Generate a unique ID for each document
|
||||||
|
docID := uuid.NewString()
|
||||||
|
log.Printf("Task %s: processing document %d with generated ID %s (filename: %s)", taskID, idx, docID, doc.Filename)
|
||||||
|
|
||||||
|
// Step 1: Create chunks from document content
|
||||||
|
chunks := textprocessor.CreateChunks(doc.Content)
|
||||||
|
log.Printf("Task %s: created %d chunks for document %s", taskID, len(chunks), docID)
|
||||||
|
|
||||||
|
// Step 2: Generate summary for the document
|
||||||
|
var summaryChunks string
|
||||||
|
if len(chunks) < 4 {
|
||||||
|
summaryChunks = doc.Content
|
||||||
|
} else {
|
||||||
|
summaryChunks = textprocessor.ConcatenateStrings(chunks[:3])
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Task %s: generating summary for document %s", taskID, docID)
|
||||||
|
summary, err := rag.LLM.Generate(fmt.Sprintf("Give me only summary of the following text: %s", summaryChunks))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Task %s: error generating summary for document %s: %v", taskID, docID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: generated summary for document %s", taskID, docID)
|
||||||
|
|
||||||
|
// Step 3: Vectorize the summary
|
||||||
|
log.Printf("Task %s: vectorizing summary for document %s", taskID, docID)
|
||||||
|
vectorSum, err := rag.Embeddings.Vectorize(summary)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Task %s: error vectorizing summary for document %s: %v", taskID, docID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: vectorized summary for document %s", taskID, docID)
|
||||||
|
|
||||||
|
// Step 4: Save the document
|
||||||
|
document := models.Document{
|
||||||
|
ID: docID, // Use generated ID
|
||||||
|
Content: "",
|
||||||
|
Link: doc.Link,
|
||||||
|
Filename: doc.Filename,
|
||||||
|
Category: doc.Category,
|
||||||
|
EmbeddingModel: rag.Embeddings.GetModel(),
|
||||||
|
Summary: summary,
|
||||||
|
Vector: vectorSum[0],
|
||||||
|
Metadata: doc.Metadata,
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: saving document %s", taskID, docID)
|
||||||
|
if err := rag.Database.SaveDocument(document); err != nil {
|
||||||
|
log.Printf("Task %s: error saving document %s: %v", taskID, docID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: saved document %s", taskID, docID)
|
||||||
|
|
||||||
|
// Step 5: Process and save embeddings for each chunk
|
||||||
|
var embeddings []models.Embedding
|
||||||
|
for order, chunk := range chunks {
|
||||||
|
log.Printf("Task %s: vectorizing chunk %d for document %s", taskID, order, docID)
|
||||||
|
vectorEmbedding, err := rag.Embeddings.Vectorize(chunk)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Task %s: error vectorizing chunk %d for document %s: %v", taskID, order, docID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: vectorized chunk %d for document %s", taskID, order, docID)
|
||||||
|
|
||||||
|
embedding := models.Embedding{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
DocumentID: docID,
|
||||||
|
Vector: vectorEmbedding[0],
|
||||||
|
TextChunk: chunk,
|
||||||
|
Dimension: int64(1024),
|
||||||
|
Order: int64(order),
|
||||||
|
}
|
||||||
|
embeddings = append(embeddings, embedding)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Task %s: saving %d embeddings for document %s", taskID, len(embeddings), docID)
|
||||||
|
if err := rag.Database.SaveEmbeddings(embeddings); err != nil {
|
||||||
|
log.Printf("Task %s: error saving embeddings for document %s: %v", taskID, docID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Task %s: saved embeddings for document %s", taskID, docID)
|
||||||
|
|
||||||
|
docs = append(docs, document)
|
||||||
|
}
|
||||||
|
}(taskID, request)
|
||||||
|
|
||||||
|
// Return the task ID and expected completion time
|
||||||
|
return c.JSON(http.StatusAccepted, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"task_id": taskID,
|
||||||
|
"expected_time": "10m",
|
||||||
|
"status": "Processing started",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListAllDocsHandler(c echo.Context) error {
|
||||||
|
rag := c.Get("Rag").(*rag.Rag)
|
||||||
|
docs, err := rag.Database.ListDocuments()
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"docs": docs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDocHandler(c echo.Context) error {
|
||||||
|
rag := c.Get("Rag").(*rag.Rag)
|
||||||
|
id := c.Param("id")
|
||||||
|
doc, err := rag.Database.GetDocument(id)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"doc": doc,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func AskDocHandler(c echo.Context) error {
|
||||||
|
rag := c.Get("Rag").(*rag.Rag)
|
||||||
|
|
||||||
|
var request RequestQuestion
|
||||||
|
err := c.Bind(&request)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
questionV, err := rag.Embeddings.Vectorize(request.Question)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings, err := rag.Database.Search(questionV)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(embeddings) == 0 {
|
||||||
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"docs": nil,
|
||||||
|
"answer": "Don't found any relevant documents",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
answer, err := rag.LLM.Generate(fmt.Sprintf("Given the following information: %s \nAnswer the question: %s", embeddings[0].TextChunk, request.Question))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a map to track unique DocumentIDs
|
||||||
|
docSet := make(map[string]struct{})
|
||||||
|
for _, embedding := range embeddings {
|
||||||
|
docSet[embedding.DocumentID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the map keys to a slice
|
||||||
|
docs := make([]string, 0, len(docSet))
|
||||||
|
for docID := range docSet {
|
||||||
|
docs = append(docs, docID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"docs": docs,
|
||||||
|
"answer": answer,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDocHandler(c echo.Context) error {
|
||||||
|
rag := c.Get("Rag").(*rag.Rag)
|
||||||
|
id := c.Param("id")
|
||||||
|
err := rag.Database.DeleteDocument(id)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorHandler(err, c)
|
||||||
|
}
|
||||||
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
|
"version": APIVersion,
|
||||||
|
"docs": nil,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorHandler(err error, c echo.Context) error {
|
||||||
|
return c.JSON(http.StatusBadRequest, map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
43
cmd/rag/main.go
Normal file
43
cmd/rag/main.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"easy_rag/api"
|
||||||
|
"easy_rag/config"
|
||||||
|
"easy_rag/internal/database"
|
||||||
|
"easy_rag/internal/embeddings"
|
||||||
|
"easy_rag/internal/llm"
|
||||||
|
"easy_rag/internal/pkg/rag"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rag is the main struct for the rag application
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
|
||||||
|
llm := llm.NewOllama(
|
||||||
|
cfg.OllamaEndpoint,
|
||||||
|
cfg.OllamaModel,
|
||||||
|
)
|
||||||
|
embeddings := embeddings.NewOllamaEmbeddings(
|
||||||
|
cfg.OllamaEmbeddingEndpoint,
|
||||||
|
cfg.OllamaEmbeddingModel,
|
||||||
|
)
|
||||||
|
database := database.NewMilvus(cfg.MilvusHost)
|
||||||
|
|
||||||
|
// Rag instance
|
||||||
|
rag := rag.NewRag(
|
||||||
|
llm,
|
||||||
|
embeddings,
|
||||||
|
database,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Echo WebServer instance
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
// Wrapper for API
|
||||||
|
api.NewAPI(e, rag)
|
||||||
|
|
||||||
|
// Start Server
|
||||||
|
e.Logger.Fatal(e.Start(":4002"))
|
||||||
|
}
|
||||||
38
config/config.go
Normal file
38
config/config.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import cfg "github.com/eschao/config"
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
// LLM
|
||||||
|
OpenAIAPIKey string `env:"OPENAI_API_KEY"`
|
||||||
|
OpenAIEndpoint string `env:"OPENAI_ENDPOINT"`
|
||||||
|
OpenAIModel string `env:"OPENAI_MODEL"`
|
||||||
|
OpenRouteAPIKey string `env:"OPENROUTE_API_KEY"`
|
||||||
|
OpenRouteEndpoint string `env:"OPENROUTE_ENDPOINT"`
|
||||||
|
OpenRouteModel string `env:"OPENROUTE_MODEL"`
|
||||||
|
|
||||||
|
OllamaEndpoint string `env:"OLLAMA_ENDPOINT"`
|
||||||
|
OllamaModel string `env:"OLLAMA_MODEL"`
|
||||||
|
|
||||||
|
// Embeddings
|
||||||
|
OpenAIEmbeddingAPIKey string `env:"OPENAI_EMBEDDING_API_KEY"`
|
||||||
|
OpenAIEmbeddingEndpoint string `env:"OPENAI_EMBEDDING_ENDPOINT"`
|
||||||
|
OpenAIEmbeddingModel string `env:"OPENAI_EMBEDDING_MODEL"`
|
||||||
|
OllamaEmbeddingEndpoint string `env:"OLLAMA_EMBEDDING_ENDPOINT"`
|
||||||
|
OllamaEmbeddingModel string `env:"OLLAMA_EMBEDDING_MODEL"`
|
||||||
|
|
||||||
|
// Database
|
||||||
|
MilvusHost string `env:"MILVUS_HOST"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfig() Config {
|
||||||
|
config := Config{
|
||||||
|
MilvusHost: "192.168.10.56:19530",
|
||||||
|
OllamaEmbeddingEndpoint: "http://192.168.10.56:11434",
|
||||||
|
OllamaEmbeddingModel: "bge-m3",
|
||||||
|
OllamaEndpoint: "http://192.168.10.56:11434/api/chat",
|
||||||
|
OllamaModel: "qwen3:latest",
|
||||||
|
}
|
||||||
|
cfg.ParseEnv(&config)
|
||||||
|
return config
|
||||||
|
}
|
||||||
16
deploy/Dockerfile
Normal file
16
deploy/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM golang:1.20 AS builder
|
||||||
|
|
||||||
|
WORKDIR /cmd
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN go mod download
|
||||||
|
RUN go build -o rag ./cmd/rag
|
||||||
|
|
||||||
|
FROM alpine:latest
|
||||||
|
|
||||||
|
WORKDIR /root/
|
||||||
|
|
||||||
|
COPY --from=builder /cmd/rag ./
|
||||||
|
|
||||||
|
CMD ["./rag"]
|
||||||
0
deploy/docker-compose.yaml
Normal file
0
deploy/docker-compose.yaml
Normal file
5
embedEtcd.yaml
Normal file
5
embedEtcd.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
listen-client-urls: http://0.0.0.0:2379
|
||||||
|
advertise-client-urls: http://0.0.0.0:2379
|
||||||
|
quota-backend-bytes: 4294967296
|
||||||
|
auto-compaction-mode: revision
|
||||||
|
auto-compaction-retention: '1000'
|
||||||
49
go.mod
Normal file
49
go.mod
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
module easy_rag
|
||||||
|
|
||||||
|
go 1.24.3
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/eschao/config v0.1.0
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/jonathanhecl/chunker v0.0.1
|
||||||
|
github.com/labstack/echo/v4 v4.13.4
|
||||||
|
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2
|
||||||
|
github.com/stretchr/testify v1.10.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/cockroachdb/errors v1.9.1 // indirect
|
||||||
|
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
|
||||||
|
github.com/cockroachdb/redact v1.1.3 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/getsentry/sentry-go v0.12.0 // indirect
|
||||||
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
|
github.com/golang/protobuf v1.5.2 // indirect
|
||||||
|
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
|
||||||
|
github.com/kr/pretty v0.3.0 // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/labstack/gommon v0.4.2 // indirect
|
||||||
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.4 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
|
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||||
|
golang.org/x/crypto v0.38.0 // indirect
|
||||||
|
golang.org/x/net v0.40.0 // indirect
|
||||||
|
golang.org/x/sync v0.14.0 // indirect
|
||||||
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
golang.org/x/text v0.25.0 // indirect
|
||||||
|
golang.org/x/time v0.11.0 // indirect
|
||||||
|
google.golang.org/genproto v0.0.0-20220503193339-ba3ae3f07e29 // indirect
|
||||||
|
google.golang.org/grpc v1.48.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.33.0 // indirect
|
||||||
|
gopkg.in/yaml.v2 v2.2.8 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
20
internal/database/database.go
Normal file
20
internal/database/database.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import "easy_rag/internal/models"
|
||||||
|
|
||||||
|
// database interface
|
||||||
|
|
||||||
|
// Database defines the interface for interacting with a database
|
||||||
|
type Database interface {
|
||||||
|
SaveDocument(document models.Document) error // the content will be chunked and saved
|
||||||
|
GetDocumentInfo(id string) (models.Document, error) // return the document with the given id without content
|
||||||
|
GetDocument(id string) (models.Document, error) // return the document with the given id with content assembled
|
||||||
|
Search(vector [][]float32) ([]models.Embedding, error)
|
||||||
|
ListDocuments() ([]models.Document, error)
|
||||||
|
DeleteDocument(id string) error
|
||||||
|
SaveEmbeddings(embeddings []models.Embedding) error
|
||||||
|
// to implement in future
|
||||||
|
// SearchByCategory(category []string) ([]Embedding, error)
|
||||||
|
// SearchByMetadata(metadata map[string]string) ([]Embedding, error)
|
||||||
|
// GetAllEmbeddingByDocumentID(documentID string) ([]Embedding, error)
|
||||||
|
}
|
||||||
142
internal/database/database_milvus.go
Normal file
142
internal/database/database_milvus.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
"easy_rag/internal/pkg/database/milvus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// implement database interface for milvus
|
||||||
|
type Milvus struct {
|
||||||
|
Host string
|
||||||
|
Client *milvus.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMilvus(host string) *Milvus {
|
||||||
|
|
||||||
|
milviusClient, err := milvus.NewClient(host)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Milvus{
|
||||||
|
Host: host,
|
||||||
|
Client: milviusClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) SaveDocument(document models.Document) error {
|
||||||
|
// for now lets use context background
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
return m.Client.InsertDocuments(ctx, []models.Document{document})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) SaveEmbeddings(embeddings []models.Embedding) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
return m.Client.InsertEmbeddings(ctx, embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) GetDocumentInfo(id string) (models.Document, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
doc, err := m.Client.GetDocumentByID(ctx, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return models.Document{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(doc) == 0 {
|
||||||
|
return models.Document{}, nil
|
||||||
|
}
|
||||||
|
return models.Document{
|
||||||
|
ID: doc["ID"].(string),
|
||||||
|
Link: doc["Link"].(string),
|
||||||
|
Filename: doc["Filename"].(string),
|
||||||
|
Category: doc["Category"].(string),
|
||||||
|
EmbeddingModel: doc["EmbeddingModel"].(string),
|
||||||
|
Summary: doc["Summary"].(string),
|
||||||
|
Metadata: doc["Metadata"].(map[string]string),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) GetDocument(id string) (models.Document, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
doc, err := m.Client.GetDocumentByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return models.Document{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeds, err := m.Client.GetAllEmbeddingByDocID(ctx, id)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return models.Document{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// order embed by order
|
||||||
|
sort.Slice(embeds, func(i, j int) bool {
|
||||||
|
return embeds[i].Order < embeds[j].Order
|
||||||
|
})
|
||||||
|
|
||||||
|
// concatenate text chunks
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for _, embed := range embeds {
|
||||||
|
buf.WriteString(embed.TextChunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
textChunks := buf.String()
|
||||||
|
|
||||||
|
if len(doc) == 0 {
|
||||||
|
return models.Document{}, nil
|
||||||
|
}
|
||||||
|
return models.Document{
|
||||||
|
ID: doc["ID"].(string),
|
||||||
|
Content: textChunks,
|
||||||
|
Link: doc["Link"].(string),
|
||||||
|
Filename: doc["Filename"].(string),
|
||||||
|
Category: doc["Category"].(string),
|
||||||
|
EmbeddingModel: doc["EmbeddingModel"].(string),
|
||||||
|
Summary: doc["Summary"].(string),
|
||||||
|
Metadata: doc["Metadata"].(map[string]string),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) Search(vector [][]float32) ([]models.Embedding, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
results, err := m.Client.Search(ctx, vector, 10)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) ListDocuments() ([]models.Document, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
docs, err := m.Client.GetAllDocuments(ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get docs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return docs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Milvus) DeleteDocument(id string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
err := m.Client.DeleteDocument(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.Client.DeleteEmbedding(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
8
internal/embeddings/embeddings.go
Normal file
8
internal/embeddings/embeddings.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package embeddings
|
||||||
|
|
||||||
|
// implement embeddings interface
|
||||||
|
type EmbeddingsService interface {
|
||||||
|
// generate embedding from text
|
||||||
|
Vectorize(text string) ([][]float32, error)
|
||||||
|
GetModel() string
|
||||||
|
}
|
||||||
78
internal/embeddings/ollama_embeddings.go
Normal file
78
internal/embeddings/ollama_embeddings.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package embeddings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OllamaEmbeddings struct {
|
||||||
|
Endpoint string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOllamaEmbeddings(endpoint string, model string) *OllamaEmbeddings {
|
||||||
|
return &OllamaEmbeddings{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vectorize generates an embedding for the provided text
|
||||||
|
func (o *OllamaEmbeddings) Vectorize(text string) ([][]float32, error) {
|
||||||
|
// Define the request payload
|
||||||
|
payload := map[string]string{
|
||||||
|
"model": o.Model,
|
||||||
|
"input": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the payload to JSON
|
||||||
|
jsonData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the HTTP request
|
||||||
|
url := fmt.Sprintf("%s/api/embed", o.Endpoint)
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Execute the HTTP request
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check for non-200 status code
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("received non-200 response: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and parse the response body
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assuming the response JSON contains an "embedding" field with a float32 array
|
||||||
|
var response struct {
|
||||||
|
Embeddings [][]float32 `json:"embeddings"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.Embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OllamaEmbeddings) GetModel() string {
|
||||||
|
return o.Model
|
||||||
|
}
|
||||||
23
internal/embeddings/openai_embeddings.go
Normal file
23
internal/embeddings/openai_embeddings.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package embeddings
|
||||||
|
|
||||||
|
type OpenAIEmbeddings struct {
|
||||||
|
APIKey string
|
||||||
|
Endpoint string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAIEmbeddings(apiKey string, endpoint string, model string) *OpenAIEmbeddings {
|
||||||
|
return &OpenAIEmbeddings{
|
||||||
|
APIKey: apiKey,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAIEmbeddings) Vectorize(text string) ([]float32, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAIEmbeddings) GetModel() string {
|
||||||
|
return o.Model
|
||||||
|
}
|
||||||
8
internal/llm/llm.go
Normal file
8
internal/llm/llm.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
// implement llm interface
|
||||||
|
type LLMService interface {
|
||||||
|
// generate text from prompt
|
||||||
|
Generate(prompt string) (string, error)
|
||||||
|
GetModel() string
|
||||||
|
}
|
||||||
82
internal/llm/ollama_llm.go
Normal file
82
internal/llm/ollama_llm.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Ollama struct {
|
||||||
|
Endpoint string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOllama(endpoint string, model string) *Ollama {
|
||||||
|
return &Ollama{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents the structure of the expected response from the API.
|
||||||
|
type Response struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate sends a prompt to the Ollama endpoint and returns the response
|
||||||
|
func (o *Ollama) Generate(prompt string) (string, error) {
|
||||||
|
// Create the request payload
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"model": o.Model,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"stream": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the payload into JSON
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the POST request
|
||||||
|
resp, err := http.Post(o.Endpoint, "application/json", bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to make request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Read and parse the response
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("API returned error: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal the response into a predefined structure
|
||||||
|
var response Response
|
||||||
|
if err := json.Unmarshal(body, &response); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and return the content from the nested structure
|
||||||
|
return response.Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Ollama) GetModel() string {
|
||||||
|
return o.Model
|
||||||
|
}
|
||||||
24
internal/llm/openai_llm.go
Normal file
24
internal/llm/openai_llm.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
type OpenAI struct {
|
||||||
|
APIKey string
|
||||||
|
Endpoint string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAI(apiKey string, endpoint string, model string) *OpenAI {
|
||||||
|
return &OpenAI{
|
||||||
|
APIKey: apiKey,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAI) Generate(prompt string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
// TODO: implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenAI) GetModel() string {
|
||||||
|
return o.Model
|
||||||
|
}
|
||||||
155
internal/llm/openroute/definitions.go
Normal file
155
internal/llm/openroute/definitions.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package openroute
|
||||||
|
|
||||||
|
// Request represents the main request structure.
|
||||||
|
type Request struct {
|
||||||
|
Messages []MessageRequest `json:"messages,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice ToolChoice `json:"tool_choice,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
|
||||||
|
LogitBias map[int]float64 `json:"logit_bias,omitempty"`
|
||||||
|
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||||
|
MinP float64 `json:"min_p,omitempty"`
|
||||||
|
TopA float64 `json:"top_a,omitempty"`
|
||||||
|
Prediction *Prediction `json:"prediction,omitempty"`
|
||||||
|
Transforms []string `json:"transforms,omitempty"`
|
||||||
|
Models []string `json:"models,omitempty"`
|
||||||
|
Route string `json:"route,omitempty"`
|
||||||
|
Provider *ProviderPreferences `json:"provider,omitempty"`
|
||||||
|
IncludeReasoning bool `json:"include_reasoning,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseFormat represents the response format structure.
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prediction represents the prediction structure.
|
||||||
|
type Prediction struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderPreferences represents the provider preferences structure.
|
||||||
|
type ProviderPreferences struct {
|
||||||
|
RefererURL string `json:"referer_url,omitempty"`
|
||||||
|
SiteName string `json:"site_name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents the message structure.
|
||||||
|
type MessageRequest struct {
|
||||||
|
Role MessageRole `json:"role"`
|
||||||
|
Content interface{} `json:"content"` // Can be string or []ContentPart
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleSystem MessageRole = "system"
|
||||||
|
RoleUser MessageRole = "user"
|
||||||
|
RoleAssistant MessageRole = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContentPart represents the content part structure.
|
||||||
|
type ContentPart struct {
|
||||||
|
Type ContnetType `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContnetType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContentTypeText ContnetType = "text"
|
||||||
|
ContentTypeImage ContnetType = "image_url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImageURL represents the image URL structure.
|
||||||
|
type ImageURL struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail string `json:"detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionDescription represents the function description structure.
|
||||||
|
type FunctionDescription struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters interface{} `json:"parameters"` // JSON Schema object
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool represents the tool structure.
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionDescription `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolChoice represents the tool choice structure.
|
||||||
|
type ToolChoice struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
SystemFingerprint *string `json:"system_fingerprint,omitempty"`
|
||||||
|
Usage *ResponseUsage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Message *MessageResponse `json:"message,omitempty"`
|
||||||
|
Delta *Delta `json:"delta,omitempty"`
|
||||||
|
Error *ErrorResponse `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageResponse struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Reasoning string `json:"reasoning,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Delta struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Reasoning string `json:"reasoning,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function interface{} `json:"function"`
|
||||||
|
}
|
||||||
198
internal/llm/openroute/route_agent.go
Normal file
198
internal/llm/openroute/route_agent.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package openroute
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type RouterAgentConfig struct {
|
||||||
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice ToolChoice `json:"tool_choice,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
RepetitionPenalty float64 `json:"repetition_penalty,omitempty"`
|
||||||
|
LogitBias map[int]float64 `json:"logit_bias,omitempty"`
|
||||||
|
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||||
|
MinP float64 `json:"min_p,omitempty"`
|
||||||
|
TopA float64 `json:"top_a,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouterAgent struct {
|
||||||
|
client *OpenRouterClient
|
||||||
|
model string
|
||||||
|
config RouterAgentConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRouterAgent(client *OpenRouterClient, model string, config RouterAgentConfig) *RouterAgent {
|
||||||
|
return &RouterAgent{
|
||||||
|
client: client,
|
||||||
|
model: model,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent RouterAgent) Completion(prompt string) (*Response, error) {
|
||||||
|
request := Request{
|
||||||
|
Prompt: prompt,
|
||||||
|
Model: agent.model,
|
||||||
|
ResponseFormat: agent.config.ResponseFormat,
|
||||||
|
Stop: agent.config.Stop,
|
||||||
|
MaxTokens: agent.config.MaxTokens,
|
||||||
|
Temperature: agent.config.Temperature,
|
||||||
|
Tools: agent.config.Tools,
|
||||||
|
ToolChoice: agent.config.ToolChoice,
|
||||||
|
Seed: agent.config.Seed,
|
||||||
|
TopP: agent.config.TopP,
|
||||||
|
TopK: agent.config.TopK,
|
||||||
|
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||||
|
PresencePenalty: agent.config.PresencePenalty,
|
||||||
|
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||||
|
LogitBias: agent.config.LogitBias,
|
||||||
|
TopLogprobs: agent.config.TopLogprobs,
|
||||||
|
MinP: agent.config.MinP,
|
||||||
|
TopA: agent.config.TopA,
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent.client.FetchChatCompletions(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent RouterAgent) CompletionStream(prompt string, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||||
|
request := Request{
|
||||||
|
Prompt: prompt,
|
||||||
|
Model: agent.model,
|
||||||
|
ResponseFormat: agent.config.ResponseFormat,
|
||||||
|
Stop: agent.config.Stop,
|
||||||
|
MaxTokens: agent.config.MaxTokens,
|
||||||
|
Temperature: agent.config.Temperature,
|
||||||
|
Tools: agent.config.Tools,
|
||||||
|
ToolChoice: agent.config.ToolChoice,
|
||||||
|
Seed: agent.config.Seed,
|
||||||
|
TopP: agent.config.TopP,
|
||||||
|
TopK: agent.config.TopK,
|
||||||
|
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||||
|
PresencePenalty: agent.config.PresencePenalty,
|
||||||
|
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||||
|
LogitBias: agent.config.LogitBias,
|
||||||
|
TopLogprobs: agent.config.TopLogprobs,
|
||||||
|
MinP: agent.config.MinP,
|
||||||
|
TopA: agent.config.TopA,
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent RouterAgent) Chat(messages []MessageRequest) (*Response, error) {
|
||||||
|
request := Request{
|
||||||
|
Messages: messages,
|
||||||
|
Model: agent.model,
|
||||||
|
ResponseFormat: agent.config.ResponseFormat,
|
||||||
|
Stop: agent.config.Stop,
|
||||||
|
MaxTokens: agent.config.MaxTokens,
|
||||||
|
Temperature: agent.config.Temperature,
|
||||||
|
Tools: agent.config.Tools,
|
||||||
|
ToolChoice: agent.config.ToolChoice,
|
||||||
|
Seed: agent.config.Seed,
|
||||||
|
TopP: agent.config.TopP,
|
||||||
|
TopK: agent.config.TopK,
|
||||||
|
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||||
|
PresencePenalty: agent.config.PresencePenalty,
|
||||||
|
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||||
|
LogitBias: agent.config.LogitBias,
|
||||||
|
TopLogprobs: agent.config.TopLogprobs,
|
||||||
|
MinP: agent.config.MinP,
|
||||||
|
TopA: agent.config.TopA,
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
return agent.client.FetchChatCompletions(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent RouterAgent) ChatStream(messages []MessageRequest, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||||
|
request := Request{
|
||||||
|
Messages: messages,
|
||||||
|
Model: agent.model,
|
||||||
|
ResponseFormat: agent.config.ResponseFormat,
|
||||||
|
Stop: agent.config.Stop,
|
||||||
|
MaxTokens: agent.config.MaxTokens,
|
||||||
|
Temperature: agent.config.Temperature,
|
||||||
|
Tools: agent.config.Tools,
|
||||||
|
ToolChoice: agent.config.ToolChoice,
|
||||||
|
Seed: agent.config.Seed,
|
||||||
|
TopP: agent.config.TopP,
|
||||||
|
TopK: agent.config.TopK,
|
||||||
|
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||||
|
PresencePenalty: agent.config.PresencePenalty,
|
||||||
|
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||||
|
LogitBias: agent.config.LogitBias,
|
||||||
|
TopLogprobs: agent.config.TopLogprobs,
|
||||||
|
MinP: agent.config.MinP,
|
||||||
|
TopA: agent.config.TopA,
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouterAgentChat struct {
|
||||||
|
RouterAgent
|
||||||
|
Messages []MessageRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRouterAgentChat(client *OpenRouterClient, model string, config RouterAgentConfig, system_prompt string) RouterAgentChat {
|
||||||
|
return RouterAgentChat{
|
||||||
|
RouterAgent: RouterAgent{
|
||||||
|
client: client,
|
||||||
|
model: model,
|
||||||
|
config: config,
|
||||||
|
},
|
||||||
|
Messages: []MessageRequest{
|
||||||
|
{
|
||||||
|
Role: RoleSystem,
|
||||||
|
Content: system_prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agent *RouterAgentChat) Chat(message string) error {
|
||||||
|
agent.Messages = append(agent.Messages, MessageRequest{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: message,
|
||||||
|
})
|
||||||
|
request := Request{
|
||||||
|
Messages: agent.Messages,
|
||||||
|
Model: agent.model,
|
||||||
|
ResponseFormat: agent.config.ResponseFormat,
|
||||||
|
Stop: agent.config.Stop,
|
||||||
|
MaxTokens: agent.config.MaxTokens,
|
||||||
|
Temperature: agent.config.Temperature,
|
||||||
|
Tools: agent.config.Tools,
|
||||||
|
ToolChoice: agent.config.ToolChoice,
|
||||||
|
Seed: agent.config.Seed,
|
||||||
|
TopP: agent.config.TopP,
|
||||||
|
TopK: agent.config.TopK,
|
||||||
|
FrequencyPenalty: agent.config.FrequencyPenalty,
|
||||||
|
PresencePenalty: agent.config.PresencePenalty,
|
||||||
|
RepetitionPenalty: agent.config.RepetitionPenalty,
|
||||||
|
LogitBias: agent.config.LogitBias,
|
||||||
|
TopLogprobs: agent.config.TopLogprobs,
|
||||||
|
MinP: agent.config.MinP,
|
||||||
|
TopA: agent.config.TopA,
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := agent.client.FetchChatCompletions(request)
|
||||||
|
|
||||||
|
agent.Messages = append(agent.Messages, MessageRequest{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: response.Choices[0].Message.Content,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
188
internal/llm/openroute/route_client.go
Normal file
188
internal/llm/openroute/route_client.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package openroute
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenRouterClient struct {
|
||||||
|
apiKey string
|
||||||
|
apiURL string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenRouterClient(apiKey string) *OpenRouterClient {
|
||||||
|
return &OpenRouterClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
apiURL: "https://openrouter.ai/api/v1",
|
||||||
|
httpClient: &http.Client{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenRouterClientFull(apiKey string, apiUrl string, client *http.Client) *OpenRouterClient {
|
||||||
|
return &OpenRouterClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
apiURL: apiUrl,
|
||||||
|
httpClient: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenRouterClient) FetchChatCompletions(request Request) (*Response, error) {
|
||||||
|
headers := map[string]string{
|
||||||
|
"Authorization": "Bearer " + c.apiKey,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Provider != nil {
|
||||||
|
headers["HTTP-Referer"] = request.Provider.RefererURL
|
||||||
|
headers["X-Title"] = request.Provider.SiteName
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outputReponse := &Response{}
|
||||||
|
err = json.Unmarshal(output, outputReponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputReponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenRouterClient) FetchChatCompletionsStream(request Request, outputChan chan Response, processingChan chan interface{}, errChan chan error, ctx context.Context) {
|
||||||
|
headers := map[string]string{
|
||||||
|
"Authorization": "Bearer " + c.apiKey,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Provider != nil {
|
||||||
|
headers["HTTP-Referer"] = request.Provider.RefererURL
|
||||||
|
headers["X-Title"] = request.Provider.SiteName
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
close(errChan)
|
||||||
|
close(outputChan)
|
||||||
|
close(processingChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat/completions", c.apiURL), bytes.NewBuffer(body))
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
close(errChan)
|
||||||
|
close(outputChan)
|
||||||
|
close(processingChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
close(errChan)
|
||||||
|
close(outputChan)
|
||||||
|
close(processingChan)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(outputChan)
|
||||||
|
defer close(processingChan)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errChan <- fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
errChan <- ctx.Err()
|
||||||
|
close(errChan)
|
||||||
|
close(outputChan)
|
||||||
|
close(processingChan)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(line, ":") {
|
||||||
|
select {
|
||||||
|
case processingChan <- true:
|
||||||
|
case <-ctx.Done():
|
||||||
|
errChan <- ctx.Err()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if line != "" {
|
||||||
|
if strings.Compare(line[6:], "[DONE]") == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response := Response{}
|
||||||
|
err = json.Unmarshal([]byte(line[6:]), &response)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case outputChan <- response:
|
||||||
|
case <-ctx.Done():
|
||||||
|
errChan <- ctx.Err()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
24
internal/llm/openroute_llm.go
Normal file
24
internal/llm/openroute_llm.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
type OpenRoute struct {
|
||||||
|
APIKey string
|
||||||
|
Endpoint string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenRoute(apiKey string, endpoint string, model string) *OpenAI {
|
||||||
|
return &OpenAI{
|
||||||
|
APIKey: apiKey,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenRoute) Generate(prompt string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
// TODO: implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OpenRoute) GetModel() string {
|
||||||
|
return o.Model
|
||||||
|
}
|
||||||
27
internal/models/models.go
Normal file
27
internal/models/models.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
// type VectorEmbedding [][]float32
|
||||||
|
// type Vector []float32
|
||||||
|
// Document represents the data structure for storing documents
|
||||||
|
type Document struct {
|
||||||
|
ID string `json:"id" milvus:"ID"` // Unique identifier for the document
|
||||||
|
Content string `json:"content" milvus:"Content"` // Text content of the document become chunks of data will not be saved
|
||||||
|
Link string `json:"link" milvus:"Link"` // Link to the document
|
||||||
|
Filename string `json:"filename" milvus:"Filename"` // Filename of the document
|
||||||
|
Category string `json:"category" milvus:"Category"` // Category of the document
|
||||||
|
EmbeddingModel string `json:"embedding_model" milvus:"EmbeddingModel"` // Embedding model used to generate the embedding
|
||||||
|
Summary string `json:"summary" milvus:"Summary"` // Summary of the document
|
||||||
|
Metadata map[string]string `json:"metadata" milvus:"Metadata"` // Additional metadata (e.g., author, timestamp)
|
||||||
|
Vector []float32 `json:"vector" milvus:"Vector"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding represents the vector embedding for a document or query
|
||||||
|
type Embedding struct {
|
||||||
|
ID string `json:"id" milvus:"ID"` // Unique identifier
|
||||||
|
DocumentID string `json:"document_id" milvus:"DocumentID"` // Unique identifier linked to a Document
|
||||||
|
Vector []float32 `json:"vector" milvus:"Vector"` // The embedding vector
|
||||||
|
TextChunk string `json:"text_chunk" milvus:"TextChunk"` // Text chunk of the document
|
||||||
|
Dimension int64 `json:"dimension" milvus:"Dimension"` // Dimensionality of the vector
|
||||||
|
Order int64 `json:"order" milvus:"Order"` // Order of the embedding to build the content back
|
||||||
|
Score float32 `json:"score"` // Score of the embedding
|
||||||
|
}
|
||||||
169
internal/pkg/database/milvus/client.go
Normal file
169
internal/pkg/database/milvus/client.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package milvus
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
Instance client.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitMilvusClient initializes the Milvus client and returns a wrapper around it.
|
||||||
|
func NewClient(milvusAddr string) (*Client, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c, err := client.NewClient(ctx, client.Config{Address: milvusAddr})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to connect to Milvus: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{Instance: c}
|
||||||
|
|
||||||
|
err = client.EnsureCollections(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureCollections ensures that the required collections ("documents" and "chunks") exist.
|
||||||
|
// If they don't exist, it creates them based on the predefined structs.
|
||||||
|
func (m *Client) EnsureCollections(ctx context.Context) error {
|
||||||
|
collections := []struct {
|
||||||
|
Name string
|
||||||
|
Schema *entity.Schema
|
||||||
|
IndexField string
|
||||||
|
IndexType string
|
||||||
|
MetricType entity.MetricType
|
||||||
|
Nlist int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "documents",
|
||||||
|
Schema: createDocumentSchema(),
|
||||||
|
IndexField: "Vector", // Indexing the Vector field for similarity search
|
||||||
|
IndexType: "IVF_FLAT",
|
||||||
|
MetricType: entity.L2,
|
||||||
|
Nlist: 10, // Number of clusters for IVF_FLAT index
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chunks",
|
||||||
|
Schema: createEmbeddingSchema(),
|
||||||
|
IndexField: "Vector", // Indexing the Vector field for similarity search
|
||||||
|
IndexType: "IVF_FLAT",
|
||||||
|
MetricType: entity.L2,
|
||||||
|
Nlist: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, collection := range collections {
|
||||||
|
// drop collection
|
||||||
|
// err := m.Instance.DropCollection(ctx, collection.Name)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to drop collection '%s': %w", collection.Name, err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Ensure the collection exists
|
||||||
|
exists, err := m.Instance.HasCollection(ctx, collection.Name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check collection existence: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
err := m.Instance.CreateCollection(ctx, collection.Schema, entity.DefaultShardNumber)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create collection '%s': %w", collection.Name, err)
|
||||||
|
}
|
||||||
|
log.Printf("Collection '%s' created successfully", collection.Name)
|
||||||
|
} else {
|
||||||
|
log.Printf("Collection '%s' already exists", collection.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the default partition exists
|
||||||
|
hasPartition, err := m.Instance.HasPartition(ctx, collection.Name, "_default")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check default partition for collection '%s': %w", collection.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasPartition {
|
||||||
|
err = m.Instance.CreatePartition(ctx, collection.Name, "_default")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create default partition for collection '%s': %w", collection.Name, err)
|
||||||
|
}
|
||||||
|
log.Printf("Default partition created for collection '%s'", collection.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip index creation if IndexField is empty
|
||||||
|
if collection.IndexField == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the index exists
|
||||||
|
log.Printf("Creating index on field '%s' for collection '%s'", collection.IndexField, collection.Name)
|
||||||
|
|
||||||
|
idx, err := entity.NewIndexIvfFlat(collection.MetricType, collection.Nlist)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create IVF_FLAT index: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.Instance.CreateIndex(ctx, collection.Name, collection.IndexField, idx, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create index on field '%s' for collection '%s': %w", collection.IndexField, collection.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Index created on field '%s' for collection '%s'", collection.IndexField, collection.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.Instance.LoadCollection(ctx, "documents", false)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to load collection, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.Instance.LoadCollection(ctx, "chunks", false)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to load collection, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for creating schemas
|
||||||
|
func createDocumentSchema() *entity.Schema {
|
||||||
|
return entity.NewSchema().
|
||||||
|
WithName("documents").
|
||||||
|
WithDescription("Collection for storing documents").
|
||||||
|
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
|
||||||
|
WithField(entity.NewField().WithName("Content").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||||
|
WithField(entity.NewField().WithName("Link").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||||
|
WithField(entity.NewField().WithName("Filename").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||||
|
WithField(entity.NewField().WithName("Category").WithDataType(entity.FieldTypeVarChar).WithMaxLength(8048)).
|
||||||
|
WithField(entity.NewField().WithName("EmbeddingModel").WithDataType(entity.FieldTypeVarChar).WithMaxLength(256)).
|
||||||
|
WithField(entity.NewField().WithName("Summary").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||||
|
WithField(entity.NewField().WithName("Metadata").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||||
|
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)) // bge-m3
|
||||||
|
}
|
||||||
|
|
||||||
|
func createEmbeddingSchema() *entity.Schema {
|
||||||
|
return entity.NewSchema().
|
||||||
|
WithName("chunks").
|
||||||
|
WithDescription("Collection for storing document embeddings").
|
||||||
|
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(512)).
|
||||||
|
WithField(entity.NewField().WithName("DocumentID").WithDataType(entity.FieldTypeVarChar).WithMaxLength(512)).
|
||||||
|
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(1024)). // bge-m3
|
||||||
|
WithField(entity.NewField().WithName("TextChunk").WithDataType(entity.FieldTypeVarChar).WithMaxLength(65535)).
|
||||||
|
WithField(entity.NewField().WithName("Dimension").WithDataType(entity.FieldTypeInt32)).
|
||||||
|
WithField(entity.NewField().WithName("Order").WithDataType(entity.FieldTypeInt32))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the Milvus client connection.
|
||||||
|
func (m *Client) Close() {
|
||||||
|
m.Instance.Close()
|
||||||
|
}
|
||||||
32
internal/pkg/database/milvus/client_test.go
Normal file
32
internal/pkg/database/milvus/client_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package milvus
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewClient(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
milvusAddr string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *Client
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewClient(tt.args.milvusAddr)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewClient() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
276
internal/pkg/database/milvus/helpers.go
Normal file
276
internal/pkg/database/milvus/helpers.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package milvus
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper functions for extracting data
|
||||||
|
func extractIDs(docs []models.Document) []string {
|
||||||
|
ids := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
ids[i] = doc.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractLinks extracts the "Link" field from the documents.
|
||||||
|
func extractLinks(docs []models.Document) []string {
|
||||||
|
links := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
links[i] = doc.Link
|
||||||
|
}
|
||||||
|
return links
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFilenames extracts the "Filename" field from the documents.
|
||||||
|
func extractFilenames(docs []models.Document) []string {
|
||||||
|
filenames := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
filenames[i] = doc.Filename
|
||||||
|
}
|
||||||
|
return filenames
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCategories extracts the "Category" field from the documents as a comma-separated string.
|
||||||
|
func extractCategories(docs []models.Document) []string {
|
||||||
|
categories := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
categories[i] = fmt.Sprintf("%v", doc.Category)
|
||||||
|
}
|
||||||
|
return categories
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEmbeddingModels extracts the "EmbeddingModel" field from the documents.
|
||||||
|
func extractEmbeddingModels(docs []models.Document) []string {
|
||||||
|
models := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
models[i] = doc.EmbeddingModel
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSummaries extracts the "Summary" field from the documents.
|
||||||
|
func extractSummaries(docs []models.Document) []string {
|
||||||
|
summaries := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
summaries[i] = doc.Summary
|
||||||
|
}
|
||||||
|
return summaries
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMetadata extracts the "Metadata" field from the documents as a JSON string.
|
||||||
|
func extractMetadata(docs []models.Document) []string {
|
||||||
|
metadata := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
metaBytes, _ := json.Marshal(doc.Metadata)
|
||||||
|
metadata[i] = string(metaBytes)
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToMetadata(metadata string) map[string]string {
|
||||||
|
var metadataMap map[string]string
|
||||||
|
json.Unmarshal([]byte(metadata), &metadataMap)
|
||||||
|
return metadataMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractContents(docs []models.Document) []string {
|
||||||
|
contents := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
contents[i] = doc.Content
|
||||||
|
}
|
||||||
|
return contents
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractEmbeddingIDs extracts the "ID" field from the embeddings.
|
||||||
|
func extractEmbeddingIDs(embeddings []models.Embedding) []string {
|
||||||
|
ids := make([]string, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
ids[i] = embedding.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDocumentIDs extracts the "DocumentID" field from the embeddings.
|
||||||
|
func extractDocumentIDs(embeddings []models.Embedding) []string {
|
||||||
|
documentIDs := make([]string, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
documentIDs[i] = embedding.DocumentID
|
||||||
|
}
|
||||||
|
return documentIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractVectors extracts the "Vector" field from the embeddings.
|
||||||
|
func extractVectors(embeddings []models.Embedding) [][]float32 {
|
||||||
|
vectors := make([][]float32, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
vectors[i] = embedding.Vector // Direct assignment since it's already []float32
|
||||||
|
}
|
||||||
|
return vectors
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractVectorsDocs extracts the "Vector" field from the documents.
|
||||||
|
func extractVectorsDocs(docs []models.Document) [][]float32 {
|
||||||
|
vectors := make([][]float32, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
vectors[i] = doc.Vector // Direct assignment since it's already []float32
|
||||||
|
}
|
||||||
|
return vectors
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTextChunks extracts the "TextChunk" field from the embeddings.
|
||||||
|
func extractTextChunks(embeddings []models.Embedding) []string {
|
||||||
|
textChunks := make([]string, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
textChunks[i] = embedding.TextChunk
|
||||||
|
}
|
||||||
|
return textChunks
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDimensions extracts the "Dimension" field from the embeddings.
|
||||||
|
func extractDimensions(embeddings []models.Embedding) []int32 {
|
||||||
|
dimensions := make([]int32, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
dimensions[i] = int32(embedding.Dimension)
|
||||||
|
}
|
||||||
|
return dimensions
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractOrders extracts the "Order" field from the embeddings.
|
||||||
|
func extractOrders(embeddings []models.Embedding) []int32 {
|
||||||
|
orders := make([]int32, len(embeddings))
|
||||||
|
for i, embedding := range embeddings {
|
||||||
|
orders[i] = int32(embedding.Order)
|
||||||
|
}
|
||||||
|
return orders
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformResultSet(rs client.ResultSet, outputFields ...string) ([]map[string]interface{}, error) {
|
||||||
|
if rs == nil || rs.Len() == 0 {
|
||||||
|
return nil, fmt.Errorf("empty result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
results := []map[string]interface{}{}
|
||||||
|
|
||||||
|
for i := 0; i < rs.Len(); i++ { // Iterate through rows
|
||||||
|
row := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, fieldName := range outputFields {
|
||||||
|
column := rs.GetColumn(fieldName)
|
||||||
|
if column == nil {
|
||||||
|
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch column.Type() {
|
||||||
|
case entity.FieldTypeInt64:
|
||||||
|
value, err := column.GetAsInt64(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
row[fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeInt32:
|
||||||
|
value, err := column.GetAsInt64(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
row[fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeFloat:
|
||||||
|
value, err := column.GetAsDouble(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
row[fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeDouble:
|
||||||
|
value, err := column.GetAsDouble(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
row[fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeVarChar:
|
||||||
|
value, err := column.GetAsString(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
row[fieldName] = value
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformSearchResultSet(rs client.SearchResult, outputFields ...string) ([]map[string]interface{}, error) {
|
||||||
|
if rs.ResultCount == 0 {
|
||||||
|
return nil, fmt.Errorf("empty result set")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]map[string]interface{}, rs.ResultCount)
|
||||||
|
|
||||||
|
for i := 0; i < rs.ResultCount; i++ { // Iterate through rows
|
||||||
|
result[i] = make(map[string]interface{})
|
||||||
|
for _, fieldName := range outputFields {
|
||||||
|
column := rs.Fields.GetColumn(fieldName)
|
||||||
|
result[i]["Score"] = rs.Scores[i]
|
||||||
|
|
||||||
|
if column == nil {
|
||||||
|
return nil, fmt.Errorf("column %s does not exist in result set", fieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch column.Type() {
|
||||||
|
case entity.FieldTypeInt64:
|
||||||
|
value, err := column.GetAsInt64(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
result[i][fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeInt32:
|
||||||
|
value, err := column.GetAsInt64(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting int64 value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
result[i][fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeFloat:
|
||||||
|
value, err := column.GetAsDouble(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting float value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
result[i][fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeDouble:
|
||||||
|
value, err := column.GetAsDouble(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting double value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
result[i][fieldName] = value
|
||||||
|
|
||||||
|
case entity.FieldTypeVarChar:
|
||||||
|
value, err := column.GetAsString(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting string value for column %s, row %d: %w", fieldName, i, err)
|
||||||
|
}
|
||||||
|
result[i][fieldName] = value
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported field type for column %s", fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
270
internal/pkg/database/milvus/operations.go
Normal file
270
internal/pkg/database/milvus/operations.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package milvus
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||||
|
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InsertDocuments inserts documents into the "documents" collection.
|
||||||
|
func (m *Client) InsertDocuments(ctx context.Context, docs []models.Document) error {
|
||||||
|
idColumn := entity.NewColumnVarChar("ID", extractIDs(docs))
|
||||||
|
contentColumn := entity.NewColumnVarChar("Content", extractContents(docs))
|
||||||
|
linkColumn := entity.NewColumnVarChar("Link", extractLinks(docs))
|
||||||
|
filenameColumn := entity.NewColumnVarChar("Filename", extractFilenames(docs))
|
||||||
|
categoryColumn := entity.NewColumnVarChar("Category", extractCategories(docs))
|
||||||
|
embeddingModelColumn := entity.NewColumnVarChar("EmbeddingModel", extractEmbeddingModels(docs))
|
||||||
|
summaryColumn := entity.NewColumnVarChar("Summary", extractSummaries(docs))
|
||||||
|
metadataColumn := entity.NewColumnVarChar("Metadata", extractMetadata(docs))
|
||||||
|
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectorsDocs(docs))
|
||||||
|
// Insert the data
|
||||||
|
_, err := m.Instance.Insert(ctx, "documents", "_default", idColumn, contentColumn, linkColumn, filenameColumn,
|
||||||
|
categoryColumn, embeddingModelColumn, summaryColumn, metadataColumn, vectorColumn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert documents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush the collection
|
||||||
|
err = m.Instance.Flush(ctx, "documents", false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to flush documents collection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertEmbeddings inserts embeddings into the "chunks" collection.
|
||||||
|
func (m *Client) InsertEmbeddings(ctx context.Context, embeddings []models.Embedding) error {
|
||||||
|
idColumn := entity.NewColumnVarChar("ID", extractEmbeddingIDs(embeddings))
|
||||||
|
documentIDColumn := entity.NewColumnVarChar("DocumentID", extractDocumentIDs(embeddings))
|
||||||
|
vectorColumn := entity.NewColumnFloatVector("Vector", 1024, extractVectors(embeddings))
|
||||||
|
textChunkColumn := entity.NewColumnVarChar("TextChunk", extractTextChunks(embeddings))
|
||||||
|
dimensionColumn := entity.NewColumnInt32("Dimension", extractDimensions(embeddings))
|
||||||
|
orderColumn := entity.NewColumnInt32("Order", extractOrders(embeddings))
|
||||||
|
|
||||||
|
_, err := m.Instance.Insert(ctx, "chunks", "_default", idColumn, documentIDColumn, vectorColumn,
|
||||||
|
textChunkColumn, dimensionColumn, orderColumn)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert embeddings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.Instance.Flush(ctx, "chunks", false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to flush chunks collection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDocumentByID retrieves a document from the "documents" collection by ID.
|
||||||
|
func (m *Client) GetDocumentByID(ctx context.Context, id string) (map[string]interface{}, error) {
|
||||||
|
collectionName := "documents"
|
||||||
|
expr := fmt.Sprintf("ID == '%s'", id)
|
||||||
|
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
||||||
|
|
||||||
|
results, err := m.Instance.Query(ctx, collectionName, nil, expr, projections)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query document by ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
return nil, fmt.Errorf("document with ID '%s' not found", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
mp, err := transformResultSet(results, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal document: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert metadata to map
|
||||||
|
mp[0]["Metadata"] = convertToMetadata(mp[0]["Metadata"].(string))
|
||||||
|
|
||||||
|
return mp[0], err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllDocuments retrieves all documents from the "documents" collection.
|
||||||
|
func (m *Client) GetAllDocuments(ctx context.Context) ([]models.Document, error) {
|
||||||
|
collectionName := "documents"
|
||||||
|
projections := []string{"ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata"} // Fetch all fields
|
||||||
|
expr := ""
|
||||||
|
|
||||||
|
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query all documents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rs) == 0 {
|
||||||
|
return nil, fmt.Errorf("no documents found in the collection")
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := transformResultSet(rs, "ID", "Content", "Link", "Filename", "Category", "EmbeddingModel", "Summary", "Metadata")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var docs []models.Document = make([]models.Document, len(results))
|
||||||
|
for i, result := range results {
|
||||||
|
docs[i] = models.Document{
|
||||||
|
ID: result["ID"].(string),
|
||||||
|
Content: result["Content"].(string),
|
||||||
|
Link: result["Link"].(string),
|
||||||
|
Filename: result["Filename"].(string),
|
||||||
|
Category: result["Category"].(string),
|
||||||
|
EmbeddingModel: result["EmbeddingModel"].(string),
|
||||||
|
Summary: result["Summary"].(string),
|
||||||
|
Metadata: convertToMetadata(results[0]["Metadata"].(string)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return docs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllEmbeddingByDocID retrieves all embeddings linked to a specific DocumentID from the "chunks" collection.
|
||||||
|
func (m *Client) GetAllEmbeddingByDocID(ctx context.Context, documentID string) ([]models.Embedding, error) {
|
||||||
|
collectionName := "chunks"
|
||||||
|
projections := []string{"ID", "DocumentID", "TextChunk", "Order"} // Fetch all fields
|
||||||
|
expr := fmt.Sprintf("DocumentID == '%s'", documentID)
|
||||||
|
|
||||||
|
rs, err := m.Instance.Query(ctx, collectionName, nil, expr, projections, client.WithLimit(1000))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query embeddings by DocumentID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rs.Len() == 0 {
|
||||||
|
return nil, fmt.Errorf("no embeddings found for DocumentID '%s'", documentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := transformResultSet(rs, "ID", "DocumentID", "TextChunk", "Order")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal all documents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var embeddings []models.Embedding = make([]models.Embedding, rs.Len())
|
||||||
|
|
||||||
|
for i, result := range results {
|
||||||
|
embeddings[i] = models.Embedding{
|
||||||
|
ID: result["ID"].(string),
|
||||||
|
DocumentID: result["DocumentID"].(string),
|
||||||
|
TextChunk: result["TextChunk"].(string),
|
||||||
|
Order: result["Order"].(int64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Client) Search(ctx context.Context, vectors [][]float32, topK int) ([]models.Embedding, error) {
|
||||||
|
const (
|
||||||
|
collectionName = "chunks"
|
||||||
|
vectorDim = 1024 // Replace with your actual vector dimension
|
||||||
|
)
|
||||||
|
projections := []string{"ID", "DocumentID", "TextChunk", "Order"}
|
||||||
|
metricType := entity.L2 // Default metric type
|
||||||
|
|
||||||
|
// Validate and convert input vectors
|
||||||
|
searchVectors, err := validateAndConvertVectors(vectors, vectorDim)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set search parameters
|
||||||
|
searchParams, err := entity.NewIndexIvfFlatSearchParam(16) // 16 is the number of clusters for IVF_FLAT index
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create search params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the search
|
||||||
|
searchResults, err := m.Instance.Search(ctx, collectionName, nil, "", projections, searchVectors, "Vector", metricType, topK, searchParams, client.WithLimit(10))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to search collection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process search results
|
||||||
|
embeddings, err := processSearchResults(searchResults)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to process search results: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateAndConvertVectors validates vector dimensions and converts them to Milvus-compatible format.
|
||||||
|
func validateAndConvertVectors(vectors [][]float32, expectedDim int) ([]entity.Vector, error) {
|
||||||
|
searchVectors := make([]entity.Vector, len(vectors))
|
||||||
|
for i, vector := range vectors {
|
||||||
|
if len(vector) != expectedDim {
|
||||||
|
return nil, fmt.Errorf("vector dimension mismatch: expected %d, got %d", expectedDim, len(vector))
|
||||||
|
}
|
||||||
|
searchVectors[i] = entity.FloatVector(vector)
|
||||||
|
}
|
||||||
|
return searchVectors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processSearchResults transforms and aggregates the search results into embeddings and sorts by score.
|
||||||
|
func processSearchResults(results []client.SearchResult) ([]models.Embedding, error) {
|
||||||
|
var embeddings []models.Embedding
|
||||||
|
|
||||||
|
for _, result := range results {
|
||||||
|
for i := 0; i < result.ResultCount; i++ {
|
||||||
|
embeddingMap, err := transformSearchResultSet(result, "ID", "DocumentID", "TextChunk", "Order")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to transform search result set: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, embedding := range embeddingMap {
|
||||||
|
embeddings = append(embeddings, models.Embedding{
|
||||||
|
ID: embedding["ID"].(string),
|
||||||
|
DocumentID: embedding["DocumentID"].(string),
|
||||||
|
TextChunk: embedding["TextChunk"].(string),
|
||||||
|
Order: embedding["Order"].(int64), // Assuming 'Order' is a float64 type
|
||||||
|
Score: embedding["Score"].(float32),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort embeddings by score in descending order (higher is better)
|
||||||
|
sort.Slice(embeddings, func(i, j int) bool {
|
||||||
|
return embeddings[i].Score > embeddings[j].Score
|
||||||
|
})
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDocument deletes a document from the "documents" collection by ID.
|
||||||
|
func (m *Client) DeleteDocument(ctx context.Context, id string) error {
|
||||||
|
collectionName := "documents"
|
||||||
|
partitionName := "_default"
|
||||||
|
expr := fmt.Sprintf("ID == '%s'", id)
|
||||||
|
|
||||||
|
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete document by ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteEmbedding deletes an embedding from the "chunks" collection by ID.
|
||||||
|
func (m *Client) DeleteEmbedding(ctx context.Context, id string) error {
|
||||||
|
collectionName := "chunks"
|
||||||
|
partitionName := "_default"
|
||||||
|
expr := fmt.Sprintf("DocumentID == '%s'", id)
|
||||||
|
|
||||||
|
err := m.Instance.Delete(ctx, collectionName, partitionName, expr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete embedding by DocumentID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
21
internal/pkg/rag/rag.go
Normal file
21
internal/pkg/rag/rag.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package rag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"easy_rag/internal/database"
|
||||||
|
"easy_rag/internal/embeddings"
|
||||||
|
"easy_rag/internal/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Rag struct {
|
||||||
|
LLM llm.LLMService
|
||||||
|
Embeddings embeddings.EmbeddingsService
|
||||||
|
Database database.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRag(llm llm.LLMService, embeddings embeddings.EmbeddingsService, database database.Database) *Rag {
|
||||||
|
return &Rag{
|
||||||
|
LLM: llm,
|
||||||
|
Embeddings: embeddings,
|
||||||
|
Database: database,
|
||||||
|
}
|
||||||
|
}
|
||||||
50
internal/pkg/textprocessor/textprocessor.go
Normal file
50
internal/pkg/textprocessor/textprocessor.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package textprocessor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jonathanhecl/chunker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateChunks(text string) []string {
|
||||||
|
// Maximum characters per chunk
|
||||||
|
const maxCharacters = 5000 // too slow otherwise
|
||||||
|
|
||||||
|
var chunks []string
|
||||||
|
var currentChunk strings.Builder
|
||||||
|
|
||||||
|
// Use the chunker library to split text into sentences
|
||||||
|
sentences := chunker.ChunkSentences(text)
|
||||||
|
|
||||||
|
for _, sentence := range sentences {
|
||||||
|
// Check if adding the sentence exceeds the character limit
|
||||||
|
if currentChunk.Len()+len(sentence) <= maxCharacters {
|
||||||
|
if currentChunk.Len() > 0 {
|
||||||
|
currentChunk.WriteString(" ") // Add a space between sentences
|
||||||
|
}
|
||||||
|
currentChunk.WriteString(sentence)
|
||||||
|
} else {
|
||||||
|
// Add the completed chunk to the chunks slice
|
||||||
|
chunks = append(chunks, currentChunk.String())
|
||||||
|
currentChunk.Reset() // Start a new chunk
|
||||||
|
currentChunk.WriteString(sentence) // Add the sentence to the new chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the last chunk if it has content
|
||||||
|
if currentChunk.Len() > 0 {
|
||||||
|
chunks = append(chunks, currentChunk.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the chunks
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConcatenateStrings(strings []string) string {
|
||||||
|
var result bytes.Buffer
|
||||||
|
for _, str := range strings {
|
||||||
|
result.WriteString(str)
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
169
scripts/standalone_embed.sh
Normal file
169
scripts/standalone_embed.sh
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Licensed to the LF AI & Data foundation under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
run_embed() {
|
||||||
|
cat << EOF > embedEtcd.yaml
|
||||||
|
listen-client-urls: http://0.0.0.0:2379
|
||||||
|
advertise-client-urls: http://0.0.0.0:2379
|
||||||
|
quota-backend-bytes: 4294967296
|
||||||
|
auto-compaction-mode: revision
|
||||||
|
auto-compaction-retention: '1000'
|
||||||
|
EOF
|
||||||
|
|
||||||
|
cat << EOF > user.yaml
|
||||||
|
# Extra config to override default milvus.yaml
|
||||||
|
EOF
|
||||||
|
|
||||||
|
sudo docker run -d \
|
||||||
|
--name milvus-standalone \
|
||||||
|
--security-opt seccomp:unconfined \
|
||||||
|
-e ETCD_USE_EMBED=true \
|
||||||
|
-e ETCD_DATA_DIR=/var/lib/milvus/etcd \
|
||||||
|
-e ETCD_CONFIG_PATH=/milvus/configs/embedEtcd.yaml \
|
||||||
|
-e COMMON_STORAGETYPE=local \
|
||||||
|
-v $(pwd)/volumes/milvus:/var/lib/milvus \
|
||||||
|
-v $(pwd)/embedEtcd.yaml:/milvus/configs/embedEtcd.yaml \
|
||||||
|
-v $(pwd)/user.yaml:/milvus/configs/user.yaml \
|
||||||
|
-p 19530:19530 \
|
||||||
|
-p 9091:9091 \
|
||||||
|
-p 2379:2379 \
|
||||||
|
--health-cmd="curl -f http://localhost:9091/healthz" \
|
||||||
|
--health-interval=30s \
|
||||||
|
--health-start-period=90s \
|
||||||
|
--health-timeout=20s \
|
||||||
|
--health-retries=3 \
|
||||||
|
milvusdb/milvus:v2.4.16 \
|
||||||
|
milvus run standalone 1> /dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_milvus_running() {
|
||||||
|
echo "Wait for Milvus Starting..."
|
||||||
|
while true
|
||||||
|
do
|
||||||
|
res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l`
|
||||||
|
if [ $res -eq 1 ]
|
||||||
|
then
|
||||||
|
echo "Start successfully."
|
||||||
|
echo "To change the default Milvus configuration, add your settings to the user.yaml file and then restart the service."
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
start() {
|
||||||
|
res=`sudo docker ps|grep milvus-standalone|grep healthy|wc -l`
|
||||||
|
if [ $res -eq 1 ]
|
||||||
|
then
|
||||||
|
echo "Milvus is running."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
res=`sudo docker ps -a|grep milvus-standalone|wc -l`
|
||||||
|
if [ $res -eq 1 ]
|
||||||
|
then
|
||||||
|
sudo docker start milvus-standalone 1> /dev/null
|
||||||
|
else
|
||||||
|
run_embed
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "Start failed."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
wait_for_milvus_running
|
||||||
|
}
|
||||||
|
|
||||||
|
stop() {
|
||||||
|
sudo docker stop milvus-standalone 1> /dev/null
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "Stop failed."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Stop successfully."
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
delete_container() {
|
||||||
|
res=`sudo docker ps|grep milvus-standalone|wc -l`
|
||||||
|
if [ $res -eq 1 ]
|
||||||
|
then
|
||||||
|
echo "Please stop Milvus service before delete."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
sudo docker rm milvus-standalone 1> /dev/null
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "Delete milvus container failed."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Delete milvus container successfully."
|
||||||
|
}
|
||||||
|
|
||||||
|
delete() {
|
||||||
|
delete_container
|
||||||
|
sudo rm -rf $(pwd)/volumes
|
||||||
|
sudo rm -rf $(pwd)/embedEtcd.yaml
|
||||||
|
sudo rm -rf $(pwd)/user.yaml
|
||||||
|
echo "Delete successfully."
|
||||||
|
}
|
||||||
|
|
||||||
|
upgrade() {
|
||||||
|
read -p "Please confirm if you'd like to proceed with the upgrade. The default will be to the latest version. Confirm with 'y' for yes or 'n' for no. > " check
|
||||||
|
if [ "$check" == "y" ] ||[ "$check" == "Y" ];then
|
||||||
|
res=`sudo docker ps -a|grep milvus-standalone|wc -l`
|
||||||
|
if [ $res -eq 1 ]
|
||||||
|
then
|
||||||
|
stop
|
||||||
|
delete_container
|
||||||
|
fi
|
||||||
|
|
||||||
|
curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed_latest.sh && \
|
||||||
|
bash standalone_embed_latest.sh start 1> /dev/null && \
|
||||||
|
echo "Upgrade successfully."
|
||||||
|
else
|
||||||
|
echo "Exit upgrade"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
case $1 in
|
||||||
|
restart)
|
||||||
|
stop
|
||||||
|
start
|
||||||
|
;;
|
||||||
|
start)
|
||||||
|
start
|
||||||
|
;;
|
||||||
|
stop)
|
||||||
|
stop
|
||||||
|
;;
|
||||||
|
upgrade)
|
||||||
|
upgrade
|
||||||
|
;;
|
||||||
|
delete)
|
||||||
|
delete
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "please use bash standalone_embed.sh restart|start|stop|upgrade|delete"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
313
tests/api_test.go
Normal file
313
tests/api_test.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"easy_rag/api"
|
||||||
|
"easy_rag/internal/pkg/rag"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example: Test UploadHandler
|
||||||
|
func TestUploadHandler(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
// Create a mock for the LLM, Embeddings, and Database
|
||||||
|
mockLLM := new(MockLLMService)
|
||||||
|
mockEmbeddings := new(MockEmbeddingsService)
|
||||||
|
mockDB := new(MockDatabase)
|
||||||
|
|
||||||
|
// Setup the Rag object
|
||||||
|
r := &rag.Rag{
|
||||||
|
LLM: mockLLM,
|
||||||
|
Embeddings: mockEmbeddings,
|
||||||
|
Database: mockDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
// We expect calls to these mocks in the background goroutine, for each document.
|
||||||
|
|
||||||
|
// The request body
|
||||||
|
requestBody := api.RequestUpload{
|
||||||
|
Docs: []api.UploadDoc{
|
||||||
|
{
|
||||||
|
Content: "Test document content",
|
||||||
|
Link: "http://example.com/doc",
|
||||||
|
Filename: "doc1.txt",
|
||||||
|
Category: "TestCategory",
|
||||||
|
Metadata: map[string]string{"Author": "Me"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert requestBody to JSON
|
||||||
|
reqBodyBytes, _ := json.Marshal(requestBody)
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/upload", bytes.NewReader(reqBodyBytes))
|
||||||
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||||
|
|
||||||
|
// Create a ResponseRecorder
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// New echo context
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
// Set the rag object in context
|
||||||
|
c.Set("Rag", r)
|
||||||
|
|
||||||
|
// Because the UploadHandler spawns a goroutine, we only test the immediate HTTP response.
|
||||||
|
// We can still set expectations for the calls that happen in the goroutine to ensure they're invoked.
|
||||||
|
// For example, we expect the summary to be generated, so:
|
||||||
|
|
||||||
|
testSummary := "Test summary from LLM"
|
||||||
|
mockLLM.On("Generate", mock.Anything).Return(testSummary, nil).Maybe() // .Maybe() because the concurrency might not complete by the time we assert
|
||||||
|
|
||||||
|
// The embedding vector returned from the embeddings service
|
||||||
|
testVector := [][]float32{{0.1, 0.2, 0.3, 0.4}}
|
||||||
|
|
||||||
|
// We'll mock calls to Vectorize() for summary and each chunk
|
||||||
|
mockEmbeddings.On("Vectorize", mock.AnythingOfType("string")).Return(testVector, nil).Maybe()
|
||||||
|
|
||||||
|
// The database SaveDocument / SaveEmbeddings calls
|
||||||
|
mockDB.On("SaveDocument", mock.AnythingOfType("models.Document")).Return(nil).Maybe()
|
||||||
|
mockDB.On("SaveEmbeddings", mock.AnythingOfType("[]models.Embedding")).Return(nil).Maybe()
|
||||||
|
|
||||||
|
// Invoke the handler
|
||||||
|
err := api.UploadHandler(c)
|
||||||
|
|
||||||
|
// Check no immediate errors
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the response
|
||||||
|
assert.Equal(t, http.StatusAccepted, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
_ = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
|
||||||
|
// We expect certain fields in the JSON response
|
||||||
|
assert.Equal(t, "v1", resp["version"])
|
||||||
|
assert.NotEmpty(t, resp["task_id"])
|
||||||
|
assert.Equal(t, "Processing started", resp["status"])
|
||||||
|
|
||||||
|
// Typically, you might want to wait or do more advanced concurrency checks if you want to test
|
||||||
|
// the logic in the goroutine, but that goes beyond a simple unit test.
|
||||||
|
// The background process can be tested more thoroughly in integration or end-to-end tests.
|
||||||
|
|
||||||
|
// Optionally assert that our mocks were called
|
||||||
|
mockLLM.AssertExpectations(t)
|
||||||
|
mockEmbeddings.AssertExpectations(t)
|
||||||
|
mockDB.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Test ListAllDocsHandler
|
||||||
|
func TestListAllDocsHandler(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
mockLLM := new(MockLLMService)
|
||||||
|
mockEmbeddings := new(MockEmbeddingsService)
|
||||||
|
mockDB := new(MockDatabase)
|
||||||
|
|
||||||
|
r := &rag.Rag{
|
||||||
|
LLM: mockLLM,
|
||||||
|
Embeddings: mockEmbeddings,
|
||||||
|
Database: mockDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock data
|
||||||
|
doc1 := models.Document{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Filename: "doc1.txt",
|
||||||
|
Summary: "summary doc1",
|
||||||
|
}
|
||||||
|
doc2 := models.Document{
|
||||||
|
ID: uuid.NewString(),
|
||||||
|
Filename: "doc2.txt",
|
||||||
|
Summary: "summary doc2",
|
||||||
|
}
|
||||||
|
docs := []models.Document{doc1, doc2}
|
||||||
|
|
||||||
|
// Expect the database to return the docs
|
||||||
|
mockDB.On("ListDocuments").Return(docs, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/docs", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.Set("Rag", r)
|
||||||
|
|
||||||
|
err := api.ListAllDocsHandler(c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
_ = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
assert.Equal(t, "v1", resp["version"])
|
||||||
|
|
||||||
|
// The "docs" field should match the ones we returned
|
||||||
|
docsInterface, ok := resp["docs"].([]interface{})
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Len(t, docsInterface, 2)
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockDB.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Test GetDocHandler
|
||||||
|
func TestGetDocHandler(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
mockLLM := new(MockLLMService)
|
||||||
|
mockEmbeddings := new(MockEmbeddingsService)
|
||||||
|
mockDB := new(MockDatabase)
|
||||||
|
|
||||||
|
r := &rag.Rag{
|
||||||
|
LLM: mockLLM,
|
||||||
|
Embeddings: mockEmbeddings,
|
||||||
|
Database: mockDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock a single doc
|
||||||
|
docID := "123"
|
||||||
|
testDoc := models.Document{
|
||||||
|
ID: docID,
|
||||||
|
Filename: "doc3.txt",
|
||||||
|
Summary: "summary doc3",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockDB.On("GetDocument", docID).Return(testDoc, nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/doc/123", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
// set path param
|
||||||
|
c.SetParamNames("id")
|
||||||
|
c.SetParamValues(docID)
|
||||||
|
c.Set("Rag", r)
|
||||||
|
|
||||||
|
err := api.GetDocHandler(c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
_ = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
assert.Equal(t, "v1", resp["version"])
|
||||||
|
|
||||||
|
docInterface := resp["doc"].(map[string]interface{})
|
||||||
|
assert.Equal(t, "doc3.txt", docInterface["filename"])
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockDB.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Test AskDocHandler
|
||||||
|
func TestAskDocHandler(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
mockLLM := new(MockLLMService)
|
||||||
|
mockEmbeddings := new(MockEmbeddingsService)
|
||||||
|
mockDB := new(MockDatabase)
|
||||||
|
|
||||||
|
r := &rag.Rag{
|
||||||
|
LLM: mockLLM,
|
||||||
|
Embeddings: mockEmbeddings,
|
||||||
|
Database: mockDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) We expect to Vectorize the question
|
||||||
|
question := "What is the summary of doc?"
|
||||||
|
questionVector := [][]float32{{0.5, 0.2, 0.1}}
|
||||||
|
mockEmbeddings.On("Vectorize", question).Return(questionVector, nil)
|
||||||
|
|
||||||
|
// 2) We expect a DB search
|
||||||
|
emb := []models.Embedding{
|
||||||
|
{
|
||||||
|
ID: "emb1",
|
||||||
|
DocumentID: "doc123",
|
||||||
|
TextChunk: "Relevant content chunk",
|
||||||
|
Score: 0.99,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockDB.On("Search", questionVector).Return(emb, nil)
|
||||||
|
|
||||||
|
// 3) We expect the LLM to generate an answer from the chunk
|
||||||
|
generatedAnswer := "Here is an answer from the chunk"
|
||||||
|
// The prompt we pass is something like: "Given the following information: chunk ... Answer the question: question"
|
||||||
|
mockLLM.On("Generate", mock.AnythingOfType("string")).Return(generatedAnswer, nil)
|
||||||
|
|
||||||
|
// Prepare request
|
||||||
|
reqBody := api.RequestQuestion{
|
||||||
|
Question: question,
|
||||||
|
}
|
||||||
|
reqBytes, _ := json.Marshal(reqBody)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/ask", bytes.NewReader(reqBytes))
|
||||||
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.Set("Rag", r)
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
err := api.AskDocHandler(c)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
_ = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
assert.Equal(t, "v1", resp["version"])
|
||||||
|
assert.Equal(t, generatedAnswer, resp["answer"])
|
||||||
|
|
||||||
|
// The docs field should have the docID "doc123"
|
||||||
|
docsInterface := resp["docs"].([]interface{})
|
||||||
|
assert.Len(t, docsInterface, 1)
|
||||||
|
assert.Equal(t, "doc123", docsInterface[0])
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockLLM.AssertExpectations(t)
|
||||||
|
mockEmbeddings.AssertExpectations(t)
|
||||||
|
mockDB.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example: Test DeleteDocHandler
|
||||||
|
func TestDeleteDocHandler(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
mockLLM := new(MockLLMService)
|
||||||
|
mockEmbeddings := new(MockEmbeddingsService)
|
||||||
|
mockDB := new(MockDatabase)
|
||||||
|
|
||||||
|
r := &rag.Rag{
|
||||||
|
LLM: mockLLM,
|
||||||
|
Embeddings: mockEmbeddings,
|
||||||
|
Database: mockDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
docID := "abc"
|
||||||
|
mockDB.On("DeleteDocument", docID).Return(nil)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodDelete, "/api/v1/doc/abc", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
c.SetParamNames("id")
|
||||||
|
c.SetParamValues(docID)
|
||||||
|
c.Set("Rag", r)
|
||||||
|
|
||||||
|
err := api.DeleteDocHandler(c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
_ = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
assert.Equal(t, "v1", resp["version"])
|
||||||
|
|
||||||
|
// docs should be nil according to DeleteDocHandler
|
||||||
|
assert.Nil(t, resp["docs"])
|
||||||
|
|
||||||
|
// Verify mocks
|
||||||
|
mockDB.AssertExpectations(t)
|
||||||
|
}
|
||||||
89
tests/mock_test.go
Normal file
89
tests/mock_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"easy_rag/internal/models"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --------------------
|
||||||
|
// Mock LLM
|
||||||
|
// --------------------
|
||||||
|
type MockLLMService struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockLLMService) Generate(prompt string) (string, error) {
|
||||||
|
args := m.Called(prompt)
|
||||||
|
return args.String(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockLLMService) GetModel() string {
|
||||||
|
args := m.Called()
|
||||||
|
return args.String(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------
|
||||||
|
// Mock Embeddings
|
||||||
|
// --------------------
|
||||||
|
type MockEmbeddingsService struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockEmbeddingsService) Vectorize(text string) ([][]float32, error) {
|
||||||
|
args := m.Called(text)
|
||||||
|
return args.Get(0).([][]float32), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockEmbeddingsService) GetModel() string {
|
||||||
|
args := m.Called()
|
||||||
|
return args.String(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------
|
||||||
|
// Mock Database
|
||||||
|
// --------------------
|
||||||
|
type MockDatabase struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDocumentInfo(id string) (models.DocumentInfo, error)
|
||||||
|
func (m *MockDatabase) GetDocumentInfo(id string) (models.Document, error) {
|
||||||
|
args := m.Called(id)
|
||||||
|
return args.Get(0).(models.Document), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveDocument(document Document) error
|
||||||
|
func (m *MockDatabase) SaveDocument(doc models.Document) error {
|
||||||
|
args := m.Called(doc)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveEmbeddings([]Embedding) error
|
||||||
|
func (m *MockDatabase) SaveEmbeddings(emb []models.Embedding) error {
|
||||||
|
args := m.Called(emb)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListDocuments() ([]Document, error)
|
||||||
|
func (m *MockDatabase) ListDocuments() ([]models.Document, error) {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).([]models.Document), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDocument(id string) (Document, error)
|
||||||
|
func (m *MockDatabase) GetDocument(id string) (models.Document, error) {
|
||||||
|
args := m.Called(id)
|
||||||
|
return args.Get(0).(models.Document), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDocument(id string) error
|
||||||
|
func (m *MockDatabase) DeleteDocument(id string) error {
|
||||||
|
args := m.Called(id)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search(vector []float32) ([]models.Embedding, error)
|
||||||
|
func (m *MockDatabase) Search(vector [][]float32) ([]models.Embedding, error) {
|
||||||
|
args := m.Called(vector)
|
||||||
|
return args.Get(0).([]models.Embedding), args.Error(1)
|
||||||
|
}
|
||||||
125
tests/openrouter_test.go
Normal file
125
tests/openrouter_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package api_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
openroute2 "easy_rag/internal/llm/openroute"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFetchChatCompletions(t *testing.T) {
|
||||||
|
client := openroute2.NewOpenRouterClient("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")
|
||||||
|
|
||||||
|
request := openroute2.Request{
|
||||||
|
Model: "qwen/qwen2.5-vl-72b-instruct:free",
|
||||||
|
Messages: []openroute2.MessageRequest{
|
||||||
|
{openroute2.RoleUser, "Привет!", "", ""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := client.FetchChatCompletions(request)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("output: %v", output.Choices[0].Message.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchChatCompletionsStreaming(t *testing.T) {
|
||||||
|
client := openroute2.NewOpenRouterClient("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")
|
||||||
|
|
||||||
|
request := openroute2.Request{
|
||||||
|
Model: "qwen/qwen2.5-vl-72b-instruct:free",
|
||||||
|
Messages: []openroute2.MessageRequest{
|
||||||
|
{openroute2.RoleUser, "Привет!", "", ""},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
outputChan := make(chan openroute2.Response)
|
||||||
|
processingChan := make(chan interface{})
|
||||||
|
errChan := make(chan error)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go client.FetchChatCompletionsStream(request, outputChan, processingChan, errChan, ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case output := <-outputChan:
|
||||||
|
if len(output.Choices) > 0 {
|
||||||
|
t.Logf("%s", output.Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
case <-processingChan:
|
||||||
|
t.Logf("Обработка\n")
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Ошибка: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Println("Контекст отменен:", ctx.Err())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchChatCompletionsAgentStreaming(t *testing.T) {
|
||||||
|
client := openroute2.NewOpenRouterClient("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")
|
||||||
|
agent := openroute2.NewRouterAgent(client, "qwen/qwen2.5-vl-72b-instruct:freet", openroute2.RouterAgentConfig{
|
||||||
|
Temperature: 0.7,
|
||||||
|
MaxTokens: 100,
|
||||||
|
})
|
||||||
|
|
||||||
|
outputChan := make(chan openroute2.Response)
|
||||||
|
processingChan := make(chan interface{})
|
||||||
|
errChan := make(chan error)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
chat := []openroute2.MessageRequest{
|
||||||
|
{Role: openroute2.RoleSystem, Content: "Вы полезный помощник."},
|
||||||
|
{Role: openroute2.RoleUser, Content: "Привет!"},
|
||||||
|
}
|
||||||
|
|
||||||
|
go agent.ChatStream(chat, outputChan, processingChan, errChan, ctx)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case output := <-outputChan:
|
||||||
|
if len(output.Choices) > 0 {
|
||||||
|
t.Logf("%s", output.Choices[0].Delta.Content)
|
||||||
|
}
|
||||||
|
case <-processingChan:
|
||||||
|
t.Logf("Обработка\n")
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Ошибка: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Println("Контекст отменен:", ctx.Err())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchChatCompletionsAgentSimpleChat(t *testing.T) {
|
||||||
|
client := openroute2.NewOpenRouterClient("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")
|
||||||
|
agent := openroute2.NewRouterAgentChat(client, "qwen/qwen2.5-vl-72b-instruct:free", openroute2.RouterAgentConfig{
|
||||||
|
Temperature: 0.0,
|
||||||
|
MaxTokens: 100,
|
||||||
|
}, "Вы полезный помощник, отвечайте короткими словами.")
|
||||||
|
|
||||||
|
agent.Chat("Запомни это: \"wojtess\"")
|
||||||
|
agent.Chat("Что я просил вас запомнить?")
|
||||||
|
|
||||||
|
for _, msg := range agent.Messages {
|
||||||
|
content, ok := msg.Content.(string)
|
||||||
|
if ok {
|
||||||
|
t.Logf("%s: %s", msg.Role, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user