Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>pull/5238/head
parent
918ebe1620
commit
ba5f8afaa8
@ -0,0 +1,67 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
|
||||
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
|
||||
data_source_api_key_bindings]}
|
||||
return {'settings': []}
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
||||
@ -0,0 +1,7 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
@ -0,0 +1,49 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
||||
required=True, nullable=True, location='json')
|
||||
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
try:
|
||||
result = WebsiteService.crawl_url(args)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
class WebsiteCrawlStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
||||
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
||||
@ -0,0 +1,132 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(
|
||||
f'{self.base_url}/v0/scrape',
|
||||
headers=headers,
|
||||
json=json_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response['success'] == True:
|
||||
data = response['data']
|
||||
return {
|
||||
'title': data.get('metadata').get('title'),
|
||||
'description': data.get('metadata').get('description'),
|
||||
'source_url': data.get('metadata').get('sourceURL'),
|
||||
'markdown': data.get('markdown')
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
|
||||
elif response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
start_time = time.time()
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get('jobId')
|
||||
return job_id
|
||||
else:
|
||||
self._handle_error(response, 'start crawl job')
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get('status') == 'completed':
|
||||
total = crawl_status_response.get('total', 0)
|
||||
if total == 0:
|
||||
raise Exception('Failed to check crawl status. Error: No page found')
|
||||
data = crawl_status_response.get('data', [])
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||
url_data = {
|
||||
'title': item.get('metadata').get('title'),
|
||||
'description': item.get('metadata').get('description'),
|
||||
'source_url': item.get('metadata').get('sourceURL'),
|
||||
'markdown': item.get('markdown')
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': url_data_list
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
'status': crawl_status_response.get('status'),
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': []
|
||||
}
|
||||
|
||||
else:
|
||||
self._handle_error(response, 'check crawl status')
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class FirecrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to scrape.
|
||||
api_key: The API key for Firecrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
mode: str = 'crawl',
|
||||
only_main_content: bool = False
|
||||
):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == 'crawl':
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': crawl_data.get('source_url'),
|
||||
'description': crawl_data.get('description'),
|
||||
'title': crawl_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == 'scrape':
|
||||
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||
self.only_main_content)
|
||||
|
||||
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': scrape_data.get('source_url'),
|
||||
'description': scrape_data.get('description'),
|
||||
'title': scrape_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
||||
@ -0,0 +1,64 @@
|
||||
# [REVIEW] Implement if Needed? Do we need a new type of data source
|
||||
from abc import abstractmethod
|
||||
|
||||
import requests
|
||||
from api.models.source import DataSourceBearerBinding
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class BearerDataSource:
|
||||
def __init__(self, api_key: str, api_base_url: str):
|
||||
self.api_key = api_key
|
||||
self.api_base_url = api_base_url
|
||||
|
||||
@abstractmethod
|
||||
def validate_bearer_data_source(self):
|
||||
"""
|
||||
Validate the data source
|
||||
"""
|
||||
|
||||
|
||||
class FireCrawlDataSource(BearerDataSource):
|
||||
def validate_bearer_data_source(self):
|
||||
TEST_CRAWL_SITE_URL = "https://www.google.com"
|
||||
FIRECRAWL_API_VERSION = "v0"
|
||||
|
||||
test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"url": TEST_CRAWL_SITE_URL,
|
||||
}
|
||||
|
||||
response = requests.get(test_api_endpoint, headers=headers, json=data)
|
||||
|
||||
return response.json().get("status") == "success"
|
||||
|
||||
def save_credentials(self):
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBearerBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBearerBinding.provider == 'firecrawl',
|
||||
DataSourceBearerBinding.endpoint_url == self.api_base_url,
|
||||
DataSourceBearerBinding.bearer_key == self.api_key
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceBearerBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='firecrawl',
|
||||
endpoint_url=self.api_base_url,
|
||||
bearer_key=self.api_key
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
@ -0,0 +1,67 @@
|
||||
"""add-api-key-auth-binding
|
||||
|
||||
Revision ID: 7b45942e39bb
|
||||
Revises: 47cc7df8c4f3
|
||||
Create Date: 2024-05-14 07:31:29.702766
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7b45942e39bb'
|
||||
down_revision = '4e99a8df00ff'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('data_source_api_key_auth_bindings',
|
||||
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.StringUUID(), nullable=False),
|
||||
sa.Column('category', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('credentials', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
|
||||
batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||
batch_op.drop_index('source_info_idx')
|
||||
|
||||
op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
|
||||
|
||||
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('source_info_idx', postgresql_using='gin')
|
||||
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||
|
||||
op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
|
||||
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('source_info_idx', ['source_info'], unique=False)
|
||||
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
|
||||
batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
|
||||
|
||||
op.drop_table('data_source_api_key_auth_bindings')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
@ -0,0 +1,70 @@
|
||||
import json
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
||||
@ -0,0 +1,171 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'url' not in args or not args['url']:
|
||||
raise ValueError('url is required')
|
||||
if 'options' not in args or not args['options']:
|
||||
raise ValueError('options is required')
|
||||
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||
raise ValueError('limit is required')
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get('provider')
|
||||
url = args.get('url')
|
||||
options = args.get('options')
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||
only_main_content = options.get('only_main_content', False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": includes if includes else [],
|
||||
"excludes": excludes if excludes else [],
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get('limit', 1),
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
if options.get('max_depth'):
|
||||
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {
|
||||
'status': 'active',
|
||||
'job_id': job_id
|
||||
}
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
'status': result.get('status', 'active'),
|
||||
'job_id': job_id,
|
||||
'total': result.get('total', 0),
|
||||
'current': result.get('current', 0),
|
||||
'data': result.get('data', [])
|
||||
}
|
||||
if crawl_status_data['status'] == 'completed':
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
else:
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get('status') != 'completed':
|
||||
raise ValueError('Crawl job is not completed')
|
||||
data = result.get('data')
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get('source_url') == url:
|
||||
return item
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
params = {
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
@ -0,0 +1,90 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
return
|
||||
|
||||
logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
try:
|
||||
if document:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
pass
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
||||
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.models.document import Document
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
|
||||
base_url = 'https://api.firecrawl.dev'
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=base_url)
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"maxDepth": 1,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
|
||||
}
|
||||
}
|
||||
mocked_firecrawl = {
|
||||
"jobId": "test",
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
print(job_id)
|
||||
assert isinstance(job_id, str)
|
||||
Loading…
Reference in New Issue