|
|
@@ -0,0 +1,151 @@
|
|
|
+import asyncio
|
|
|
+import functools
|
|
|
+import time
|
|
|
+
|
|
|
+import chromadb
|
|
|
+from aiohttp import web
|
|
|
+from chromadb.config import Settings
|
|
|
+from chromadb.utils import embedding_functions
|
|
|
+from chromadb.api.models.Collection import Collection
|
|
|
+import logging
|
|
|
+from pathlib import Path
|
|
|
+from config import DATABASE_DIR, HTTP_HOST, HTTP_PORT, QUIET, WEBDATA_DIR, COLLECTION_NAME
|
|
|
+from rich.logging import RichHandler
|
|
|
+from rich.console import Console
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+import uuid
|
|
|
+
|
|
|
+
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+COLLECTION: Collection = None
|
|
|
+executor = ThreadPoolExecutor(1)
|
|
|
+routes = web.RouteTableDef()
|
|
|
+WEBDATA_DIR = Path(WEBDATA_DIR)
|
|
|
+
|
|
|
+async def async_run(func, /, *args, **kwargs):
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs))
|
|
|
+ results = await asyncio.gather(task)
|
|
|
+ return results[0]
|
|
|
+
|
|
|
+
|
|
|
+@routes.get('/')
|
|
|
+async def serve_index(request):
|
|
|
+ return web.FileResponse(WEBDATA_DIR / 'index.html')
|
|
|
+
|
|
|
+@routes.get('/favicon.ico')
|
|
|
+async def serve_favicon(request):
|
|
|
+ return web.FileResponse(WEBDATA_DIR / 'favicon.ico')
|
|
|
+
|
|
|
+
|
|
|
+@routes.post('/api/v1/search')
|
|
|
+async def add_items(request):
|
|
|
+ data = await request.json()
|
|
|
+ results = await async_run(
|
|
|
+ COLLECTION.query,
|
|
|
+ query_texts=[data.get('text', '')],
|
|
|
+ n_results=10,
|
|
|
+ )
|
|
|
+ items = [
|
|
|
+ [
|
|
|
+ results['ids'][0][i],
|
|
|
+ results['distances'][0][i],
|
|
|
+ results['documents'][0][i],
|
|
|
+ results['metadatas'][0][i],
|
|
|
+ ] for i in range(len(results['ids'][0]))
|
|
|
+ ]
|
|
|
+ return web.json_response(items)
|
|
|
+
|
|
|
+
|
|
|
+@routes.post('/api/v1/item')
|
|
|
+async def add_items(request):
|
|
|
+ data = await request.json()
|
|
|
+ ids = []
|
|
|
+ metadatas = []
|
|
|
+ documents = []
|
|
|
+ body = data.get('body', '')
|
|
|
+ source = data.get('source')
|
|
|
+ search_sources = source == '#'
|
|
|
+ if search_sources:
|
|
|
+ source = ''
|
|
|
+
|
|
|
+ divisions = body.split('\n\n')
|
|
|
+ for items in divisions:
|
|
|
+ lines = [line for line in items.split('\n') if line.strip()]
|
|
|
+ if len(lines) == 0:
|
|
|
+ continue
|
|
|
+ if search_sources and lines[0].startswith('#'):
|
|
|
+ source = lines[0].lstrip('#').strip()
|
|
|
+ continue
|
|
|
+ title = ''
|
|
|
+ if len(lines) >= 2:
|
|
|
+ title = lines[0]
|
|
|
+ lines = lines[1:]
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+ doc = ' '.join(lines)
|
|
|
+ doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one.
|
|
|
+ if len(doc) == 0:
|
|
|
+ continue
|
|
|
+ ids.append(uuid.uuid4().hex)
|
|
|
+ documents.append(doc)
|
|
|
+ metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title})
|
|
|
+ await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
|
|
|
+ return web.json_response({'count': await async_run(COLLECTION.count)})
|
|
|
+
|
|
|
+
|
|
|
+@routes.delete('/api/v1/item')
|
|
|
+async def delete_item(request: web.BaseRequest):
|
|
|
+ await async_run(COLLECTION.delete, ids=[request.query.get('id')])
|
|
|
+ return
|
|
|
+
|
|
|
+
|
|
|
+@routes.put('/api/v1/item')
|
|
|
+async def edit_item(request: web.BaseRequest):
|
|
|
+ data = await request.json()
|
|
|
+ await async_run(
|
|
|
+ COLLECTION.update,
|
|
|
+ ids=[request.query.get('id')],
|
|
|
+ metadatas=[data.get('metadata', {})],
|
|
|
+ documents=[data.get('document', {})]
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+# @routes.get('/api/v1/sources')
|
|
|
+# async def get_sources(request: web.BaseRequest):
|
|
|
+# await async_run(COLLECTION.get, where=)
|
|
|
+# return
|
|
|
+
|
|
|
+
|
|
|
+def run():
|
|
|
+ global COLLECTION
|
|
|
+ console = Console(color_system="standard", quiet=QUIET, width=180)
|
|
|
+ logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[
|
|
|
+ RichHandler(console=console, enable_link_path=False)
|
|
|
+ ])
|
|
|
+
|
|
|
+ log.info("Loading collection: " + COLLECTION_NAME)
|
|
|
+ sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction(
|
|
|
+ model_name="hkunlp/instructor-large", device="cuda")
|
|
|
+ client = chromadb.PersistentClient(
|
|
|
+ path=DATABASE_DIR,
|
|
|
+ settings=Settings(
|
|
|
+ anonymized_telemetry=False,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef)
|
|
|
+ log.info(f"Using web root: {WEBDATA_DIR.absolute()}")
|
|
|
+ app = web.Application(logger=log)
|
|
|
+ app.add_routes(routes)
|
|
|
+ app.router.add_static('/css', WEBDATA_DIR / 'css')
|
|
|
+ app.router.add_static('/fonts', WEBDATA_DIR / 'fonts')
|
|
|
+ app.router.add_static('/img', WEBDATA_DIR / 'img')
|
|
|
+ app.router.add_static('/js', WEBDATA_DIR / 'js')
|
|
|
+
|
|
|
+ log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}')
|
|
|
+ web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log,
|
|
|
+ access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ run()
|