From 40daf46c03f4d841370aa8ce28cb8e4208e8f02a Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Thu, 19 Oct 2023 18:56:48 -0400 Subject: [PATCH] [fix] lower on query, [add] metadata response, [add] context distance & reference links --- README.md | 6 +- minyma/__init__.py | 9 +-- minyma/api/v1.py | 26 ++++++- minyma/normalizer.py | 27 ++++++-- minyma/oai.py | 12 ++-- minyma/templates/index.html | 134 +++++++++++++++++++++++++----------- minyma/vdb.py | 37 +++++----- 7 files changed, 174 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 7b5c2c9..9ee04d5 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer. To normalize data, you can use Minyma's `normalize` CLI command: ```bash -minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma +minyma normalize \ + --filename ./pubmed_manuscripts.jsonl \ + --normalizer pubmed \ + --database chroma \ + --datapath ./chroma ``` The above example does the following: diff --git a/minyma/__init__.py b/minyma/__init__.py index b85bf5f..44b388d 100644 --- a/minyma/__init__.py +++ b/minyma/__init__.py @@ -1,8 +1,8 @@ +from os import path import click import signal import sys from importlib.metadata import version -from minyma.config import Config from minyma.oai import OpenAIConnector from minyma.vdb import ChromaDB from flask import Flask @@ -15,14 +15,15 @@ def signal_handler(sig, frame): def create_app(): - global oai, cdb + global oai, vdb + from minyma.config import Config import minyma.api.common as api_common import minyma.api.v1 as api_v1 app = Flask(__name__) - cdb = ChromaDB(Config.DATA_PATH) - oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb) + vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma")) + oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb) app.register_blueprint(api_common.bp) app.register_blueprint(api_v1.bp) diff --git a/minyma/api/v1.py b/minyma/api/v1.py index 4edb3a2..18a1eb9 100644 --- a/minyma/api/v1.py +++ b/minyma/api/v1.py @@ -17,8 +17,28 @@ def get_response(): if message == "": return {"error": "Empty Message"} - oai_response = minyma.oai.query(message) - return oai_response + resp = minyma.oai.query(message) + + # Derive LLM Data + llm_resp = resp.get("llm", {}) + llm_choices = llm_resp.get("choices", []) + + # Derive VDB Data + vdb_resp = resp.get("vdb", {}) + combined_context = [{ + "id": vdb_resp.get("ids")[i], + "distance": vdb_resp.get("distances")[i], + "doc": vdb_resp.get("docs")[i], + "metadata": vdb_resp.get("metadatas")[i], + } for i, _ in enumerate(vdb_resp.get("docs", []))] + + # Return Data + return { + "response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"), + "context": combined_context, + "usage": llm_resp.get("usage"), + } + """ @@ -34,5 +54,5 @@ def get_related(): if message == "": return {"error": "Empty Message"} - related_documents = minyma.cdb.get_related(message) + related_documents = minyma.vdb.get_related(message) return related_documents diff --git a/minyma/normalizer.py b/minyma/normalizer.py index 42de7b0..2d6b2b6 100644 --- a/minyma/normalizer.py +++ b/minyma/normalizer.py @@ -1,12 +1,16 @@ from io import TextIOWrapper import json -class DataNormalizer: +class DataNormalizer(): def __init__(self, file: TextIOWrapper): - pass + self.file = file + + def __len__(self) -> int: + return 0 def __iter__(self): - pass + yield None + class PubMedNormalizer(DataNormalizer): """ @@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer): normalized inside the iterator. """ def __init__(self, file: TextIOWrapper): - self.file = file + self.file = file + self.length = 0 + + def __len__(self): + last_pos = self.file.tell() + self.length = sum(1 for _ in self.file) + self.file.seek(last_pos) + return self.length def __iter__(self): count = 0 @@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer): count += 1 # ID = Line Number - yield { "doc": norm_text, "id": str(count - 1) } + yield { + "id": str(count - 1), + "doc": norm_text, + "metadata": { + "file": l.get("file") + }, + } diff --git a/minyma/oai.py b/minyma/oai.py index 1c3fc13..b47f6aa 100644 --- a/minyma/oai.py +++ b/minyma/oai.py @@ -5,8 +5,8 @@ from minyma.vdb import VectorDB # Stolen LangChain Prompt PROMPT_TEMPLATE = """ -Use the following pieces of context to answer the question at the end. -If you don't know the answer, just say that you don't know, don't try to +Use the following pieces of context to answer the question at the end. +If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} @@ -19,6 +19,7 @@ class OpenAIConnector: def __init__(self, api_key: str, vdb: VectorDB): self.vdb = vdb self.model = "gpt-3.5-turbo" + self.word_cap = 1000 openai.api_key = api_key def query(self, question: str) -> Any: @@ -30,8 +31,9 @@ class OpenAIConnector: if len(all_docs) == 0: return { "error": "No Context Found" } - # Join on new line, generate main prompt - context = '\n'.join(all_docs) + # Join on new line (cap @ word limit), generate main prompt + reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs)) + context = '\n'.join(reduced_docs) prompt = PROMPT_TEMPLATE.format(context = context, question = question) # Query OpenAI ChatCompletion @@ -41,4 +43,4 @@ class OpenAIConnector: ) # Return Response - return response + return { "llm": response, "vdb": related } diff --git a/minyma/templates/index.html b/minyma/templates/index.html index d6bdc8f..8c1bd96 100644 --- a/minyma/templates/index.html +++ b/minyma/templates/index.html @@ -72,41 +72,41 @@