parent
2d1cb076c6
commit
42a5b3ec17
@ -0,0 +1,26 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
service = AdvancedPromptTemplateService()
|
||||
return service.get_prompt(args)
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
@ -0,0 +1,79 @@
|
||||
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
|
||||
|
||||
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
|
||||
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
|
||||
},
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "用户",
|
||||
"assistant_prefix": "助手"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "system",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
|
||||
"chat_prompt_config": {
|
||||
"prompt": [{
|
||||
"role": "user",
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
|
||||
"completion_prompt_config": {
|
||||
"prompt": {
|
||||
"text": "{{#pre_prompt#}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,38 +1,24 @@
|
||||
import re
|
||||
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
|
||||
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
@classmethod
|
||||
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
|
||||
prompt_template = PromptTemplateParser(prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt = prompt_template.format(prompt_inputs)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
|
||||
system_message = system_prompt_template.format(**prompt_inputs)
|
||||
return system_message
|
||||
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@classmethod
|
||||
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
|
||||
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
|
||||
ai_message = ai_prompt_template.format(**prompt_inputs)
|
||||
return ai_message
|
||||
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@classmethod
|
||||
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
|
||||
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
|
||||
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
human_message = human_prompt_template.format(**inputs)
|
||||
return human_message
|
||||
|
||||
@classmethod
|
||||
def process_template(cls, template: str):
|
||||
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
|
||||
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
|
||||
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
|
||||
return processed_template
|
||||
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))
|
||||
|
||||
@ -1,79 +1,39 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Environment, meta
|
||||
from langchain import PromptTemplate
|
||||
from langchain.formatting import StrictFormatter
|
||||
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
|
||||
|
||||
|
||||
class JinjaPromptTemplate(PromptTemplate):
|
||||
template_format: str = "jinja2"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
class PromptTemplateParser:
|
||||
"""
|
||||
Rules:
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
env = Environment()
|
||||
template = template.replace("{{}}", "{}")
|
||||
ast = env.parse(template)
|
||||
input_variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
if "partial_variables" in kwargs:
|
||||
partial_variables = kwargs["partial_variables"]
|
||||
input_variables = {
|
||||
var for var in input_variables if var not in partial_variables
|
||||
}
|
||||
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class OutLinePromptTemplate(PromptTemplate):
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
|
||||
}
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
1. Template variables must be enclosed in `{{}}`.
|
||||
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
|
||||
and can only start with letters and underscores.
|
||||
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
|
||||
4. In addition to the above, 3 types of special template variable Keys are accepted:
|
||||
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
|
||||
"""
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
def __init__(self, template: str):
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
def extract(self) -> list:
|
||||
# Regular expression to match the template rules
|
||||
return re.findall(REGEX, self.template)
|
||||
|
||||
Example:
|
||||
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
|
||||
def replacer(match):
|
||||
key = match.group(1)
|
||||
value = inputs.get(key, match.group(0)) # return original matched string if key not found
|
||||
|
||||
.. code-block:: python
|
||||
if remove_template_variables:
|
||||
return PromptTemplateParser.remove_template_variables(value)
|
||||
return value
|
||||
|
||||
prompt.format(variable1="foo")
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
return OneLineFormatter().format(self.template, **kwargs)
|
||||
return re.sub(REGEX, replacer, self.template)
|
||||
|
||||
|
||||
class OneLineFormatter(StrictFormatter):
|
||||
def parse(self, format_string):
|
||||
last_end = 0
|
||||
results = []
|
||||
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
|
||||
field_name = match.group(1)
|
||||
start, end = match.span()
|
||||
|
||||
literal_text = format_string[last_end:start]
|
||||
last_end = end
|
||||
|
||||
results.append((literal_text, field_name, '', None))
|
||||
|
||||
remaining_literal_text = format_string[last_end:]
|
||||
if remaining_literal_text:
|
||||
results.append((remaining_literal_text, None, None, None))
|
||||
|
||||
return results
|
||||
@classmethod
|
||||
def remove_template_variables(cls, text: str):
|
||||
return re.sub(REGEX, r'{\1}', text)
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
from events.message_event import message_was_created
|
||||
from tasks.generate_conversation_summary_task import generate_conversation_summary_task
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
conversation = kwargs.get('conversation')
|
||||
is_first_message = kwargs.get('is_first_message')
|
||||
|
||||
if not is_first_message and conversation.mode == 'chat' and not conversation.summary:
|
||||
history_message_count = conversation.message_count
|
||||
if history_message_count >= 5:
|
||||
generate_conversation_summary_task.delay(conversation.id)
|
||||
@ -0,0 +1,37 @@
|
||||
"""add advanced prompt templates
|
||||
|
||||
Revision ID: b3a09c049e8e
|
||||
Revises: 2e9819ca5b28
|
||||
Create Date: 2023-10-10 15:23:23.395420
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b3a09c049e8e'
|
||||
down_revision = '2e9819ca5b28'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
|
||||
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('dataset_configs')
|
||||
batch_op.drop_column('completion_prompt_config')
|
||||
batch_op.drop_column('chat_prompt_config')
|
||||
batch_op.drop_column('prompt_type')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,56 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
||||
def get_prompt(self, args: dict) -> dict:
|
||||
app_mode = args['app_mode']
|
||||
model_mode = args['model_mode']
|
||||
model_name = args['model_name']
|
||||
has_context = args['has_context']
|
||||
|
||||
if 'baichuan' in model_name:
|
||||
return self.get_baichuan_prompt(app_mode, model_mode, has_context)
|
||||
else:
|
||||
return self.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT)
|
||||
|
||||
def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
|
||||
def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
|
||||
|
||||
return prompt_template
|
||||
|
||||
|
||||
def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict:
|
||||
if app_mode == 'chat':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif app_mode == 'completion':
|
||||
if model_mode == 'completion':
|
||||
return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
elif model_mode == 'chat':
|
||||
return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT)
|
||||
@ -1,55 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.model_providers.error import LLMError, ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
@shared_task(queue='generation')
|
||||
def generate_conversation_summary_task(conversation_id: str):
|
||||
"""
|
||||
Async Generate conversation summary
|
||||
:param conversation_id:
|
||||
|
||||
Usage: generate_conversation_summary_task.delay(conversation_id)
|
||||
"""
|
||||
logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green'))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
if not conversation:
|
||||
raise NotFound('Conversation not found')
|
||||
|
||||
try:
|
||||
# get conversation messages count
|
||||
history_message_count = conversation.message_count
|
||||
if history_message_count >= 5 and not conversation.summary:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.asc()).all()
|
||||
|
||||
conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
except (LLMError, ProviderTokenNotInitError):
|
||||
conversation.summary = '[No Summary]'
|
||||
db.session.commit()
|
||||
pass
|
||||
except Exception as e:
|
||||
conversation.summary = '[No Summary]'
|
||||
db.session.commit()
|
||||
logging.exception(e)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at),
|
||||
fg='green'))
|
||||
Loading…
Reference in New Issue