[fix] lower on query, [add] metadata response, [add] context distance & reference links

This commit is contained in:
Evan Reichard 2023-10-19 18:56:48 -04:00
parent 05c5546c10
commit 40daf46c03
7 changed files with 174 additions and 77 deletions

View File

@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
To normalize data, you can use Minyma's `normalize` CLI command: To normalize data, you can use Minyma's `normalize` CLI command:
```bash ```bash
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma minyma normalize \
--filename ./pubmed_manuscripts.jsonl \
--normalizer pubmed \
--database chroma \
--datapath ./chroma
``` ```
The above example does the following: The above example does the following:

View File

@ -1,8 +1,8 @@
from os import path
import click import click
import signal import signal
import sys import sys
from importlib.metadata import version from importlib.metadata import version
from minyma.config import Config
from minyma.oai import OpenAIConnector from minyma.oai import OpenAIConnector
from minyma.vdb import ChromaDB from minyma.vdb import ChromaDB
from flask import Flask from flask import Flask
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
def create_app(): def create_app():
global oai, cdb global oai, vdb
from minyma.config import Config
import minyma.api.common as api_common import minyma.api.common as api_common
import minyma.api.v1 as api_v1 import minyma.api.v1 as api_v1
app = Flask(__name__) app = Flask(__name__)
cdb = ChromaDB(Config.DATA_PATH) vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb) oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
app.register_blueprint(api_common.bp) app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp) app.register_blueprint(api_v1.bp)

View File

@ -17,8 +17,28 @@ def get_response():
if message == "": if message == "":
return {"error": "Empty Message"} return {"error": "Empty Message"}
oai_response = minyma.oai.query(message) resp = minyma.oai.query(message)
return oai_response
# Derive LLM Data
llm_resp = resp.get("llm", {})
llm_choices = llm_resp.get("choices", [])
# Derive VDB Data
vdb_resp = resp.get("vdb", {})
combined_context = [{
"id": vdb_resp.get("ids")[i],
"distance": vdb_resp.get("distances")[i],
"doc": vdb_resp.get("docs")[i],
"metadata": vdb_resp.get("metadatas")[i],
} for i, _ in enumerate(vdb_resp.get("docs", []))]
# Return Data
return {
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
"context": combined_context,
"usage": llm_resp.get("usage"),
}
""" """
@ -34,5 +54,5 @@ def get_related():
if message == "": if message == "":
return {"error": "Empty Message"} return {"error": "Empty Message"}
related_documents = minyma.cdb.get_related(message) related_documents = minyma.vdb.get_related(message)
return related_documents return related_documents

View File

@ -1,12 +1,16 @@
from io import TextIOWrapper from io import TextIOWrapper
import json import json
class DataNormalizer: class DataNormalizer():
def __init__(self, file: TextIOWrapper): def __init__(self, file: TextIOWrapper):
pass self.file = file
def __len__(self) -> int:
return 0
def __iter__(self): def __iter__(self):
pass yield None
class PubMedNormalizer(DataNormalizer): class PubMedNormalizer(DataNormalizer):
""" """
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
normalized inside the iterator. normalized inside the iterator.
""" """
def __init__(self, file: TextIOWrapper): def __init__(self, file: TextIOWrapper):
self.file = file self.file = file
self.length = 0
def __len__(self):
last_pos = self.file.tell()
self.length = sum(1 for _ in self.file)
self.file.seek(last_pos)
return self.length
def __iter__(self): def __iter__(self):
count = 0 count = 0
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
count += 1 count += 1
# ID = Line Number # ID = Line Number
yield { "doc": norm_text, "id": str(count - 1) } yield {
"id": str(count - 1),
"doc": norm_text,
"metadata": {
"file": l.get("file")
},
}

View File

@ -5,8 +5,8 @@ from minyma.vdb import VectorDB
# Stolen LangChain Prompt # Stolen LangChain Prompt
PROMPT_TEMPLATE = """ PROMPT_TEMPLATE = """
Use the following pieces of context to answer the question at the end. Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to If you don't know the answer, just say that you don't know, don't try to
make up an answer. make up an answer.
{context} {context}
@ -19,6 +19,7 @@ class OpenAIConnector:
def __init__(self, api_key: str, vdb: VectorDB): def __init__(self, api_key: str, vdb: VectorDB):
self.vdb = vdb self.vdb = vdb
self.model = "gpt-3.5-turbo" self.model = "gpt-3.5-turbo"
self.word_cap = 1000
openai.api_key = api_key openai.api_key = api_key
def query(self, question: str) -> Any: def query(self, question: str) -> Any:
@ -30,8 +31,9 @@ class OpenAIConnector:
if len(all_docs) == 0: if len(all_docs) == 0:
return { "error": "No Context Found" } return { "error": "No Context Found" }
# Join on new line, generate main prompt # Join on new line (cap @ word limit), generate main prompt
context = '\n'.join(all_docs) reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
context = '\n'.join(reduced_docs)
prompt = PROMPT_TEMPLATE.format(context = context, question = question) prompt = PROMPT_TEMPLATE.format(context = context, question = question)
# Query OpenAI ChatCompletion # Query OpenAI ChatCompletion
@ -41,4 +43,4 @@ class OpenAIConnector:
) )
# Return Response # Return Response
return response return { "llm": response, "vdb": related }

View File

@ -72,41 +72,41 @@
</main> </main>
<script> <script>
const LOADING_SVG = `<svg const LOADING_SVG = `<svg
width="24" width="24"
height="24" height="24"
viewBox="0 0 24 24" viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
fill="currentColor" fill="currentColor"
> >
<style> <style>
.spinner_qM83 { .spinner_qM83 {
animation: spinner_8HQG 1.05s infinite; animation: spinner_8HQG 1.05s infinite;
} }
.spinner_oXPr { .spinner_oXPr {
animation-delay: 0.1s; animation-delay: 0.1s;
} }
.spinner_ZTLf { .spinner_ZTLf {
animation-delay: 0.2s; animation-delay: 0.2s;
} }
@keyframes spinner_8HQG { @keyframes spinner_8HQG {
0%, 0%,
57.14% { 57.14% {
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1); animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
transform: translate(0); transform: translate(0);
} }
28.57% { 28.57% {
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33); animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
transform: translateY(-6px); transform: translateY(-6px);
} }
100% { 100% {
transform: translate(0); transform: translate(0);
} }
} }
</style> </style>
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle> <circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle> <circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle> <circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
</svg>`; </svg>`;
/** /**
* Wrapper API Call * Wrapper API Call
@ -125,9 +125,9 @@
// Wrapping Element // Wrapping Element
let wrapEl = document.createElement("div"); let wrapEl = document.createElement("div");
wrapEl.innerHTML = `<div class="flex"> wrapEl.innerHTML = `<div class="flex">
<span class="font-bold w-24 grow-0 shrink-0"></span> <span class="font-bold w-24 grow-0 shrink-0"></span>
<span class="whitespace-break-spaces w-full"></span> <span class="whitespace-break-spaces w-full"></span>
</div>`; </div>`;
// Get Elements // Get Elements
let nameEl = wrapEl.querySelector("span"); let nameEl = wrapEl.querySelector("span");
@ -158,7 +158,63 @@
}) })
.then((data) => { .then((data) => {
console.log("SUCCESS:", data); console.log("SUCCESS:", data);
content.innerText = data.choices[0].message.content;
// Create Response Element
let responseEl = document.createElement("p");
responseEl.setAttribute(
"class",
"whitespace-break-spaces border-b pb-3 mb-3"
);
responseEl.innerText = data.response;
// Create Context Element
let contextEl = document.createElement("div");
contextEl.innerHTML = `
<h1 class="font-bold">Context:</h1>
<ul class="list-disc ml-6"></ul>`;
let ulEl = contextEl.querySelector("ul");
// Create Context Links
data.context
// Capture PubMed ID & Distance
.map((item) => [
item.metadata.file.match("\/(.*)\.txt$"),
item.distance,
])
// Filter Non-Matches
.filter(([match]) => match)
// Get Match Value & Round Distance (2)
.map(([match, distance]) => [
match[1],
Math.round(distance * 100) / 100,
])
// Create Links
.forEach(([pmid, distance]) => {
let newEl = document.createElement("li");
let linkEl = document.createElement("a");
linkEl.setAttribute("target", "_blank");
linkEl.setAttribute(
"class",
"text-blue-500 hover:text-blue-600"
);
linkEl.setAttribute(
"href",
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
);
linkEl.textContent = "[" + distance + "] " + pmid;
newEl.append(linkEl);
ulEl.append(newEl);
});
// Add to DOM
content.setAttribute("class", "w-full");
content.innerHTML = "";
content.append(responseEl);
content.append(contextEl);
}) })
.catch((e) => { .catch((e) => {
console.log("ERROR:", e); console.log("ERROR:", e);

View File

@ -1,7 +1,6 @@
from chromadb.api import API from chromadb.api import API
from itertools import islice from itertools import islice
from os import path from tqdm import tqdm
from tqdm.auto import tqdm
from typing import Any, cast from typing import Any, cast
import chromadb import chromadb
@ -29,47 +28,45 @@ class VectorDB:
ChromaDV VectorDB Type ChromaDV VectorDB Type
""" """
class ChromaDB(VectorDB): class ChromaDB(VectorDB):
def __init__(self, base_path: str): def __init__(self, path: str):
chroma_path = path.join(base_path, "chroma") self.client: API = chromadb.PersistentClient(path=path)
self.client: API = chromadb.PersistentClient(path=chroma_path) self.word_cap = 2500
self.word_limit = 1000
self.collection_name: str = "vdb" self.collection_name: str = "vdb"
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True) self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
def get_related(self, question) -> Any: def get_related(self, question: str) -> Any:
"""Returns line separated related docs""" """Returns line separated related docs"""
results = self.collection.query( results = self.collection.query(
query_texts=[question], query_texts=[question.lower()],
n_results=2 n_results=2
) )
all_docs: list = cast(list, results.get("documents", [[]]))[0] all_docs: list = cast(list, results.get("documents", [[]]))[0]
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
all_distances: list = cast(list, results.get("distances", [[]]))[0] all_distances: list = cast(list, results.get("distances", [[]]))[0]
all_ids: list = cast(list, results.get("ids", [[]]))[0] all_ids: list = cast(list, results.get("ids", [[]]))[0]
return { return {
"distances":all_distances, "distances": all_distances,
"metadatas": all_metadata,
"docs": all_docs, "docs": all_docs,
"ids": all_ids "ids": all_ids
} }
def load_documents(self, normalizer: DataNormalizer): def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
# 10 Item Chunking length = len(normalizer) / chunk_size
for items in tqdm(chunk(normalizer, 50)): for items in tqdm(chunk(normalizer, chunk_size), total=length):
ids = [] ids = []
documents = [] documents = []
metadatas = []
# Limit words per document to accommodate context token limits
for item in items: for item in items:
doc = " ".join(item.get("doc").split()[:self.word_limit]) documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
documents.append(doc)
ids.append(item.get("id")) ids.append(item.get("id"))
metadatas.append(item.get("metadata", {}))
# Ideally we parse out metadata from each document
# and pass to the metadata kwarg. However, each
# document appears to have a slightly different format,
# so it's difficult to parse out.
self.collection.add( self.collection.add(
ids=ids,
documents=documents, documents=documents,
ids=ids metadatas=metadatas,
) )