import asyncio import codecs import functools import io import time import chromadb from aiohttp import web from chromadb.config import Settings from chromadb.utils import embedding_functions from chromadb.api.models.Collection import Collection import logging from pathlib import Path from config import ( DATABASE_DIR, HTTP_HOST, HTTP_PORT, QUIET, WEBDATA_DIR, COLLECTION_NAME, MODEL_NAME, TORCH_DEVICE, ) from rich.logging import RichHandler from rich.console import Console from concurrent.futures import ThreadPoolExecutor import uuid import csv log = logging.getLogger(__name__) COLLECTION: Collection = None executor = ThreadPoolExecutor(1) routes = web.RouteTableDef() WEBDATA_DIR = Path(WEBDATA_DIR) async def async_run(func, /, *args, **kwargs): loop = asyncio.get_running_loop() task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs)) results = await asyncio.gather(task) return results[0] @routes.get('/') async def serve_index(request): return web.FileResponse(WEBDATA_DIR / 'index.html') @routes.get('/favicon.ico') async def serve_favicon(request): return web.FileResponse(WEBDATA_DIR / 'favicon.ico') @routes.post('/api/v1/search') async def add_items(request): data = await request.json() results = await async_run( COLLECTION.query, query_texts=[data.get('text', '')], n_results=10, ) items = [ [ results['ids'][0][i], results['distances'][0][i], results['documents'][0][i], results['metadatas'][0][i], ] for i in range(len(results['ids'][0])) ] return web.json_response(items) @routes.post('/api/v1/item') async def add_items(request): data = await request.json() ids = [] metadatas = [] documents = [] body = data.get('body', '') source = data.get('source') search_sources = source == '#' if search_sources: source = '' divisions = body.split('\n\n') for items in divisions: lines = [line for line in items.split('\n') if line.strip()] if len(lines) == 0: continue if search_sources and lines[0].startswith('#'): source = lines[0].lstrip('#').strip() continue title = '' if len(lines) >= 2: title = lines[0] lines = lines[1:] else: pass doc = ' '.join(lines) doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one. if len(doc) == 0: continue ids.append(uuid.uuid4().hex) documents.append(doc) metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title}) await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents) return web.json_response({'count': await async_run(COLLECTION.count)}) @routes.delete('/api/v1/item') async def delete_item(request: web.BaseRequest): await async_run(COLLECTION.delete, ids=[request.query.get('id')]) return @routes.get('/api/v1/database') async def get_database(request: web.BaseRequest): results = await async_run(COLLECTION.get) buffer = io.StringIO() writer = csv.writer(buffer) writer.writerow(['ID', 'Source', 'Title', 'Document']) for i in range(len(results['ids'])): writer.writerow([ results["ids"][i], results['metadatas'][i]['source'], results["metadatas"][i]["title"], results['documents'][i] ]) return web.Response(text=buffer.getvalue(), status=200, content_type='text/csv', headers={'Content-Disposition': 'attachment; filename=ipt_questions.csv'}) @routes.post('/api/v1/database') async def get_database(request: web.BaseRequest): data = await request.post() ids = [] documents = [] metadatas = [] body = codecs.getreader('utf-8')(data['database'].file) reader = csv.reader(body) header = next(reader) if ','.join(header) != 'ID,Source,Title,Document': return web.Response(text='Invalid CSV header', status=400) for row in reader: ids.append(row[0]) documents.append(row[3]) metadatas.append({"type": "question", "source": row[1], "added": int(time.time()), "title": row[2]}) await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents) return web.json_response({'count': await async_run(COLLECTION.count)}) @routes.put('/api/v1/item') async def edit_item(request: web.BaseRequest): data = await request.json() await async_run( COLLECTION.update, ids=[request.query.get('id')], metadatas=[data.get('metadata', {})], documents=[data.get('document', {})] ) return def run(): global COLLECTION console = Console(color_system="standard", quiet=QUIET, width=180) logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[ RichHandler(console=console, enable_link_path=False) ]) log.info("Loading collection: " + COLLECTION_NAME) sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction( model_name=MODEL_NAME, device=TORCH_DEVICE) client = chromadb.PersistentClient( path=DATABASE_DIR, settings=Settings( anonymized_telemetry=False, ) ) COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef) log.info(f"Using web root: {WEBDATA_DIR.absolute()}") app = web.Application(logger=log) app.add_routes(routes) app.router.add_static('/css', WEBDATA_DIR / 'css') app.router.add_static('/fonts', WEBDATA_DIR / 'fonts') app.router.add_static('/img', WEBDATA_DIR / 'img') app.router.add_static('/js', WEBDATA_DIR / 'js') log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}') web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log, access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False) if __name__ == '__main__': run()