|
|
@ -6,6 +6,7 @@ import com.alibaba.fastjson.JSONObject; |
|
|
import com.bnyer.common.core.constant.RedisKeyConstant; |
|
|
import com.bnyer.common.core.constant.RedisKeyConstant; |
|
|
import com.bnyer.common.core.domain.AiPaint; |
|
|
import com.bnyer.common.core.domain.AiPaint; |
|
|
import com.bnyer.common.core.dto.TextToImgDto; |
|
|
import com.bnyer.common.core.dto.TextToImgDto; |
|
|
|
|
|
import com.bnyer.common.core.exception.ServiceException; |
|
|
import com.bnyer.common.core.utils.TranslateUtils; |
|
|
import com.bnyer.common.core.utils.TranslateUtils; |
|
|
import com.bnyer.common.core.utils.file.Base64ToMultipartFileUtils; |
|
|
import com.bnyer.common.core.utils.file.Base64ToMultipartFileUtils; |
|
|
import com.bnyer.common.core.vo.TextToImgVo; |
|
|
import com.bnyer.common.core.vo.TextToImgVo; |
|
|
@ -13,8 +14,13 @@ import com.bnyer.common.redis.service.RedisService; |
|
|
import com.bnyer.file.api.RemoteFileService; |
|
|
import com.bnyer.file.api.RemoteFileService; |
|
|
import com.bnyer.img.config.StableDiffusionConfig; |
|
|
import com.bnyer.img.config.StableDiffusionConfig; |
|
|
import com.bnyer.img.config.TencentTranslateConfig; |
|
|
import com.bnyer.img.config.TencentTranslateConfig; |
|
|
|
|
|
import com.bnyer.img.enums.AiPaintButtonEnum; |
|
|
import com.bnyer.img.service.AiPaintService; |
|
|
import com.bnyer.img.service.AiPaintService; |
|
|
import com.bnyer.img.service.StableDiffusionService; |
|
|
import com.bnyer.img.service.StableDiffusionService; |
|
|
|
|
|
import com.tencentcloudapi.aiart.v20221229.AiartClient; |
|
|
|
|
|
import com.tencentcloudapi.aiart.v20221229.models.ResultConfig; |
|
|
|
|
|
import com.tencentcloudapi.aiart.v20221229.models.TextToImageRequest; |
|
|
|
|
|
import com.tencentcloudapi.aiart.v20221229.models.TextToImageResponse; |
|
|
import com.tencentcloudapi.common.Credential; |
|
|
import com.tencentcloudapi.common.Credential; |
|
|
import com.tencentcloudapi.common.exception.TencentCloudSDKException; |
|
|
import com.tencentcloudapi.common.exception.TencentCloudSDKException; |
|
|
import com.tencentcloudapi.common.profile.ClientProfile; |
|
|
import com.tencentcloudapi.common.profile.ClientProfile; |
|
|
@ -29,10 +35,7 @@ import org.springframework.web.client.RestTemplate; |
|
|
import org.springframework.web.multipart.MultipartFile; |
|
|
import org.springframework.web.multipart.MultipartFile; |
|
|
|
|
|
|
|
|
import java.text.SimpleDateFormat; |
|
|
import java.text.SimpleDateFormat; |
|
|
import java.util.Date; |
|
|
import java.util.*; |
|
|
import java.util.HashMap; |
|
|
|
|
|
import java.util.List; |
|
|
|
|
|
import java.util.Map; |
|
|
|
|
|
|
|
|
|
|
|
@Service |
|
|
@Service |
|
|
@Slf4j |
|
|
@Slf4j |
|
|
@ -85,43 +88,63 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { |
|
|
|
|
|
|
|
|
@Override |
|
|
@Override |
|
|
public TextToImgVo textToImg(TextToImgDto param) { |
|
|
public TextToImgVo textToImg(TextToImgDto param) { |
|
|
try{ |
|
|
|
|
|
String prompt = ""; |
|
|
//配置控制采用sd还是tx文生图
|
|
|
//判断prompt是否包含中文,中文则翻译,否则跳过
|
|
|
int button = redisService.getCacheObject(RedisKeyConstant.AI_PAINT_BUTTON); |
|
|
if(TranslateUtils.isContainChinese(param.getPrompt())){ |
|
|
if(button == AiPaintButtonEnum.TX.getCode()){ |
|
|
//调用翻译api
|
|
|
//采用腾讯文生图
|
|
|
prompt = translate(param.getPrompt()); |
|
|
try{ |
|
|
}else{ |
|
|
Credential cred = new Credential(tencentTranslateConfig.getSecretId(), tencentTranslateConfig.getSecretKey()); |
|
|
prompt = param.getPrompt(); |
|
|
// 实例化一个http选项,可选的,没有特殊需求可以跳过
|
|
|
} |
|
|
HttpProfile httpProfile = new HttpProfile(); |
|
|
System.out.println(prompt); |
|
|
httpProfile.setEndpoint("aiart.tencentcloudapi.com"); |
|
|
|
|
|
// 实例化一个client选项,可选的,没有特殊需求可以跳过
|
|
|
//TODO 根据选择的风格来选择模型
|
|
|
ClientProfile clientProfile = new ClientProfile(); |
|
|
Map<String, Object> map = new HashMap<>(); |
|
|
clientProfile.setHttpProfile(httpProfile); |
|
|
map.put("width",param.getWidth() == null ? 512 : param.getWidth()); |
|
|
// 实例化要请求产品的client对象,clientProfile是可选的
|
|
|
map.put("height",param.getHeight() == null ? 512 : param.getHeight()); |
|
|
AiartClient client = new AiartClient(cred, "ap-guangzhou", clientProfile); |
|
|
map.put("prompt", prompt); |
|
|
// 实例化一个请求对象,每个接口都会对应一个request对象
|
|
|
//map.put("prompt", param.getPrompt());
|
|
|
TextToImageRequest req = new TextToImageRequest(); |
|
|
map.put("seed",-1); |
|
|
req.setPrompt(param.getPrompt()); |
|
|
map.put("batch_size",1); |
|
|
//请求风格
|
|
|
map.put("cfg_scale",7); |
|
|
String[] styles1 = new String[0]; |
|
|
map.put("restore_faces",false); |
|
|
switch (param.getStyleName()) { |
|
|
map.put("tiling",false); |
|
|
case "细腻": |
|
|
map.put("eta",0); |
|
|
styles1 = new String[]{"110"}; |
|
|
map.put("sampler_index","DPM++ 2S a Karras"); |
|
|
break; |
|
|
//map.put("sampler_index",param.getSamplerIndex());
|
|
|
case "卡通": |
|
|
map.put("steps",25); |
|
|
styles1 = new String[]{"201"}; |
|
|
map.put("negative_prompt","easynegative"); |
|
|
break; |
|
|
//log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map));
|
|
|
case "科幻": |
|
|
JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class); |
|
|
styles1 = new String[]{"114"}; |
|
|
//log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject));
|
|
|
break; |
|
|
TextToImgVo img = new TextToImgVo(); |
|
|
case "中国风": |
|
|
if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){ |
|
|
styles1 = new String[]{"101"}; |
|
|
List<String> images = jsonObject.getJSONArray("images").toJavaList(String.class); |
|
|
break; |
|
|
img.setImages(images); |
|
|
} |
|
|
|
|
|
req.setStyles(styles1); |
|
|
|
|
|
|
|
|
|
|
|
//画布大小
|
|
|
|
|
|
ResultConfig resultConfig1 = new ResultConfig(); |
|
|
|
|
|
if(param.getWidth() == 512 && param.getHeight() == 512){ |
|
|
|
|
|
resultConfig1.setResolution("768:768"); |
|
|
|
|
|
}else if(param.getWidth() == 512 && param.getHeight() == 1024){ |
|
|
|
|
|
resultConfig1.setResolution("768:1024"); |
|
|
|
|
|
}else{ |
|
|
|
|
|
resultConfig1.setResolution("1024:768"); |
|
|
|
|
|
} |
|
|
|
|
|
req.setResultConfig(resultConfig1); |
|
|
|
|
|
|
|
|
|
|
|
// 返回的resp是一个TextToImageResponse的实例,与请求对象对应
|
|
|
|
|
|
TextToImageResponse resp = client.TextToImage(req); |
|
|
|
|
|
TextToImgVo img = new TextToImgVo(); |
|
|
|
|
|
String images = resp.getResultImage(); |
|
|
|
|
|
List<String> list = new ArrayList<>(); |
|
|
|
|
|
list.add(images); |
|
|
|
|
|
img.setImages(list); |
|
|
String paintId = null; |
|
|
String paintId = null; |
|
|
Date paintTime = null; |
|
|
Date paintTime = null; |
|
|
for (String image : images) { |
|
|
for (String image : list) { |
|
|
//base64转file
|
|
|
//base64转file
|
|
|
MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg"); |
|
|
MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg"); |
|
|
//上传图片到七牛云/minio
|
|
|
//上传图片到七牛云/minio
|
|
|
@ -149,10 +172,82 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { |
|
|
} |
|
|
} |
|
|
img.setPaintId(paintId); |
|
|
img.setPaintId(paintId); |
|
|
img.setPaintTime(paintTime); |
|
|
img.setPaintTime(paintTime); |
|
|
|
|
|
return img; |
|
|
|
|
|
} catch (TencentCloudSDKException e) { |
|
|
|
|
|
log.error("腾讯文生图调用错误!"+e.getMessage()); |
|
|
|
|
|
throw new ServiceException(e.getMessage(),500); |
|
|
|
|
|
} |
|
|
|
|
|
}else{ |
|
|
|
|
|
//采用sd
|
|
|
|
|
|
try{ |
|
|
|
|
|
String prompt = ""; |
|
|
|
|
|
//判断prompt是否包含中文,中文则翻译,否则跳过
|
|
|
|
|
|
if(TranslateUtils.isContainChinese(param.getPrompt())){ |
|
|
|
|
|
//调用翻译api
|
|
|
|
|
|
prompt = translate(param.getPrompt()); |
|
|
|
|
|
}else{ |
|
|
|
|
|
prompt = param.getPrompt(); |
|
|
|
|
|
} |
|
|
|
|
|
System.out.println(prompt); |
|
|
|
|
|
|
|
|
|
|
|
//TODO 根据选择的风格来选择模型
|
|
|
|
|
|
Map<String, Object> map = new HashMap<>(); |
|
|
|
|
|
map.put("width",param.getWidth() == null ? 512 : param.getWidth()); |
|
|
|
|
|
map.put("height",param.getHeight() == null ? 512 : param.getHeight()); |
|
|
|
|
|
map.put("prompt", prompt); |
|
|
|
|
|
//map.put("prompt", param.getPrompt());
|
|
|
|
|
|
map.put("seed",-1); |
|
|
|
|
|
map.put("batch_size",1); |
|
|
|
|
|
map.put("cfg_scale",7); |
|
|
|
|
|
map.put("restore_faces",false); |
|
|
|
|
|
map.put("tiling",false); |
|
|
|
|
|
map.put("eta",0); |
|
|
|
|
|
map.put("sampler_index","DPM++ 2S a Karras"); |
|
|
|
|
|
//map.put("sampler_index",param.getSamplerIndex());
|
|
|
|
|
|
map.put("steps",25); |
|
|
|
|
|
map.put("negative_prompt","easynegative,nsfw,naked"); |
|
|
|
|
|
//log.info("请求stable_diffusion请求体为:【{}】", JSON.toJSONString(map));
|
|
|
|
|
|
JSONObject jsonObject = restTemplate.postForObject(stableDiffusionConfig.getTxt2ImgUrl(), map, JSONObject.class); |
|
|
|
|
|
//log.info("请求stable_diffusion响应体的为:【{}】", JSON.toJSONString(jsonObject));
|
|
|
|
|
|
TextToImgVo img = new TextToImgVo(); |
|
|
|
|
|
if(jsonObject != null && jsonObject.getJSONArray("images").size() > 0){ |
|
|
|
|
|
List<String> images = jsonObject.getJSONArray("images").toJavaList(String.class); |
|
|
|
|
|
img.setImages(images); |
|
|
|
|
|
String paintId = null; |
|
|
|
|
|
Date paintTime = null; |
|
|
|
|
|
for (String image : images) { |
|
|
|
|
|
//base64转file
|
|
|
|
|
|
MultipartFile file = new Base64ToMultipartFileUtils(image, "data:image/png;base64", "file", "tempSDImg"); |
|
|
|
|
|
//上传图片到七牛云/minio
|
|
|
|
|
|
String imgStr = remoteFileService.uploadBanner(file).getData(); |
|
|
|
|
|
//保存生辰该图片到ai绘画表
|
|
|
|
|
|
AiPaint paint = new AiPaint(); |
|
|
|
|
|
//paint.setId(); 主键改成雪花算法后启用
|
|
|
|
|
|
paintId = IdUtil.getSnowflakeNextIdStr(); |
|
|
|
|
|
paintTime = new Date(); |
|
|
|
|
|
paint.setPaintId(paintId); |
|
|
|
|
|
paint.setCreateTime(paintTime); |
|
|
|
|
|
paint.setImgUrl(imgStr); |
|
|
|
|
|
paint.setPrompt(param.getPrompt()); |
|
|
|
|
|
paint.setModel(param.getModelName()); |
|
|
|
|
|
paint.setStyleName(param.getStyleName()); |
|
|
|
|
|
paint.setHeight(param.getHeight() == null ? "512" : String.valueOf(param.getHeight())); |
|
|
|
|
|
paint.setWidth(param.getWidth() == null ? "512" : String.valueOf(param.getWidth())); |
|
|
|
|
|
paint.setIsShow("1"); |
|
|
|
|
|
paint.setSource(param.getPlatform()); |
|
|
|
|
|
paint.setPainterId(param.getPainterId()); |
|
|
|
|
|
paint.setPainterName(param.getPainterName()); |
|
|
|
|
|
aiPaintService.insert(paint); |
|
|
|
|
|
//写入ai绘画次数
|
|
|
|
|
|
writePlatformUserAiPaintNum(param.getAppType(),param.getPlatform(),param.getPainterId()); |
|
|
|
|
|
} |
|
|
|
|
|
img.setPaintId(paintId); |
|
|
|
|
|
img.setPaintTime(paintTime); |
|
|
|
|
|
} |
|
|
|
|
|
return img; |
|
|
|
|
|
}catch (Exception e){ |
|
|
|
|
|
log.error("文本翻译错误!"+e); |
|
|
} |
|
|
} |
|
|
return img; |
|
|
|
|
|
}catch (Exception e){ |
|
|
|
|
|
log.error("文本翻译错误!"+e); |
|
|
|
|
|
} |
|
|
} |
|
|
return null; |
|
|
return null; |
|
|
} |
|
|
} |
|
|
@ -186,4 +281,9 @@ public class StableDiffusionServiceImpl implements StableDiffusionService { |
|
|
String redisKey = RedisKeyConstant.PLATFORM_USER_AI_PAINT_KEY + date; |
|
|
String redisKey = RedisKeyConstant.PLATFORM_USER_AI_PAINT_KEY + date; |
|
|
redisService.hashIncr(redisKey, hashKey, -patinNum); |
|
|
redisService.hashIncr(redisKey, hashKey, -patinNum); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@Override |
|
|
|
|
|
public void setPaintButton(int buttonValue) { |
|
|
|
|
|
redisService.setCacheObject(RedisKeyConstant.AI_PAINT_BUTTON,buttonValue); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|