[fix] lower on query, [add] metadata response, [add] context distance & reference links
This commit is contained in:
parent
05c5546c10
commit
40daf46c03
@ -43,7 +43,11 @@ and the only supported normalizer is the `pubmed` normalizer.
|
||||
To normalize data, you can use Minyma's `normalize` CLI command:
|
||||
|
||||
```bash
|
||||
minyma normalize --filename ./pubmed_manuscripts.jsonl --normalizer pubmed --database chroma --datapath ./chroma
|
||||
minyma normalize \
|
||||
--filename ./pubmed_manuscripts.jsonl \
|
||||
--normalizer pubmed \
|
||||
--database chroma \
|
||||
--datapath ./chroma
|
||||
```
|
||||
|
||||
The above example does the following:
|
||||
|
@ -1,8 +1,8 @@
|
||||
from os import path
|
||||
import click
|
||||
import signal
|
||||
import sys
|
||||
from importlib.metadata import version
|
||||
from minyma.config import Config
|
||||
from minyma.oai import OpenAIConnector
|
||||
from minyma.vdb import ChromaDB
|
||||
from flask import Flask
|
||||
@ -15,14 +15,15 @@ def signal_handler(sig, frame):
|
||||
|
||||
|
||||
def create_app():
|
||||
global oai, cdb
|
||||
global oai, vdb
|
||||
|
||||
from minyma.config import Config
|
||||
import minyma.api.common as api_common
|
||||
import minyma.api.v1 as api_v1
|
||||
|
||||
app = Flask(__name__)
|
||||
cdb = ChromaDB(Config.DATA_PATH)
|
||||
oai = OpenAIConnector(Config.OPENAI_API_KEY, cdb)
|
||||
vdb = ChromaDB(path.join(Config.DATA_PATH, "chroma"))
|
||||
oai = OpenAIConnector(Config.OPENAI_API_KEY, vdb)
|
||||
|
||||
app.register_blueprint(api_common.bp)
|
||||
app.register_blueprint(api_v1.bp)
|
||||
|
@ -17,8 +17,28 @@ def get_response():
|
||||
if message == "":
|
||||
return {"error": "Empty Message"}
|
||||
|
||||
oai_response = minyma.oai.query(message)
|
||||
return oai_response
|
||||
resp = minyma.oai.query(message)
|
||||
|
||||
# Derive LLM Data
|
||||
llm_resp = resp.get("llm", {})
|
||||
llm_choices = llm_resp.get("choices", [])
|
||||
|
||||
# Derive VDB Data
|
||||
vdb_resp = resp.get("vdb", {})
|
||||
combined_context = [{
|
||||
"id": vdb_resp.get("ids")[i],
|
||||
"distance": vdb_resp.get("distances")[i],
|
||||
"doc": vdb_resp.get("docs")[i],
|
||||
"metadata": vdb_resp.get("metadatas")[i],
|
||||
} for i, _ in enumerate(vdb_resp.get("docs", []))]
|
||||
|
||||
# Return Data
|
||||
return {
|
||||
"response": None if len(llm_choices) == 0 else llm_choices[0].get("message", {}).get("content"),
|
||||
"context": combined_context,
|
||||
"usage": llm_resp.get("usage"),
|
||||
}
|
||||
|
||||
|
||||
|
||||
"""
|
||||
@ -34,5 +54,5 @@ def get_related():
|
||||
if message == "":
|
||||
return {"error": "Empty Message"}
|
||||
|
||||
related_documents = minyma.cdb.get_related(message)
|
||||
related_documents = minyma.vdb.get_related(message)
|
||||
return related_documents
|
||||
|
@ -1,12 +1,16 @@
|
||||
from io import TextIOWrapper
|
||||
import json
|
||||
|
||||
class DataNormalizer:
|
||||
class DataNormalizer():
|
||||
def __init__(self, file: TextIOWrapper):
|
||||
pass
|
||||
self.file = file
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 0
|
||||
|
||||
def __iter__(self):
|
||||
pass
|
||||
yield None
|
||||
|
||||
|
||||
class PubMedNormalizer(DataNormalizer):
|
||||
"""
|
||||
@ -14,7 +18,14 @@ class PubMedNormalizer(DataNormalizer):
|
||||
normalized inside the iterator.
|
||||
"""
|
||||
def __init__(self, file: TextIOWrapper):
|
||||
self.file = file
|
||||
self.file = file
|
||||
self.length = 0
|
||||
|
||||
def __len__(self):
|
||||
last_pos = self.file.tell()
|
||||
self.length = sum(1 for _ in self.file)
|
||||
self.file.seek(last_pos)
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
count = 0
|
||||
@ -42,4 +53,10 @@ class PubMedNormalizer(DataNormalizer):
|
||||
count += 1
|
||||
|
||||
# ID = Line Number
|
||||
yield { "doc": norm_text, "id": str(count - 1) }
|
||||
yield {
|
||||
"id": str(count - 1),
|
||||
"doc": norm_text,
|
||||
"metadata": {
|
||||
"file": l.get("file")
|
||||
},
|
||||
}
|
||||
|
@ -5,8 +5,8 @@ from minyma.vdb import VectorDB
|
||||
|
||||
# Stolen LangChain Prompt
|
||||
PROMPT_TEMPLATE = """
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to
|
||||
Use the following pieces of context to answer the question at the end.
|
||||
If you don't know the answer, just say that you don't know, don't try to
|
||||
make up an answer.
|
||||
|
||||
{context}
|
||||
@ -19,6 +19,7 @@ class OpenAIConnector:
|
||||
def __init__(self, api_key: str, vdb: VectorDB):
|
||||
self.vdb = vdb
|
||||
self.model = "gpt-3.5-turbo"
|
||||
self.word_cap = 1000
|
||||
openai.api_key = api_key
|
||||
|
||||
def query(self, question: str) -> Any:
|
||||
@ -30,8 +31,9 @@ class OpenAIConnector:
|
||||
if len(all_docs) == 0:
|
||||
return { "error": "No Context Found" }
|
||||
|
||||
# Join on new line, generate main prompt
|
||||
context = '\n'.join(all_docs)
|
||||
# Join on new line (cap @ word limit), generate main prompt
|
||||
reduced_docs = list(map(lambda x: " ".join(x.split()[:self.word_cap]), all_docs))
|
||||
context = '\n'.join(reduced_docs)
|
||||
prompt = PROMPT_TEMPLATE.format(context = context, question = question)
|
||||
|
||||
# Query OpenAI ChatCompletion
|
||||
@ -41,4 +43,4 @@ class OpenAIConnector:
|
||||
)
|
||||
|
||||
# Return Response
|
||||
return response
|
||||
return { "llm": response, "vdb": related }
|
||||
|
@ -72,41 +72,41 @@
|
||||
</main>
|
||||
<script>
|
||||
const LOADING_SVG = `<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="currentColor"
|
||||
>
|
||||
<style>
|
||||
.spinner_qM83 {
|
||||
animation: spinner_8HQG 1.05s infinite;
|
||||
}
|
||||
.spinner_oXPr {
|
||||
animation-delay: 0.1s;
|
||||
}
|
||||
.spinner_ZTLf {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
@keyframes spinner_8HQG {
|
||||
0%,
|
||||
57.14% {
|
||||
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
|
||||
transform: translate(0);
|
||||
}
|
||||
28.57% {
|
||||
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
100% {
|
||||
transform: translate(0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
|
||||
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
|
||||
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
|
||||
</svg>`;
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="currentColor"
|
||||
>
|
||||
<style>
|
||||
.spinner_qM83 {
|
||||
animation: spinner_8HQG 1.05s infinite;
|
||||
}
|
||||
.spinner_oXPr {
|
||||
animation-delay: 0.1s;
|
||||
}
|
||||
.spinner_ZTLf {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
@keyframes spinner_8HQG {
|
||||
0%,
|
||||
57.14% {
|
||||
animation-timing-function: cubic-bezier(0.33, 0.66, 0.66, 1);
|
||||
transform: translate(0);
|
||||
}
|
||||
28.57% {
|
||||
animation-timing-function: cubic-bezier(0.33, 0, 0.66, 0.33);
|
||||
transform: translateY(-6px);
|
||||
}
|
||||
100% {
|
||||
transform: translate(0);
|
||||
}
|
||||
}
|
||||
</style>
|
||||
<circle class="spinner_qM83" cx="4" cy="12" r="3"></circle>
|
||||
<circle class="spinner_qM83 spinner_oXPr" cx="12" cy="12" r="3"></circle>
|
||||
<circle class="spinner_qM83 spinner_ZTLf" cx="20" cy="12" r="3"></circle>
|
||||
</svg>`;
|
||||
|
||||
/**
|
||||
* Wrapper API Call
|
||||
@ -125,9 +125,9 @@
|
||||
// Wrapping Element
|
||||
let wrapEl = document.createElement("div");
|
||||
wrapEl.innerHTML = `<div class="flex">
|
||||
<span class="font-bold w-24 grow-0 shrink-0"></span>
|
||||
<span class="whitespace-break-spaces w-full"></span>
|
||||
</div>`;
|
||||
<span class="font-bold w-24 grow-0 shrink-0"></span>
|
||||
<span class="whitespace-break-spaces w-full"></span>
|
||||
</div>`;
|
||||
|
||||
// Get Elements
|
||||
let nameEl = wrapEl.querySelector("span");
|
||||
@ -158,7 +158,63 @@
|
||||
})
|
||||
.then((data) => {
|
||||
console.log("SUCCESS:", data);
|
||||
content.innerText = data.choices[0].message.content;
|
||||
|
||||
// Create Response Element
|
||||
let responseEl = document.createElement("p");
|
||||
responseEl.setAttribute(
|
||||
"class",
|
||||
"whitespace-break-spaces border-b pb-3 mb-3"
|
||||
);
|
||||
responseEl.innerText = data.response;
|
||||
|
||||
// Create Context Element
|
||||
let contextEl = document.createElement("div");
|
||||
contextEl.innerHTML = `
|
||||
<h1 class="font-bold">Context:</h1>
|
||||
<ul class="list-disc ml-6"></ul>`;
|
||||
let ulEl = contextEl.querySelector("ul");
|
||||
|
||||
// Create Context Links
|
||||
data.context
|
||||
|
||||
// Capture PubMed ID & Distance
|
||||
.map((item) => [
|
||||
item.metadata.file.match("\/(.*)\.txt$"),
|
||||
item.distance,
|
||||
])
|
||||
|
||||
// Filter Non-Matches
|
||||
.filter(([match]) => match)
|
||||
|
||||
// Get Match Value & Round Distance (2)
|
||||
.map(([match, distance]) => [
|
||||
match[1],
|
||||
Math.round(distance * 100) / 100,
|
||||
])
|
||||
|
||||
// Create Links
|
||||
.forEach(([pmid, distance]) => {
|
||||
let newEl = document.createElement("li");
|
||||
let linkEl = document.createElement("a");
|
||||
linkEl.setAttribute("target", "_blank");
|
||||
linkEl.setAttribute(
|
||||
"class",
|
||||
"text-blue-500 hover:text-blue-600"
|
||||
);
|
||||
linkEl.setAttribute(
|
||||
"href",
|
||||
"https://www.ncbi.nlm.nih.gov/pmc/articles/" + pmid
|
||||
);
|
||||
linkEl.textContent = "[" + distance + "] " + pmid;
|
||||
newEl.append(linkEl);
|
||||
ulEl.append(newEl);
|
||||
});
|
||||
|
||||
// Add to DOM
|
||||
content.setAttribute("class", "w-full");
|
||||
content.innerHTML = "";
|
||||
content.append(responseEl);
|
||||
content.append(contextEl);
|
||||
})
|
||||
.catch((e) => {
|
||||
console.log("ERROR:", e);
|
||||
|
@ -1,7 +1,6 @@
|
||||
from chromadb.api import API
|
||||
from itertools import islice
|
||||
from os import path
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from typing import Any, cast
|
||||
import chromadb
|
||||
|
||||
@ -29,47 +28,45 @@ class VectorDB:
|
||||
ChromaDV VectorDB Type
|
||||
"""
|
||||
class ChromaDB(VectorDB):
|
||||
def __init__(self, base_path: str):
|
||||
chroma_path = path.join(base_path, "chroma")
|
||||
self.client: API = chromadb.PersistentClient(path=chroma_path)
|
||||
self.word_limit = 1000
|
||||
def __init__(self, path: str):
|
||||
self.client: API = chromadb.PersistentClient(path=path)
|
||||
self.word_cap = 2500
|
||||
self.collection_name: str = "vdb"
|
||||
self.collection: chromadb.Collection = self.client.create_collection(name=self.collection_name, get_or_create=True)
|
||||
|
||||
def get_related(self, question) -> Any:
|
||||
def get_related(self, question: str) -> Any:
|
||||
"""Returns line separated related docs"""
|
||||
results = self.collection.query(
|
||||
query_texts=[question],
|
||||
query_texts=[question.lower()],
|
||||
n_results=2
|
||||
)
|
||||
|
||||
all_docs: list = cast(list, results.get("documents", [[]]))[0]
|
||||
all_metadata: list = cast(list, results.get("metadatas", [[]]))[0]
|
||||
all_distances: list = cast(list, results.get("distances", [[]]))[0]
|
||||
all_ids: list = cast(list, results.get("ids", [[]]))[0]
|
||||
|
||||
return {
|
||||
"distances":all_distances,
|
||||
"distances": all_distances,
|
||||
"metadatas": all_metadata,
|
||||
"docs": all_docs,
|
||||
"ids": all_ids
|
||||
}
|
||||
|
||||
def load_documents(self, normalizer: DataNormalizer):
|
||||
# 10 Item Chunking
|
||||
for items in tqdm(chunk(normalizer, 50)):
|
||||
def load_documents(self, normalizer: DataNormalizer, chunk_size: int = 10):
|
||||
length = len(normalizer) / chunk_size
|
||||
for items in tqdm(chunk(normalizer, chunk_size), total=length):
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
# Limit words per document to accommodate context token limits
|
||||
for item in items:
|
||||
doc = " ".join(item.get("doc").split()[:self.word_limit])
|
||||
documents.append(doc)
|
||||
documents.append(" ".join(item.get("doc").split()[:self.word_cap]))
|
||||
ids.append(item.get("id"))
|
||||
metadatas.append(item.get("metadata", {}))
|
||||
|
||||
# Ideally we parse out metadata from each document
|
||||
# and pass to the metadata kwarg. However, each
|
||||
# document appears to have a slightly different format,
|
||||
# so it's difficult to parse out.
|
||||
self.collection.add(
|
||||
ids=ids,
|
||||
documents=documents,
|
||||
ids=ids
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user