| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import asyncio
- import codecs
- import functools
- import io
- import time
- import chromadb
- from aiohttp import web
- from chromadb.config import Settings
- from chromadb.utils import embedding_functions
- from chromadb.api.models.Collection import Collection
- import logging
- from pathlib import Path
- from config import (
- DATABASE_DIR,
- HTTP_HOST,
- HTTP_PORT,
- QUIET,
- WEBDATA_DIR,
- COLLECTION_NAME,
- MODEL_NAME,
- TORCH_DEVICE,
- )
- from rich.logging import RichHandler
- from rich.console import Console
- from concurrent.futures import ThreadPoolExecutor
- import uuid
- import csv
- log = logging.getLogger(__name__)
- COLLECTION: Collection = None
- executor = ThreadPoolExecutor(1)
- routes = web.RouteTableDef()
- WEBDATA_DIR = Path(WEBDATA_DIR)
- async def async_run(func, /, *args, **kwargs):
- loop = asyncio.get_running_loop()
- task = loop.run_in_executor(executor, functools.partial(func, *args, **kwargs))
- results = await asyncio.gather(task)
- return results[0]
- @routes.get('/')
- async def serve_index(request):
- return web.FileResponse(WEBDATA_DIR / 'index.html')
- @routes.get('/favicon.ico')
- async def serve_favicon(request):
- return web.FileResponse(WEBDATA_DIR / 'favicon.ico')
- @routes.post('/api/v1/search')
- async def add_items(request):
- data = await request.json()
- results = await async_run(
- COLLECTION.query,
- query_texts=[data.get('text', '')],
- n_results=10,
- )
- items = [
- [
- results['ids'][0][i],
- results['distances'][0][i],
- results['documents'][0][i],
- results['metadatas'][0][i],
- ] for i in range(len(results['ids'][0]))
- ]
- return web.json_response(items)
- @routes.post('/api/v1/item')
- async def add_items(request):
- data = await request.json()
- ids = []
- metadatas = []
- documents = []
- body = data.get('body', '')
- source = data.get('source')
- search_sources = source == '#'
- if search_sources:
- source = ''
- divisions = body.split('\n\n')
- for items in divisions:
- lines = [line for line in items.split('\n') if line.strip()]
- if len(lines) == 0:
- continue
- if search_sources and lines[0].startswith('#'):
- source = lines[0].lstrip('#').strip()
- continue
- title = ''
- if len(lines) >= 2:
- title = lines[0]
- lines = lines[1:]
- else:
- pass
- doc = ' '.join(lines)
- doc = ' '.join(doc.split()).strip() # this just replaces extra whitespaces with one.
- if len(doc) == 0:
- continue
- ids.append(uuid.uuid4().hex)
- documents.append(doc)
- metadatas.append({'type': 'question', 'source': source, 'added': int(time.time()), 'title': title})
- await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
- return web.json_response({'count': await async_run(COLLECTION.count)})
- @routes.delete('/api/v1/item')
- async def delete_item(request: web.BaseRequest):
- await async_run(COLLECTION.delete, ids=[request.query.get('id')])
- return
- @routes.get('/api/v1/database')
- async def get_database(request: web.BaseRequest):
- results = await async_run(COLLECTION.get)
- buffer = io.StringIO()
- writer = csv.writer(buffer)
- writer.writerow(['ID', 'Source', 'Title', 'Document'])
- for i in range(len(results['ids'])):
- writer.writerow([
- results["ids"][i], results['metadatas'][i]['source'],
- results["metadatas"][i]["title"], results['documents'][i]
- ])
- return web.Response(text=buffer.getvalue(), status=200, content_type='text/csv',
- headers={'Content-Disposition': 'attachment; filename=ipt_questions.csv'})
- @routes.post('/api/v1/database')
- async def get_database(request: web.BaseRequest):
- data = await request.post()
- ids = []
- documents = []
- metadatas = []
- body = codecs.getreader('utf-8')(data['database'].file)
- reader = csv.reader(body)
- header = next(reader)
- if ','.join(header) != 'ID,Source,Title,Document':
- return web.Response(text='Invalid CSV header', status=400)
- for row in reader:
- ids.append(row[0])
- documents.append(row[3])
- metadatas.append({"type": "question", "source": row[1], "added": int(time.time()), "title": row[2]})
- await async_run(COLLECTION.upsert, ids=ids, metadatas=metadatas, documents=documents)
- return web.json_response({'count': await async_run(COLLECTION.count)})
- @routes.put('/api/v1/item')
- async def edit_item(request: web.BaseRequest):
- data = await request.json()
- await async_run(
- COLLECTION.update,
- ids=[request.query.get('id')],
- metadatas=[data.get('metadata', {})],
- documents=[data.get('document', {})]
- )
- return
- def run():
- global COLLECTION
- console = Console(color_system="standard", quiet=QUIET, width=180)
- logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[
- RichHandler(console=console, enable_link_path=False)
- ])
- log.info("Loading collection: " + COLLECTION_NAME)
- sentence_transformer_ef = embedding_functions.InstructorEmbeddingFunction(
- model_name=MODEL_NAME, device=TORCH_DEVICE)
- client = chromadb.PersistentClient(
- path=DATABASE_DIR,
- settings=Settings(
- anonymized_telemetry=False,
- )
- )
- COLLECTION = client.get_or_create_collection(name=COLLECTION_NAME, embedding_function=sentence_transformer_ef)
- log.info(f"Using web root: {WEBDATA_DIR.absolute()}")
- app = web.Application(logger=log)
- app.add_routes(routes)
- app.router.add_static('/css', WEBDATA_DIR / 'css')
- app.router.add_static('/fonts', WEBDATA_DIR / 'fonts')
- app.router.add_static('/img', WEBDATA_DIR / 'img')
- app.router.add_static('/js', WEBDATA_DIR / 'js')
- log.info(f'Serving http at http://{HTTP_HOST}:{HTTP_PORT}')
- web.run_app(app, host=HTTP_HOST, port=HTTP_PORT, access_log=log,
- access_log_format='%a "%r" %s %b "%{Referer}i" "%{User-Agent}i"', print=False)
- if __name__ == '__main__':
- run()
|