601 lines
14 KiB
JavaScript
601 lines
14 KiB
JavaScript
const axios = require('axios');
|
|
require('dotenv').config();
|
|
const { OpenAI } = require('openai');
|
|
const yaml = require('js-yaml');
|
|
const clickhouse = require('../database/clickhouse');
|
|
|
|
const azureEndpoint = "https://cpmindiayoda-resource.services.ai.azure.com";
|
|
const deploymentName = "gpt-4o-mini";
|
|
const apiVersion = "2024-08-01-preview";
|
|
|
|
const client = new OpenAI({
|
|
baseURL: `${azureEndpoint}/openai/deployments/${deploymentName}`,
|
|
apiKey: process.env.AZURE_OPENAI_KEY,
|
|
defaultHeaders: { 'api-key': process.env.AZURE_OPENAI_KEY },
|
|
defaultQuery: { 'api-version': apiVersion }
|
|
});
|
|
|
|
const WREN_URL = "http://172.236.172.26:3000/api/graphql";
|
|
|
|
// const gql = async (operationName, query, variables) => {
|
|
// const res = await axios.post(WREN_URL, { operationName, query, variables }, {
|
|
// headers: { "Content-Type": "application/json", Accept: "application/json" },
|
|
// timeout: 60000,
|
|
// });
|
|
// if (res.data?.errors) throw new Error(res.data.errors[0].message);
|
|
// return res.data.data;
|
|
// };
|
|
|
|
|
|
|
|
const gql = async (
|
|
operationName,
|
|
query,
|
|
variables = {},
|
|
options = {}
|
|
) => {
|
|
const {
|
|
retries = 3,
|
|
timeout = 60000
|
|
} = options;
|
|
|
|
const payload = {
|
|
operationName,
|
|
query,
|
|
variables
|
|
};
|
|
|
|
let lastError;
|
|
|
|
for (let attempt = 1; attempt <= retries; attempt++) {
|
|
const startTime = Date.now();
|
|
|
|
try {
|
|
console.log(
|
|
`[GraphQL] ${operationName} | Attempt ${attempt}/${retries}`
|
|
);
|
|
|
|
const response = await axios.post(
|
|
WREN_URL,
|
|
payload,
|
|
{
|
|
timeout,
|
|
headers: {
|
|
"Content-Type": "application/json",
|
|
Accept: "application/json"
|
|
},
|
|
validateStatus: () => true
|
|
}
|
|
);
|
|
|
|
const duration = Date.now() - startTime;
|
|
|
|
console.log(
|
|
`[GraphQL] ${operationName} | ${response.status} | ${duration}ms`
|
|
);
|
|
|
|
// HTTP Error
|
|
if (response.status >= 400) {
|
|
throw new Error(
|
|
`HTTP ${response.status}: ${response.data?.message ||
|
|
response.statusText
|
|
}`
|
|
);
|
|
}
|
|
|
|
// Empty Response
|
|
if (!response.data) {
|
|
throw new Error(
|
|
`Empty response received for ${operationName}`
|
|
);
|
|
}
|
|
|
|
// GraphQL Errors
|
|
if (
|
|
Array.isArray(response.data.errors) &&
|
|
response.data.errors.length > 0
|
|
) {
|
|
const graphQLError = response.data.errors
|
|
.map(err => err.message)
|
|
.join(" | ");
|
|
|
|
throw new Error(graphQLError);
|
|
}
|
|
|
|
// Missing Data
|
|
if (!response.data.data) {
|
|
throw new Error(
|
|
`No data returned from GraphQL operation: ${operationName}`
|
|
);
|
|
}
|
|
|
|
return response.data.data;
|
|
|
|
} catch (error) {
|
|
lastError = error;
|
|
|
|
const duration = Date.now() - startTime;
|
|
|
|
console.error(
|
|
`[GraphQL Error] ${operationName} | Attempt ${attempt}/${retries} | ${duration}ms`
|
|
);
|
|
|
|
console.error(error.message);
|
|
|
|
const shouldRetry =
|
|
attempt < retries &&
|
|
(
|
|
error.code === "ECONNABORTED" ||
|
|
error.code === "ECONNRESET" ||
|
|
error.code === "ETIMEDOUT" ||
|
|
error.code === "ENOTFOUND" ||
|
|
error.code === "ECONNREFUSED" ||
|
|
error.message.includes("timeout")
|
|
);
|
|
|
|
if (!shouldRetry) {
|
|
break;
|
|
}
|
|
|
|
const delay = attempt * 2000;
|
|
|
|
console.log(
|
|
`[GraphQL Retry] Waiting ${delay}ms before retry...`
|
|
);
|
|
|
|
await new Promise(resolve =>
|
|
setTimeout(resolve, delay)
|
|
);
|
|
}
|
|
}
|
|
|
|
throw new Error(
|
|
`[${operationName}] Failed after ${retries} attempts. ${lastError?.message}`
|
|
);
|
|
};
|
|
|
|
const pollUntilFinished = async (taskId, maxAttempts = 50) => {
|
|
for (let i = 0; i < maxAttempts; i++) {
|
|
const { askingTask } = await gql("AskingTask",
|
|
`query AskingTask($taskId: String!) {
|
|
askingTask(taskId: $taskId) {
|
|
status
|
|
candidates { sql }
|
|
error { message }
|
|
}
|
|
}`,
|
|
{ taskId }
|
|
);
|
|
|
|
console.log(`Poll ${i + 1} => ${askingTask?.status}`);
|
|
|
|
if (askingTask?.error) throw new Error(askingTask.error.message);
|
|
|
|
if (askingTask?.status === "FINISHED") {
|
|
if (askingTask?.candidates?.length > 0) {
|
|
return { sql: askingTask.candidates[0].sql, type: "sql" };
|
|
}
|
|
return {
|
|
sql: null,
|
|
type: "clarification",
|
|
message: "I couldn't generate SQL for this question. Please try rephrasing.",
|
|
};
|
|
}
|
|
|
|
await new Promise(r => setTimeout(r, 2000));
|
|
}
|
|
throw new Error("Wren polling timeout");
|
|
};
|
|
|
|
const fetchWrenData = async (prompt) => {
|
|
try {
|
|
// Step 1: Create task
|
|
const { createAskingTask } = await gql("CreateAskingTask",
|
|
`mutation CreateAskingTask($data: AskingTaskInput!) {
|
|
createAskingTask(data: $data) { id }
|
|
}`,
|
|
{ data: { question: prompt } }
|
|
);
|
|
console.log("Task =>", createAskingTask.id);
|
|
|
|
// Step 2: Poll for SQL
|
|
const pollResult = await pollUntilFinished(createAskingTask.id);
|
|
|
|
// Clarification needed
|
|
if (pollResult.type === "clarification") {
|
|
return {
|
|
success: false,
|
|
type: "clarification",
|
|
message: pollResult.message,
|
|
data: [],
|
|
chart: null,
|
|
};
|
|
}
|
|
|
|
const wrenSql = pollResult.sql;
|
|
console.log("SQL ready");
|
|
|
|
// Step 3: Create thread
|
|
const { createThread } = await gql("CreateThread",
|
|
`mutation CreateThread($data: CreateThreadInput!) {
|
|
createThread(data: $data) { id }
|
|
}`,
|
|
{ data: { question: prompt, sql: wrenSql } }
|
|
);
|
|
console.log("Thread =>", createThread.id);
|
|
|
|
// Step 4: Create thread response
|
|
const { createThreadResponse } = await gql("CreateThreadResponse",
|
|
`mutation CreateThreadResponse($threadId: Int!, $data: CreateThreadResponseInput!) {
|
|
createThreadResponse(threadId: $threadId, data: $data) { id }
|
|
}`,
|
|
{ threadId: createThread.id, data: { question: prompt, sql: wrenSql } }
|
|
);
|
|
console.log("Response ID =>", createThreadResponse.id);
|
|
const responseId = createThreadResponse.id;
|
|
|
|
// Step 5: Preview data
|
|
const { previewData } = await gql("PreviewData",
|
|
`mutation PreviewData($where: PreviewDataInput!) {
|
|
previewData(where: $where)
|
|
}`,
|
|
{ where: { responseId: parseInt(createThreadResponse.id) } }
|
|
);
|
|
|
|
const columns = previewData.columns.map(c => c.name);
|
|
const rows = previewData.data.map(row =>
|
|
Object.fromEntries(columns.map((col, i) => [col, row[i]]))
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
console.log(`Done — ${rows.length} rows`);
|
|
console.table(rows);
|
|
|
|
return {
|
|
success: true,
|
|
type: "data",
|
|
prompt,
|
|
sql: wrenSql,
|
|
totalRows: rows.length,
|
|
columns,
|
|
data: rows,
|
|
responseId
|
|
};
|
|
|
|
} catch (err) {
|
|
console.error("WREN ERROR =>", err.message);
|
|
return {
|
|
success: false,
|
|
type: "error",
|
|
data: [],
|
|
chart: null,
|
|
message: err.message,
|
|
};
|
|
}
|
|
};
|
|
|
|
const generateVegaJson = async (queryResult) => {
|
|
try {
|
|
const systemPrompt = `You are a data visualization expert. I will provide a user's question and a JSON array of data. Your task is to generate a strictly valid Vega-Lite JSON specification to visualize this data. The data array will be provided to the Vega spec internally. Map the JSON keys to the correct x, y, and color axes. Choose the best chart type (bar,pai,line, arc) based on the question.`;
|
|
|
|
const userPrompt = `
|
|
DATA:
|
|
${JSON.stringify(queryResult, null, 2)}
|
|
|
|
Generate Vega-Lite JSON.
|
|
`;
|
|
|
|
const completion = await client.chat.completions.create({
|
|
model: deploymentName,
|
|
temperature: 0,
|
|
messages: [
|
|
{
|
|
role: "system",
|
|
content: systemPrompt
|
|
},
|
|
{
|
|
role: "user",
|
|
content: userPrompt
|
|
}
|
|
]
|
|
});
|
|
|
|
const vegaJson = completion.choices[0].message.content.trim();
|
|
const cleanJson = vegaJson
|
|
.replace(/```json/g, "")
|
|
.replace(/```/g, "")
|
|
.trim();
|
|
|
|
return JSON.parse(cleanJson);
|
|
|
|
|
|
} catch (err) {
|
|
console.error("Vega Generation Error =>", err.message);
|
|
throw err;
|
|
}
|
|
};
|
|
const ask = async (req, res) => {
|
|
try {
|
|
const { prompt } = req.body;
|
|
|
|
const userId = req.user?.client_id || null;
|
|
|
|
if (!prompt?.trim()) {
|
|
return res.status(400).json({
|
|
success: false,
|
|
message: "Prompt required"
|
|
});
|
|
}
|
|
|
|
const result = await fetchWrenData(prompt);
|
|
if (!result.success) {
|
|
return res.json(result);
|
|
}
|
|
const vegaSpec = await generateVegaJson({
|
|
columns: result.columns,
|
|
data: result.data,
|
|
chart: result.prompt,
|
|
sql: result.sql
|
|
});
|
|
|
|
const finalResponse = {
|
|
...result,
|
|
vegaSpec
|
|
};
|
|
|
|
await clickhouse.insert({
|
|
table: "userdetails.chat_history",
|
|
values: [
|
|
{
|
|
user_id: userId,
|
|
response_id: result.responseId,
|
|
prompt: prompt,
|
|
response_json: JSON.stringify(finalResponse)
|
|
}
|
|
],
|
|
format: "JSONEachRow"
|
|
});
|
|
|
|
return res.json(finalResponse);
|
|
|
|
// return res.json({
|
|
// ...result,
|
|
// vegaSpec
|
|
// });
|
|
|
|
} catch (err) {
|
|
console.error(err);
|
|
|
|
return res.status(500).json({
|
|
success: false,
|
|
message: err.message
|
|
});
|
|
}
|
|
};
|
|
|
|
const getSuggestedQuestions = async () => {
|
|
try {
|
|
const { threads } = await gql(
|
|
"Threads",
|
|
`
|
|
query Threads {
|
|
threads {
|
|
id
|
|
summary
|
|
}
|
|
}
|
|
`,
|
|
{}
|
|
);
|
|
|
|
return (threads || [])
|
|
.filter(t => t.summary)
|
|
.slice(0, 5)
|
|
.map(t => ({
|
|
question: t.summary,
|
|
category: "Recent"
|
|
}));
|
|
|
|
} catch (fallbackErr) {
|
|
console.error(
|
|
"Fallback failed =>",
|
|
fallbackErr.message
|
|
);
|
|
return [];
|
|
}
|
|
};
|
|
const suggestions = async (req, res) => {
|
|
const questions = await getSuggestedQuestions();
|
|
res.json({ success: true, questions });
|
|
|
|
}
|
|
const getThreadDetails = async (req, res) => {
|
|
try {
|
|
const { threadId } = req.body;
|
|
|
|
if (!threadId) {
|
|
return res.status(400).json({
|
|
success: false,
|
|
message: "threadId is required",
|
|
});
|
|
}
|
|
|
|
const query = `
|
|
query Thread($threadId: Int!) {
|
|
thread(threadId: $threadId) {
|
|
id
|
|
responses {
|
|
id
|
|
threadId
|
|
question
|
|
sql
|
|
|
|
answerDetail {
|
|
queryId
|
|
status
|
|
content
|
|
numRowsUsedInLLM
|
|
error {
|
|
code
|
|
shortMessage
|
|
message
|
|
}
|
|
}
|
|
|
|
chartDetail {
|
|
queryId
|
|
status
|
|
description
|
|
chartType
|
|
chartSchema
|
|
}
|
|
}
|
|
}
|
|
}
|
|
`;
|
|
|
|
const { thread } = await gql(
|
|
"Thread",
|
|
query,
|
|
{
|
|
threadId: Number(threadId),
|
|
}
|
|
);
|
|
|
|
if (!thread) {
|
|
return res.status(404).json({
|
|
success: false,
|
|
message: "Thread not found",
|
|
});
|
|
}
|
|
|
|
const latestResponse =
|
|
thread.responses?.[thread.responses.length - 1] || null;
|
|
|
|
return res.json({
|
|
success: true,
|
|
threadId: thread.id,
|
|
totalResponses: thread.responses.length,
|
|
|
|
latestResponse: latestResponse
|
|
? {
|
|
responseId: latestResponse.id,
|
|
question: latestResponse.question,
|
|
sql: latestResponse.sql,
|
|
status: latestResponse.answerDetail?.status,
|
|
queryId: latestResponse.answerDetail?.queryId,
|
|
content: latestResponse.answerDetail?.content,
|
|
}
|
|
: null,
|
|
|
|
responses: thread.responses.map(r => ({
|
|
responseId: r.id,
|
|
question: r.question,
|
|
status: r.answerDetail?.status,
|
|
queryId: r.answerDetail?.queryId,
|
|
content: r.answerDetail?.content,
|
|
})),
|
|
});
|
|
} catch (err) {
|
|
console.error(err);
|
|
|
|
return res.status(500).json({
|
|
success: false,
|
|
message: err.message,
|
|
});
|
|
}
|
|
};
|
|
const getResponseDetails = async (req, res) => {
|
|
try {
|
|
const { responseId } = req.body;
|
|
|
|
if (!responseId) {
|
|
return res.status(400).json({
|
|
success: false,
|
|
message: "responseId is required",
|
|
});
|
|
}
|
|
|
|
const query = `
|
|
query ThreadResponse($responseId: Int!) {
|
|
threadResponse(responseId: $responseId) {
|
|
id
|
|
threadId
|
|
question
|
|
sql
|
|
|
|
answerDetail {
|
|
queryId
|
|
status
|
|
content
|
|
numRowsUsedInLLM
|
|
error {
|
|
code
|
|
shortMessage
|
|
message
|
|
}
|
|
}
|
|
|
|
chartDetail {
|
|
queryId
|
|
status
|
|
description
|
|
chartType
|
|
chartSchema
|
|
}
|
|
}
|
|
}
|
|
`;
|
|
|
|
const { threadResponse } = await gql(
|
|
"ThreadResponse",
|
|
query,
|
|
{
|
|
responseId: Number(responseId),
|
|
}
|
|
);
|
|
|
|
if (!threadResponse) {
|
|
return res.status(404).json({
|
|
success: false,
|
|
message: "Response not found",
|
|
});
|
|
}
|
|
|
|
return res.json({
|
|
success: true,
|
|
responseId: threadResponse.id,
|
|
threadId: threadResponse.threadId,
|
|
question: threadResponse.question,
|
|
sql: threadResponse.sql,
|
|
|
|
status: threadResponse.answerDetail?.status,
|
|
queryId: threadResponse.answerDetail?.queryId,
|
|
content: threadResponse.answerDetail?.content,
|
|
|
|
chartDetail: threadResponse.chartDetail || null,
|
|
});
|
|
|
|
} catch (err) {
|
|
console.error(err);
|
|
|
|
return res.status(500).json({
|
|
success: false,
|
|
message: err.message,
|
|
});
|
|
}
|
|
};
|
|
|
|
module.exports = { ask, suggestions, getThreadDetails, getResponseDetails };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|