| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import asyncio
- import functools
- import time
- import chromadb
- from aiohttp import web
- from chromadb.config import Settings
- from chromadb.utils import embedding_functions
- from chromadb.api.models.Collection import Collection
- import logging
- from pathlib import Path
- from config import DATABASE_DIR, HTTP_HOST, HTTP_PORT, QUIET, WEBDATA_DIR, COLLECTION_NAME
- from rich.logging import RichHandler
- from rich.console import Console
- from concurrent.futures import ThreadPoolExecutor
- import uuid
- log = logging.getLogger(__name__)
- COLLECTION: Collection = None
- executor = ThreadPoolExecutor(1)
- routes = web.RouteTableDef()
- WEBDATA_DIR = Path(WEBDATA_DIR)
- async def async_run(func, /, *args, **kwargs):
- loop = asyncio.get_running_loop()
- task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs))
- results = await asyncio.gather(task)
- return results[0]
- @routes.get('/')
- async def serve_index(request):
- return web.FileResponse(WEBDATA_DIR / 'index.html')
- @routes.get('/favicon.ico')
- async def serve_favicon(request):
- return web.FileResponse(WEBDATA_DIR / 'favicon.ico')
- @routes.post('/api/v1/search')
- async def add_items(request):
- data = await request.json()
- results = await async_run(
- COLLECTION.query,
- query_texts=[data.get('text', '')],
- n_results=10,
- )
- items = [
- [
- results['ids'][0][i],
- results['distances'][0][i],
- results['documents'][0][i],
- results['metadatas'][0][i],
- ] for i in range(len(results['ids'][0]))
- ]
- return web.json_response(items)
- @routes.post('/api/v1/item')
- async def add_items(request):
- data = await request.json()
- ids = []
- metadatas = []
- documents = []
- body = data.get('body', '')
- source = data.get('source')
- search_sources = source == '#'
- if search_sources:
- source = ''
- divisions = body.split('\n\n')
- for items in divisions:
- lines = [line for line in items.split('\n') if line.strip()]
- if len(lines) == 0:
- continue
- if search_sources and lines[0].startswith('#'):
- source = lines[0].lstrip('#').strip()
- continue
- title = ''
- if len(lines) >= 2:
- title = lines[0]
- lines = lines[1:]
- else:
- pass
- doc = ' '.join(lines)
- doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one.
- if len(doc) == 0:
- continue
- ids.append(uuid.uuid4().hex)
- documents.append(doc)
- metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title})
- await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
- return web.json_response({'count': await async_run(COLLECTION.count)})
- @routes.delete('/api/v1/item')
- async def delete_item(request: web.BaseRequest):
- await async_run(COLLECTION.delete, ids=[request.query.get('id')])
- return
- @routes.put('/api/v1/item')
- async def edit_item(request: web.BaseRequest):
- data = await request.json()
- await async_run(
- COLLECTION.update,
- ids=[request.query.get('id')],
- metadatas=[data.get('metadata', {})],
- documents=[data.get('document', {})]
- )
- return
- # @routes.get('/api/v1/sources')
- # async def get_sources(request: web.BaseRequest):
- # await async_run(COLLECTION.get, where=)
- # return
- def run():
- global COLLECTION
- console = Console(color_system="standard", quiet=QUIET, width=180)
- logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[
- RichHandler(console=console, enable_link_path=False)
- ])
- log.info("Loading collection: " + COLLECTION_NAME)
- sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction(
- model_name="hkunlp/instructor-large", device="cuda")
- client = chromadb.PersistentClient(
- path=DATABASE_DIR,
- settings=Settings(
- anonymized_telemetry=False,
- )
- )
- COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef)
- log.info(f"Using web root: {WEBDATA_DIR.absolute()}")
- app = web.Application(logger=log)
- app.add_routes(routes)
- app.router.add_static('/css', WEBDATA_DIR / 'css')
- app.router.add_static('/fonts', WEBDATA_DIR / 'fonts')
- app.router.add_static('/img', WEBDATA_DIR / 'img')
- app.router.add_static('/js', WEBDATA_DIR / 'js')
- log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}')
- web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log,
- access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False)
- if __name__ == '__main__':
- run()
|